File size: 6,066 Bytes
5a67aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app.py
import os
import io
import random
from PIL import Image, ImageOps
import numpy as np
import streamlit as st

# ---- ML libs ----
import torch
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
from huggingface_hub import login

# ---- OpenCV for simple preproc (Canny) ----
import cv2

st.set_page_config(page_title="Sketch2Face (Streamlit + ControlNet)", layout="centered")

st.title("Sketch2Face — turn your face sketches into stylized images")
st.write("Upload a face sketch (line drawing). Use the prompt to guide style, pose & mood.")

# Get HF token (recommended to set as secret on HF Spaces or as env var locally)
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
    except Exception:
        pass
else:
    st.warning("No Hugging Face token found. On Spaces, add HF_TOKEN in Settings → Secrets for model download. Locally use 'huggingface-cli login'.")

# Sidebar controls
with st.sidebar:
    st.header("Generation settings")
    model_id = st.text_input("Stable Diffusion model (hf repo)", value="runwayml/stable-diffusion-v1-5")
    controlnet_id = st.text_input("ControlNet (canny) repo", value="lllyasviel/sd-controlnet-canny")
    prompt = st.text_area("Prompt", value="A realistic portrait of a young man, soft lighting, cinematic")
    negative_prompt = st.text_area("Negative prompt (optional)", value="lowres, deformed, extra fingers, watermark")
    guidance_scale = st.slider("Guidance scale", 1.0, 20.0, 7.5)
    strength = st.slider("Strength (how much to change sketch)", 0.1, 1.0, 0.7)
    num_inference_steps = st.slider("Steps", 10, 60, 28)
    seed = st.number_input("Seed (0 for random)", min_value=0, max_value=999999999, value=0, step=1)
    use_gpu = st.checkbox("Use GPU (if available)", value=True)
    run_btn = st.button("Generate")

# Upload sketch
uploaded = st.file_uploader("Upload your sketch (png/jpg). Prefer simple line art.", type=["png","jpg","jpeg"])
example_col1, example_col2 = st.columns(2)
with example_col1:
    st.markdown("**Tip:**** clear black lines on white background work best.")
with example_col2:
    st.markdown("**Tip:** crop to face / 1:1 or 3:4 ratio.")

@st.cache_resource(show_spinner=False)
def load_models(sd_model_id, cn_model_id, device):
    # Load ControlNet then the combined pipeline
    controlnet = ControlNetModel.from_pretrained(
        cn_model_id, torch_dtype=torch.float16 if device=="cuda" else torch.float32
    )
    pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
        sd_model_id,
        controlnet=controlnet,
        safety_checker=None,
        torch_dtype=torch.float16 if device=="cuda" else torch.float32,
    )
    # Scheduler & device
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    if device == "cuda":
        pipe.enable_xformers_memory_efficient_attention()
        pipe.to("cuda")
    else:
        pipe.to("cpu")
    return pipe

def prepare_control_image_pil(pil_img, target_size=512):
    # Ensure grayscale -> convert to single-channel edge map using Canny
    img = pil_img.convert("RGB")
    open_cv_image = np.array(img)[:, :, ::-1]  # RGB->BGR
    gray = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2GRAY)
    # Auto-threshold can be useful; here we use fixed, but you can expose sliders
    edges = cv2.Canny(gray, 100, 200)
    edges = cv2.resize(edges, (target_size, target_size))
    # convert single channel to 3-channel PIL
    edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    return Image.fromarray(edges_rgb)

def prepare_init_image(pil_img, target_size=512):
    img = pil_img.convert("RGB")
    img = ImageOps.fit(img, (target_size, target_size), Image.LANCZOS)
    return img

if run_btn:
    if not uploaded:
        st.error("Please upload a sketch first.")
    else:
        device = "cuda" if (torch.cuda.is_available() and use_gpu) else "cpu"
        with st.spinner("Loading models (first run may take ~1-2 minutes)..."):
            pipe = load_models(model_id, controlnet_id, device)
        # load user image
        img = Image.open(uploaded)
        control_image = prepare_control_image_pil(img, target_size=512)
        init_image = prepare_init_image(img, target_size=512)

        # seed
        gen_seed = None if seed == 0 else int(seed)
        generator = torch.Generator(device=device)
        if gen_seed is not None:
            generator = generator.manual_seed(gen_seed)
        else:
            generator = None

        with st.spinner("Generating..."):
            try:
                output = pipe.img2img(
                    prompt=prompt,
                    image=init_image,
                    control_image=control_image,
                    negative_prompt=negative_prompt or None,
                    strength=float(strength),
                    guidance_scale=float(guidance_scale),
                    num_inference_steps=int(num_inference_steps),
                    generator=generator,
                )
            except Exception as e:
                st.exception(f"Generation failed: {e}")
                raise

        result = output.images[0]
        st.image(result, caption="Generated image", use_column_width=True)
        # offer download
        buf = io.BytesIO()
        result.save(buf, format="PNG")
        buf.seek(0)
        st.download_button("Download image (PNG)", data=buf, file_name="sketch2face.png", mime="image/png")

# Show sample control image / debug
if uploaded:
    try:
        img = Image.open(uploaded)
        control_img = prepare_control_image_pil(img, target_size=256)
        st.caption("Preview: internal Canny/control image (what ControlNet sees)")
        st.image(control_img)
    except Exception:
        pass

st.markdown("---")
st.markdown("Made for sketch-to-face. Adjust prompt & strength. For best results, upload clear line sketches and try style prompts like 'photorealistic', 'studio lighting', or artists' names (check licenses).")