hongyu12321 commited on
Commit
7093eab
Β·
verified Β·
1 Parent(s): 482599f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -46
app.py CHANGED
@@ -1,10 +1,11 @@
1
- # app.py β€” Age-first + FAST cartoon (Turbo), nicer framing & magical background
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
 
8
  import gradio as gr
9
  from PIL import Image, ImageDraw
10
  import numpy as np
@@ -20,7 +21,7 @@ AGE_RANGE_TO_MID = {
20
  }
21
 
22
  class PretrainedAgeEstimator:
23
- def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None):
24
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
25
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
26
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
@@ -41,12 +42,12 @@ class PretrainedAgeEstimator:
41
  for i, p in enumerate(probs))
42
  return expected, top
43
 
44
- # ------------------ Face detection with WIDER crop ------------------
45
  from facenet_pytorch import MTCNN
46
 
47
  class FaceCropper:
48
  """Detect faces; return (cropped_wide, annotated). Adds margin so face isn't full screen."""
49
- def __init__(self, device: str | None = None, margin_scale: float = 1.8):
50
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
51
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
52
  self.margin_scale = margin_scale
@@ -56,7 +57,7 @@ class FaceCropper:
56
  return img.convert("RGB")
57
  return Image.fromarray(img).convert("RGB")
58
 
59
- def detect_and_crop_wide(self, img, select="largest"):
60
  pil = self._ensure_pil(img)
61
  W, H = pil.size
62
  boxes, probs = self.mtcnn.detect(pil)
@@ -65,13 +66,7 @@ class FaceCropper:
65
  draw = ImageDraw.Draw(annotated)
66
 
67
  if boxes is None or len(boxes) == 0:
68
- return None, annotated
69
-
70
- # choose largest face
71
- idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
72
- if isinstance(select, int) and 0 <= select < len(boxes):
73
- idx = select
74
- x1, y1, x2, y2 = boxes[idx]
75
 
76
  # draw all boxes
77
  for b, p in zip(boxes, probs):
@@ -79,11 +74,13 @@ class FaceCropper:
79
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
80
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
81
 
82
- # expand with margin
 
 
 
83
  cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
84
  w, h = (x2 - x1), (y2 - y1)
85
- side = max(w, h) * self.margin_scale # wider frame to include background/shoulders
86
- # keep a pleasant portrait aspect (4:5)
87
  target_w = side
88
  target_h = side * 1.25
89
 
@@ -95,8 +92,10 @@ class FaceCropper:
95
  crop = pil.crop((nx1, ny1, nx2, ny2))
96
  return crop, annotated
97
 
98
- # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
99
  from diffusers import AutoPipelineForImage2Image
 
 
100
 
101
  # Turbo is very fast (1–4 steps). Great for stylization on CPU/GPU.
102
  TURBO_ID = "stabilityai/sd-turbo"
@@ -105,10 +104,16 @@ def load_turbo_pipe(device):
105
  dtype = torch.float16 if (device == "cuda") else torch.float32
106
  pipe = AutoPipelineForImage2Image.from_pretrained(
107
  TURBO_ID,
108
- torch_dtype=dtype,
109
- safety_checker=None,
110
  )
111
  pipe = pipe.to(device)
 
 
 
 
 
 
 
112
  try:
113
  pipe.enable_attention_slicing()
114
  except Exception:
@@ -117,16 +122,41 @@ def load_turbo_pipe(device):
117
 
118
  # ------------------ Init models once ------------------
119
  age_est = PretrainedAgeEstimator()
120
- cropper = FaceCropper(device=age_est.device, margin_scale=1.8) # 1.6–2.0 feels good
121
  sd_pipe = load_turbo_pipe(age_est.device)
122
 
123
- # ------------------ Prompts ------------------
124
- DEFAULT_POSITIVE = (
125
- "beautiful princess portrait, elegant gown, tiara, soft magical lighting, "
126
- "sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, "
127
- "storybook illustration, high quality"
128
- )
129
- DEFAULT_NEGATIVE = (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, "
131
  "blurry, watermark, text, logo"
132
  )
@@ -143,7 +173,41 @@ def _resize_512(im: Image.Image):
143
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
144
  return im
145
 
146
- # ------------------ 1) Predict Age (fast) ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  @torch.inference_mode()
148
  def predict_age_only(img, auto_crop=True):
149
  if img is None:
@@ -161,14 +225,14 @@ def predict_age_only(img, auto_crop=True):
161
  summary = f"**Estimated age:** {age:.1f} years"
162
  return probs, summary, (annotated if annotated is not None else img)
163
 
164
- # ------------------ 2) Generate Cartoon (fast) ------------------
165
  @torch.inference_mode()
166
- def generate_cartoon(img, prompt="", auto_crop=True, strength=0.5, steps=2, seed=-1):
 
167
  if img is None:
168
  return None
169
 
170
  img = _ensure_pil(img).convert("RGB")
171
- # use wide face crop to include background/shoulders
172
  if auto_crop:
173
  face_wide, _ = cropper.detect_and_crop_wide(img)
174
  if face_wide is not None:
@@ -176,45 +240,55 @@ def generate_cartoon(img, prompt="", auto_crop=True, strength=0.5, steps=2, seed
176
 
177
  img = _resize_512(img)
178
 
179
- # prompt assembly
180
- user = (prompt or "").strip()
181
- pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}"
182
- neg = DEFAULT_NEGATIVE
183
 
184
  generator = None
185
  if isinstance(seed, (int, float)) and int(seed) >= 0:
186
  generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
187
 
188
- # Turbo likes low steps and guidance ~0
189
  out = sd_pipe(
190
- prompt=pos,
191
- negative_prompt=neg,
192
  image=img,
193
  strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
194
- guidance_scale=0.0, # Turbo typically uses 0
195
  num_inference_steps=int(steps), # 1–4 steps β†’ very fast
196
  generator=generator,
197
  )
198
  return out.images[0]
199
 
200
  # ------------------ UI ------------------
201
- with gr.Blocks(title="Age First + Fast Cartoon") as demo:
202
- gr.Markdown("# Upload or capture once β€” get age prediction first, then a faster cartoon ✨")
 
203
 
204
  with gr.Row():
205
  with gr.Column(scale=1):
206
  img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
207
  auto = gr.Checkbox(True, label="Auto face crop (wide, recommended)")
208
- prompt = gr.Textbox(
209
- label="(Optional) Extra cartoon style",
210
- placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette"
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
 
212
  with gr.Row():
213
  strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
214
  steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
215
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
216
 
217
- btn_age = gr.Button("Predict Age (fast)", variant="primary")
218
  btn_cartoon = gr.Button("Make Cartoon (fast)", variant="secondary")
219
 
220
  with gr.Column(scale=1):
@@ -225,7 +299,15 @@ with gr.Blocks(title="Age First + Fast Cartoon") as demo:
225
 
226
  # Wire the buttons
227
  btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
228
- btn_cartoon.click(fn=generate_cartoon, inputs=[img_in, prompt, auto, strength, steps, seed], outputs=cartoon_out)
 
 
 
 
 
 
 
 
229
 
230
  if __name__ == "__main__":
231
- demo.launch()
 
1
+ # app.py β€” Age-first + FAST cartoon (Turbo) with prompt hint pickers (largest face only)
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
8
+ from typing import Optional
9
  import gradio as gr
10
  from PIL import Image, ImageDraw
11
  import numpy as np
 
21
  }
22
 
23
  class PretrainedAgeEstimator:
24
+ def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None):
25
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
26
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
27
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
 
42
  for i, p in enumerate(probs))
43
  return expected, top
44
 
45
+ # ------------------ Face detection with WIDER crop (largest face) ------------------
46
  from facenet_pytorch import MTCNN
47
 
48
  class FaceCropper:
49
  """Detect faces; return (cropped_wide, annotated). Adds margin so face isn't full screen."""
50
+ def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8):
51
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
52
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
53
  self.margin_scale = margin_scale
 
57
  return img.convert("RGB")
58
  return Image.fromarray(img).convert("RGB")
59
 
60
+ def detect_and_crop_wide(self, img):
61
  pil = self._ensure_pil(img)
62
  W, H = pil.size
63
  boxes, probs = self.mtcnn.detect(pil)
 
66
  draw = ImageDraw.Draw(annotated)
67
 
68
  if boxes is None or len(boxes) == 0:
69
+ return None, annotated # no faces
 
 
 
 
 
 
70
 
71
  # draw all boxes
72
  for b, p in zip(boxes, probs):
 
74
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
75
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
76
 
77
+ # choose largest face
78
+ idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
79
+ x1, y1, x2, y2 = boxes[idx]
80
+ # expand with margin (4:5 portrait feel)
81
  cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
82
  w, h = (x2 - x1), (y2 - y1)
83
+ side = max(w, h) * self.margin_scale
 
84
  target_w = side
85
  target_h = side * 1.25
86
 
 
92
  crop = pil.crop((nx1, ny1, nx2, ny2))
93
  return crop, annotated
94
 
95
+ # ------------------ FAST Cartoonizer (SD-Turbo) with safety ------------------
96
  from diffusers import AutoPipelineForImage2Image
97
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
98
+ from transformers import AutoFeatureExtractor
99
 
100
  # Turbo is very fast (1–4 steps). Great for stylization on CPU/GPU.
101
  TURBO_ID = "stabilityai/sd-turbo"
 
104
  dtype = torch.float16 if (device == "cuda") else torch.float32
105
  pipe = AutoPipelineForImage2Image.from_pretrained(
106
  TURBO_ID,
107
+ dtype=dtype, # βœ… use dtype (no deprecation warning)
 
108
  )
109
  pipe = pipe.to(device)
110
+ # Safety checker ON for public Spaces
111
+ pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
112
+ "CompVis/stable-diffusion-safety-checker"
113
+ )
114
+ pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
115
+ "CompVis/stable-diffusion-safety-checker"
116
+ )
117
  try:
118
  pipe.enable_attention_slicing()
119
  except Exception:
 
122
 
123
  # ------------------ Init models once ------------------
124
  age_est = PretrainedAgeEstimator()
125
+ cropper = FaceCropper(device=age_est.device, margin_scale=1.85) # 1.6–2.0 feels good
126
  sd_pipe = load_turbo_pipe(age_est.device)
127
 
128
+ # ------------------ Prompt hint dictionaries ------------------
129
+ ROLE_CHOICES = [
130
+ "Queen/Princess", "King/Prince", "Fairy", "Elf", "Knight", "Sorcerer/Sorceress",
131
+ "Steampunk Royalty", "Cyberpunk Royalty", "Superhero", "Anime Protagonist"
132
+ ]
133
+ BACKGROUND_CHOICES = [
134
+ "grand castle hall", "castle balcony at sunset", "enchanted forest", "starry night sky",
135
+ "throne room with banners", "crystal palace", "moonlit garden", "winter snow castle",
136
+ "golden hour meadow", "mystical waterfall"
137
+ ]
138
+ LIGHTING_CHOICES = [
139
+ "soft magical lighting", "golden hour rim light", "cinematic soft light",
140
+ "glowing ambience", "volumetric light rays", "dramatic chiaroscuro"
141
+ ]
142
+ ARTSTYLE_CHOICES = [
143
+ "Disney/Pixar style", "Studio Ghibli watercolor", "cel-shaded cartoon",
144
+ "storybook illustration", "painterly brush strokes", "anime lineart"
145
+ ]
146
+ COLOR_CHOICES = [
147
+ "pastel palette", "vibrant colors", "warm tones", "cool tones",
148
+ "iridescent highlights", "royal gold & sapphire"
149
+ ]
150
+ OUTFIT_CHOICES = [
151
+ "elegant gown", "ornate royal cloak", "jeweled tiara/crown",
152
+ "silver diadem", "flowing cape", "intricate embroidery"
153
+ ]
154
+ EFFECTS_CHOICES = [
155
+ "sparkles", "soft bokeh background", "floating petals", "glowing particles",
156
+ "butterflies", "magical aura"
157
+ ]
158
+
159
+ NEGATIVE_PROMPT = (
160
  "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, "
161
  "blurry, watermark, text, logo"
162
  )
 
173
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
174
  return im
175
 
176
+ def build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra):
177
+ bits = []
178
+ # role to base descriptors
179
+ role_map = {
180
+ "Queen/Princess": "regal queen/princess portrait",
181
+ "King/Prince": "regal king/prince portrait",
182
+ "Fairy": "ethereal fairy portrait with delicate wings",
183
+ "Elf": "elven royalty portrait with elegant ears",
184
+ "Knight": "valiant knight portrait in ornate armor",
185
+ "Sorcerer/Sorceress": "mystical sorcerer portrait with arcane motifs",
186
+ "Steampunk Royalty": "steampunk royal portrait with brass filigree",
187
+ "Cyberpunk Royalty": "cyberpunk royal portrait with neon accents",
188
+ "Superhero": "heroic comic-style portrait",
189
+ "Anime Protagonist": "anime protagonist portrait"
190
+ }
191
+ if role:
192
+ bits.append(role_map.get(role, role))
193
+
194
+ # the hint pickers
195
+ for group in (background, lighting, artstyle, colors, outfit, effects):
196
+ if group and isinstance(group, list):
197
+ bits.append(", ".join(group))
198
+
199
+ # strong general quality/style anchors
200
+ bits.append("clean lineart, storybook illustration, high quality")
201
+
202
+ # extra user text
203
+ extra = (extra or "").strip()
204
+ if extra:
205
+ bits.append(extra)
206
+
207
+ # join
208
+ return ", ".join([b for b in bits if b])
209
+
210
+ # ------------------ 1) Predict Age (fast, largest face) ------------------
211
  @torch.inference_mode()
212
  def predict_age_only(img, auto_crop=True):
213
  if img is None:
 
225
  summary = f"**Estimated age:** {age:.1f} years"
226
  return probs, summary, (annotated if annotated is not None else img)
227
 
228
+ # ------------------ 2) Generate Cartoon (fast, largest face) ------------------
229
  @torch.inference_mode()
230
+ def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit, effects,
231
+ extra_desc, auto_crop=True, strength=0.5, steps=2, seed=-1):
232
  if img is None:
233
  return None
234
 
235
  img = _ensure_pil(img).convert("RGB")
 
236
  if auto_crop:
237
  face_wide, _ = cropper.detect_and_crop_wide(img)
238
  if face_wide is not None:
 
240
 
241
  img = _resize_512(img)
242
 
243
+ # prompt assembly from pickers
244
+ prompt = build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra_desc)
 
 
245
 
246
  generator = None
247
  if isinstance(seed, (int, float)) and int(seed) >= 0:
248
  generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
249
 
 
250
  out = sd_pipe(
251
+ prompt=prompt,
252
+ negative_prompt=NEGATIVE_PROMPT,
253
  image=img,
254
  strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
255
+ guidance_scale=0.0, # Turbo commonly uses 0
256
  num_inference_steps=int(steps), # 1–4 steps β†’ very fast
257
  generator=generator,
258
  )
259
  return out.images[0]
260
 
261
  # ------------------ UI ------------------
262
+ with gr.Blocks(title="Age First + Fast Cartoon (with Hint Pickers)") as demo:
263
+ gr.Markdown("# Upload or capture once β€” get age prediction first, then a beautiful cartoon ✨")
264
+ gr.Markdown("Largest face is used if multiple people are present.")
265
 
266
  with gr.Row():
267
  with gr.Column(scale=1):
268
  img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
269
  auto = gr.Checkbox(True, label="Auto face crop (wide, recommended)")
270
+
271
+ # --- Age first
272
+ btn_age = gr.Button("Predict Age (fast)", variant="primary")
273
+
274
+ gr.Markdown("### Cartoon Description Hints")
275
+ role = gr.Dropdown(choices=ROLE_CHOICES, value="Queen/Princess", label="Role")
276
+ background = gr.CheckboxGroup(choices=BACKGROUND_CHOICES, label="Background")
277
+ lighting = gr.CheckboxGroup(choices=LIGHTING_CHOICES, label="Lighting")
278
+ artstyle = gr.CheckboxGroup(choices=ARTSTYLE_CHOICES, label="Art Style")
279
+ colors = gr.CheckboxGroup(choices=COLOR_CHOICES, label="Color Mood")
280
+ outfit = gr.CheckboxGroup(choices=OUTFIT_CHOICES, label="Outfit / Accessories")
281
+ effects = gr.CheckboxGroup(choices=EFFECTS_CHOICES, label="Magical Effects")
282
+ extra = gr.Textbox(
283
+ label="Extra description (optional)",
284
+ placeholder="e.g., silver tiara, flowing gown, castle balcony at sunset"
285
  )
286
+
287
  with gr.Row():
288
  strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
289
  steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
290
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
291
 
 
292
  btn_cartoon = gr.Button("Make Cartoon (fast)", variant="secondary")
293
 
294
  with gr.Column(scale=1):
 
299
 
300
  # Wire the buttons
301
  btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
302
+ btn_cartoon.click(
303
+ fn=generate_cartoon,
304
+ inputs=[img_in, role, background, lighting, artstyle, colors, outfit, effects,
305
+ extra, auto, strength, steps, seed],
306
+ outputs=cartoon_out
307
+ )
308
+
309
+ # Expose app for HF Spaces
310
+ app = demo
311
 
312
  if __name__ == "__main__":
313
+ app.queue().launch()