File size: 12,722 Bytes
aec1787
97ae321
8fc4073
 
 
 
 
7093eab
8fc4073
97ae321
8fc4073
 
 
482599f
97ae321
 
482599f
97ae321
 
 
 
 
482599f
7093eab
97ae321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec1787
97ae321
165f68d
97ae321
aec1787
 
97ae321
 
165f68d
97ae321
 
 
 
 
 
7093eab
97ae321
165f68d
97ae321
 
 
 
 
aec1787
482599f
 
165f68d
 
482599f
 
 
aec1787
7093eab
 
aec1787
482599f
 
7093eab
482599f
 
 
 
 
 
 
165f68d
 
 
 
aec1787
165f68d
7093eab
 
6e327e0
165f68d
482599f
165f68d
aec1787
165f68d
 
aec1787
 
7093eab
 
 
 
 
 
 
165f68d
 
 
 
 
 
f63b4cd
482599f
aec1787
165f68d
 
aec1787
7093eab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec1787
482599f
 
 
8fc4073
97ae321
8fc4073
165f68d
 
 
 
 
 
 
7093eab
aec1787
 
 
 
 
 
 
 
 
 
7093eab
 
 
 
 
 
 
 
 
 
aec1787
7093eab
 
aec1787
7093eab
 
aec1787
 
7093eab
 
 
aec1787
7093eab
aec1787
7093eab
aec1787
97ae321
482599f
8fc4073
97ae321
aec1787
2187ded
aec1787
8fc4073
aec1787
f63b4cd
aec1787
165f68d
97ae321
482599f
aec1787
482599f
 
7093eab
 
482599f
 
aec1787
482599f
 
aec1787
482599f
aec1787
482599f
aec1787
7093eab
482599f
 
 
 
 
 
7093eab
 
aec1787
 
 
 
482599f
 
 
 
aec1787
 
 
 
482599f
 
 
 
aec1787
7093eab
aec1787
482599f
aec1787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482599f
 
aec1787
482599f
 
 
 
aec1787
482599f
7093eab
 
 
 
 
 
 
aec1787
7093eab
482599f
 
7093eab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
# app.py β€” Compact UI: Age-first + FAST cartoon (Turbo) with collapsible advanced options

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from typing import Optional
import gradio as gr
from PIL import Image, ImageDraw
import numpy as np
import torch

# ------------------ Age estimator (Hugging Face) ------------------
from transformers import AutoImageProcessor, AutoModelForImageClassification

HF_MODEL_ID = "nateraw/vit-age-classifier"
AGE_RANGE_TO_MID = {
    "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
    "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
}

class PretrainedAgeEstimator:
    def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
        self.model = AutoModelForImageClassification.from_pretrained(model_id)
        self.model.to(self.device).eval()
        self.id2label = self.model.config.id2label

    @torch.inference_mode()
    def predict(self, img: Image.Image, topk: int = 5):
        if img.mode != "RGB":
            img = img.convert("RGB")
        inputs = self.processor(images=img, return_tensors="pt").to(self.device)
        logits = self.model(**inputs).logits
        probs = logits.softmax(dim=-1).squeeze(0)
        k = min(topk, probs.numel())
        values, indices = torch.topk(probs, k=k)
        top = [(self.id2label[i.item()], float(v.item())) for i, v in zip(indices, values)]
        expected = sum(AGE_RANGE_TO_MID.get(self.id2label[i], 35) * float(p)
                       for i, p in enumerate(probs))
        return expected, top

# ------------------ Largest-face detector with nice margin ------------------
from facenet_pytorch import MTCNN

class FaceCropper:
    """Detect faces; return (wide_crop, annotated). Largest face only; adds margin so face isn't full screen."""
    def __init__(self, device: Optional[str] = None, margin_scale: float = 1.85):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.mtcnn = MTCNN(keep_all=True, device=self.device)
        self.margin_scale = margin_scale

    def _ensure_pil(self, img):
        if isinstance(img, Image.Image):
            return img.convert("RGB")
        return Image.fromarray(img).convert("RGB")

    def detect_and_crop_wide(self, img):
        pil = self._ensure_pil(img)
        W, H = pil.size
        boxes, probs = self.mtcnn.detect(pil)

        annotated = pil.copy()
        draw = ImageDraw.Draw(annotated)
        if boxes is None or len(boxes) == 0:
            return None, annotated

        # draw all boxes
        for b, p in zip(boxes, probs):
            bx1, by1, bx2, by2 = map(float, b)
            draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
            draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))

        # choose largest
        idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
        x1, y1, x2, y2 = boxes[idx]
        # expand with margin (approx 4:5 portrait)
        cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
        w, h = (x2 - x1), (y2 - y1)
        side = max(w, h) * self.margin_scale
        target_w = side
        target_h = side * 1.25

        nx1 = int(max(0, cx - target_w/2))
        nx2 = int(min(W, cx + target_w/2))
        ny1 = int(max(0, cy - target_h/2))
        ny2 = int(min(H, cy + target_h/2))

        crop = pil.crop((nx1, ny1, nx2, ny2))
        return crop, annotated

# ------------------ Fast Cartoonizer (SD-Turbo) with safety ------------------
from diffusers import AutoPipelineForImage2Image
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

TURBO_ID = "stabilityai/sd-turbo"

def load_turbo_pipe(device):
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    pipe = AutoPipelineForImage2Image.from_pretrained(
        TURBO_ID,
        dtype=dtype,  # βœ… no deprecation warning
    ).to(device)
    # Safety checker ON for public Spaces
    pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    )
    try:
        pipe.enable_attention_slicing()
    except Exception:
        pass
    return pipe

# ------------------ Init models once ------------------
age_est = PretrainedAgeEstimator()
cropper = FaceCropper(device=age_est.device, margin_scale=1.85)
sd_pipe = load_turbo_pipe(age_est.device)

# ------------------ Hint choices (with defaults) ------------------
ROLE_CHOICES = [
    "Queen/Princess", "King/Prince", "Fairy", "Elf", "Knight", "Sorcerer/Sorceress",
    "Steampunk Royalty", "Cyberpunk Royalty", "Superhero", "Anime Protagonist"
]
BACKGROUND_CHOICES = [
    "grand castle hall", "castle balcony at sunset", "enchanted forest", "starry night sky",
    "throne room with banners", "crystal palace", "moonlit garden", "winter snow castle",
    "golden hour meadow", "mystical waterfall"
]
LIGHTING_CHOICES = [
    "soft magical lighting", "golden hour rim light", "cinematic soft light",
    "glowing ambience", "volumetric light rays", "dramatic chiaroscuro"
]
ARTSTYLE_CHOICES = [
    "Disney/Pixar style", "Studio Ghibli watercolor", "cel-shaded cartoon",
    "storybook illustration", "painterly brush strokes", "anime lineart"
]
COLOR_CHOICES = [
    "pastel palette", "vibrant colors", "warm tones", "cool tones",
    "iridescent highlights", "royal gold & sapphire"
]
OUTFIT_CHOICES = [
    "elegant gown", "ornate royal cloak", "jeweled tiara/crown",
    "silver diadem", "flowing cape", "intricate embroidery"
]
EFFECTS_CHOICES = [
    "sparkles", "soft bokeh background", "floating petals", "glowing particles",
    "butterflies", "magical aura"
]

NEGATIVE_PROMPT = (
    "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, blurry, watermark, text, logo"
)

# ------------------ Helpers ------------------
def _ensure_pil(img):
    return img if isinstance(img, Image.Image) else Image.fromarray(img)

def _resize_512(im: Image.Image):
    w, h = im.size
    scale = 512 / max(w, h)
    if scale < 1.0:
        im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
    return im

def build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra):
    """Defaults always exist; user selections override them."""
    # Defaults (applied if user doesn't choose)
    role = role or "Queen/Princess"
    background = background or ["castle balcony at sunset"]
    lighting = lighting or ["soft magical lighting"]
    artstyle = artstyle or ["storybook illustration"]
    colors = colors or ["vibrant colors"]
    outfit = outfit or ["elegant gown", "jeweled tiara/crown"]
    effects = effects or ["sparkles", "glowing particles"]

    role_map = {
        "Queen/Princess": "regal queen/princess portrait",
        "King/Prince": "regal king/prince portrait",
        "Fairy": "ethereal fairy portrait with delicate wings",
        "Elf": "elven royalty portrait with elegant ears",
        "Knight": "valiant knight portrait in ornate armor",
        "Sorcerer/Sorceress": "mystical sorcerer portrait with arcane motifs",
        "Steampunk Royalty": "steampunk royal portrait with brass filigree",
        "Cyberpunk Royalty": "cyberpunk royal portrait with neon accents",
        "Superhero": "heroic comic-style portrait",
        "Anime Protagonist": "anime protagonist portrait",
    }

    parts = [role_map.get(role, role)]
    for group in (background, lighting, artstyle, colors, outfit, effects):
        if group and isinstance(group, list):
            parts.append(", ".join(group))
    parts.append("clean lineart, high quality")

    extra = (extra or "").strip()
    if extra:
        parts.append(extra)

    return ", ".join([p for p in parts if p])

# ------------------ Actions ------------------
@torch.inference_mode()
def predict_age_only(img, auto_crop=True):
    if img is None:
        return {}, "Please upload an image.", None
    pil = _ensure_pil(img).convert("RGB")

    face_wide, annotated = (None, None)
    if auto_crop:
        face_wide, annotated = cropper.detect_and_crop_wide(pil)

    target = face_wide if face_wide is not None else pil
    age, top = age_est.predict(target, topk=5)
    probs = {lbl: float(p) for lbl, p in top}
    summary = f"**Estimated age:** {age:.1f} years"
    return probs, summary, (annotated if annotated is not None else pil)

@torch.inference_mode()
def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit, effects,
                     extra_desc, auto_crop=True, strength=0.5, steps=2, seed=-1):
    if img is None:
        return None
    pil = _ensure_pil(img).convert("RGB")

    if auto_crop:
        face_wide, _ = cropper.detect_and_crop_wide(pil)
        if face_wide is not None:
            pil = face_wide

    pil = _resize_512(pil)
    prompt = build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra_desc)

    generator = None
    if isinstance(seed, (int, float)) and int(seed) >= 0:
        generator = torch.Generator(device=age_est.device).manual_seed(int(seed))

    out = sd_pipe(
        prompt=prompt,
        negative_prompt=NEGATIVE_PROMPT,
        image=pil,
        strength=float(strength),      # 0.4–0.6 keeps identity & adds dress/background
        guidance_scale=0.0,            # Turbo likes 0
        num_inference_steps=int(steps),# 1–4 β†’ fast
        generator=generator,
    )
    return out.images[0]

# ------------------ Compact UI ------------------
with gr.Blocks(title="Age + Cartoon (Compact)") as demo:
    gr.Markdown("## Upload β†’ Predict Age β†’ Make Cartoon ✨")
    gr.Markdown("Largest face is used if multiple people are present. Defaults are applied automatically.")

    with gr.Row():
        with gr.Column(scale=1):
            img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
            auto = gr.Checkbox(True, label="Auto face crop (recommended)")

            # Buttons visible immediately (no scrolling)
            with gr.Row():
                btn_age = gr.Button("Predict Age", variant="primary")
                btn_cartoon = gr.Button("Make Cartoon", variant="secondary")

            # Collapsible advanced options
            with gr.Accordion("🎨 Advanced Cartoon Options", open=False):
                role = gr.Dropdown(choices=ROLE_CHOICES, value="Queen/Princess", label="Role")
                background = gr.CheckboxGroup(choices=BACKGROUND_CHOICES, value=["castle balcony at sunset"], label="Background")
                lighting = gr.CheckboxGroup(choices=LIGHTING_CHOICES, value=["soft magical lighting"], label="Lighting")
                artstyle = gr.CheckboxGroup(choices=ARTSTYLE_CHOICES, value=["storybook illustration"], label="Art Style")
                colors = gr.CheckboxGroup(choices=COLOR_CHOICES, value=["vibrant colors"], label="Color Mood")
                outfit = gr.CheckboxGroup(choices=OUTFIT_CHOICES, value=["elegant gown", "jeweled tiara/crown"], label="Outfit / Accessories")
                effects = gr.CheckboxGroup(choices=EFFECTS_CHOICES, value=["sparkles", "glowing particles"], label="Magical Effects")
                extra = gr.Textbox(label="Extra description (optional)", placeholder="e.g., silver tiara, flowing gown, balcony at sunset")
                with gr.Row():
                    strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
                    steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
                    seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")

        with gr.Column(scale=1):
            probs_out = gr.Label(num_top_classes=5, label="Age Prediction")
            age_md = gr.Markdown(label="Age Summary")
            preview = gr.Image(label="Detection Preview")
            cartoon_out = gr.Image(label="Cartoon Result")

    # Wire events
    btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
    btn_cartoon.click(
        fn=generate_cartoon,
        inputs=[img_in, role, background, lighting, artstyle, colors, outfit, effects,
                extra, auto, strength, steps, seed],
        outputs=cartoon_out
    )

# Expose for HF Spaces
app = demo

if __name__ == "__main__":
    app.queue().launch()