hongyu12321 commited on
Commit
2187ded
·
verified ·
1 Parent(s): 165f68d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -100
app.py CHANGED
@@ -1,16 +1,17 @@
1
- # app.py — Age-first + FAST cartoon (Turbo), nicer framing & magical background
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
 
8
  import gradio as gr
9
  from PIL import Image, ImageDraw
10
  import numpy as np
11
  import torch
12
 
13
- # ------------------ Age estimator (Hugging Face) ------------------
14
  from transformers import AutoImageProcessor, AutoModelForImageClassification
15
 
16
  HF_MODEL_ID = "nateraw/vit-age-classifier"
@@ -41,11 +42,16 @@ class PretrainedAgeEstimator:
41
  for i, p in enumerate(probs))
42
  return expected, top
43
 
44
- # ------------------ Face detection with WIDER crop ------------------
45
  from facenet_pytorch import MTCNN
46
 
47
  class FaceCropper:
48
- """Detect faces; return (cropped_wide, annotated). Adds margin so face isn't full screen."""
 
 
 
 
 
49
  def __init__(self, device: str | None = None, margin_scale: float = 1.8):
50
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
51
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
@@ -56,49 +62,65 @@ class FaceCropper:
56
  return img.convert("RGB")
57
  return Image.fromarray(img).convert("RGB")
58
 
59
- def detect_and_crop_wide(self, img, select="largest"):
 
 
 
 
 
 
 
 
 
 
 
60
  pil = self._ensure_pil(img)
61
  W, H = pil.size
62
  boxes, probs = self.mtcnn.detect(pil)
63
 
64
  annotated = pil.copy()
65
  draw = ImageDraw.Draw(annotated)
66
-
67
  if boxes is None or len(boxes) == 0:
68
  return None, annotated
69
 
70
- # choose largest face
71
- idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
72
- if isinstance(select, int) and 0 <= select < len(boxes):
73
- idx = select
74
- x1, y1, x2, y2 = boxes[idx]
75
-
76
  # draw all boxes
77
  for b, p in zip(boxes, probs):
78
  bx1, by1, bx2, by2 = map(float, b)
79
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
80
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
81
 
82
- # expand with margin
83
- cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
84
- w, h = (x2 - x1), (y2 - y1)
85
- side = max(w, h) * self.margin_scale # wider frame to include background/shoulders
86
- # keep a pleasant portrait aspect (4:5)
87
- target_w = side
88
- target_h = side * 1.25
89
-
90
- nx1 = int(max(0, cx - target_w/2))
91
- nx2 = int(min(W, cx + target_w/2))
92
- ny1 = int(max(0, cy - target_h/2))
93
- ny2 = int(min(H, cy + target_h/2))
94
-
95
  crop = pil.crop((nx1, ny1, nx2, ny2))
96
  return crop, annotated
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
99
  from diffusers import AutoPipelineForImage2Image
100
-
101
- # Turbo is very fast (1–4 steps). Great for stylization on CPU/GPU.
102
  TURBO_ID = "stabilityai/sd-turbo"
103
 
104
  def load_turbo_pipe(device):
@@ -107,20 +129,19 @@ def load_turbo_pipe(device):
107
  TURBO_ID,
108
  torch_dtype=dtype,
109
  safety_checker=None,
110
- )
111
- pipe = pipe.to(device)
112
  try:
113
  pipe.enable_attention_slicing()
114
  except Exception:
115
  pass
116
  return pipe
117
 
118
- # ------------------ Init models once ------------------
119
  age_est = PretrainedAgeEstimator()
120
- cropper = FaceCropper(device=age_est.device, margin_scale=1.8) # 1.6–2.0 feels good
121
  sd_pipe = load_turbo_pipe(age_est.device)
122
 
123
- # ------------------ Prompts ------------------
124
  DEFAULT_POSITIVE = (
125
  "beautiful princess portrait, elegant gown, tiara, soft magical lighting, "
126
  "sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, "
@@ -131,52 +152,62 @@ DEFAULT_NEGATIVE = (
131
  "blurry, watermark, text, logo"
132
  )
133
 
134
- # ------------------ Helpers ------------------
135
  def _ensure_pil(img):
136
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
137
 
138
  def _resize_512(im: Image.Image):
139
- # keep aspect, fit longest side to 512 (faster, fewer artifacts)
140
  w, h = im.size
141
  scale = 512 / max(w, h)
142
  if scale < 1.0:
143
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
144
  return im
145
 
146
- # ------------------ 1) Predict Age (fast) ------------------
147
  @torch.inference_mode()
148
- def predict_age_only(img, auto_crop=True):
149
  if img is None:
150
  return {}, "Please upload an image.", None
151
- img = _ensure_pil(img).convert("RGB")
152
 
153
- face_wide = None
154
- annotated = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if auto_crop:
156
- face_wide, annotated = cropper.detect_and_crop_wide(img)
157
- target = face_wide if face_wide is not None else img
158
-
159
  age, top = age_est.predict(target, topk=5)
160
  probs = {lbl: float(p) for lbl, p in top}
161
- summary = f"**Estimated age:** {age:.1f} years"
162
- return probs, summary, (annotated if annotated is not None else img)
163
 
164
- # ------------------ 2) Generate Cartoon (fast) ------------------
165
  @torch.inference_mode()
166
- def generate_cartoon(img, prompt="", auto_crop=True, strength=0.5, steps=2, seed=-1):
167
  if img is None:
168
  return None
 
169
 
170
- img = _ensure_pil(img).convert("RGB")
171
- # use wide face crop to include background/shoulders
172
- if auto_crop:
173
- face_wide, _ = cropper.detect_and_crop_wide(img)
174
- if face_wide is not None:
175
- img = face_wide
176
-
177
- img = _resize_512(img)
178
-
179
- # prompt assembly
180
  user = (prompt or "").strip()
181
  pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}"
182
  neg = DEFAULT_NEGATIVE
@@ -185,47 +216,13 @@ def generate_cartoon(img, prompt="", auto_crop=True, strength=0.5, steps=2, seed
185
  if isinstance(seed, (int, float)) and int(seed) >= 0:
186
  generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
187
 
188
- # Turbo likes low steps and guidance ~0
189
- out = sd_pipe(
190
- prompt=pos,
191
- negative_prompt=neg,
192
- image=img,
193
- strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
194
- guidance_scale=0.0, # Turbo typically uses 0
195
- num_inference_steps=int(steps), # 1–4 steps → very fast
196
- generator=generator,
197
- )
198
- return out.images[0]
199
-
200
- # ------------------ UI ------------------
201
- with gr.Blocks(title="Age First + Fast Cartoon") as demo:
202
- gr.Markdown("# Upload or capture once — get age prediction first, then a faster cartoon ✨")
203
-
204
- with gr.Row():
205
- with gr.Column(scale=1):
206
- img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
207
- auto = gr.Checkbox(True, label="Auto face crop (wide, recommended)")
208
- prompt = gr.Textbox(
209
- label="(Optional) Extra cartoon style",
210
- placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette"
211
- )
212
- with gr.Row():
213
- strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
214
- steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
215
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
216
-
217
- btn_age = gr.Button("Predict Age (fast)", variant="primary")
218
- btn_cartoon = gr.Button("Make Cartoon (fast)", variant="secondary")
219
-
220
- with gr.Column(scale=1):
221
- probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
222
- age_md = gr.Markdown(label="Age Summary")
223
- preview = gr.Image(label="Detection Preview")
224
- cartoon_out = gr.Image(label="Cartoon Result")
225
-
226
- # Wire the buttons
227
- btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
228
- btn_cartoon.click(fn=generate_cartoon, inputs=[img_in, prompt, auto, strength, steps, seed], outputs=cartoon_out)
229
-
230
- if __name__ == "__main__":
231
- demo.launch()
 
1
+ # app.py — Age-first + FAST group cartoons (SD-Turbo), single page
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
8
+ import math
9
  import gradio as gr
10
  from PIL import Image, ImageDraw
11
  import numpy as np
12
  import torch
13
 
14
+ # ------------------ Age estimator ------------------
15
  from transformers import AutoImageProcessor, AutoModelForImageClassification
16
 
17
  HF_MODEL_ID = "nateraw/vit-age-classifier"
 
42
  for i, p in enumerate(probs))
43
  return expected, top
44
 
45
+ # ------------------ Face detection (single & group) ------------------
46
  from facenet_pytorch import MTCNN
47
 
48
  class FaceCropper:
49
+ """
50
+ Detect faces.
51
+ - detect_one_wide: returns (crop_with_margin, annotated)
52
+ - detect_all_wide: returns (list[crops], annotated, list[boxes])
53
+ Boxes are (x1,y1,x2,y2) floats.
54
+ """
55
  def __init__(self, device: str | None = None, margin_scale: float = 1.8):
56
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
57
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
 
62
  return img.convert("RGB")
63
  return Image.fromarray(img).convert("RGB")
64
 
65
+ def _expand_box(self, box, W, H, aspect=0.8): # 4:5 portrait (w/h=0.8)
66
+ x1, y1, x2, y2 = box
67
+ cx, cy = (x1 + x2)/2, (y1 + y2)/2
68
+ w, h = (x2 - x1), (y2 - y1)
69
+ side = max(w, h) * self.margin_scale
70
+ tw = side
71
+ th = side / aspect # make it taller than wide
72
+ nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
73
+ ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
74
+ return nx1, ny1, nx2, ny2
75
+
76
+ def detect_one_wide(self, img):
77
  pil = self._ensure_pil(img)
78
  W, H = pil.size
79
  boxes, probs = self.mtcnn.detect(pil)
80
 
81
  annotated = pil.copy()
82
  draw = ImageDraw.Draw(annotated)
 
83
  if boxes is None or len(boxes) == 0:
84
  return None, annotated
85
 
 
 
 
 
 
 
86
  # draw all boxes
87
  for b, p in zip(boxes, probs):
88
  bx1, by1, bx2, by2 = map(float, b)
89
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
90
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
91
 
92
+ # choose largest
93
+ idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
94
+ nx1, ny1, nx2, ny2 = self._expand_box(boxes[idx], W, H)
 
 
 
 
 
 
 
 
 
 
95
  crop = pil.crop((nx1, ny1, nx2, ny2))
96
  return crop, annotated
97
 
98
+ def detect_all_wide(self, img):
99
+ pil = self._ensure_pil(img)
100
+ W, H = pil.size
101
+ boxes, probs = self.mtcnn.detect(pil)
102
+
103
+ annotated = pil.copy()
104
+ draw = ImageDraw.Draw(annotated)
105
+ crops = []
106
+ ordered = []
107
+
108
+ if boxes is None or len(boxes) == 0:
109
+ return crops, annotated, []
110
+
111
+ for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
112
+ bx1, by1, bx2, by2 = map(float, b)
113
+ draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
114
+ draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(0, 200, 255))
115
+
116
+ nx1, ny1, nx2, ny2 = self._expand_box(b, W, H)
117
+ crops.append(pil.crop((nx1, ny1, nx2, ny2)))
118
+ ordered.append((bx1, by1, bx2, by2))
119
+
120
+ return crops, annotated, ordered
121
+
122
  # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
123
  from diffusers import AutoPipelineForImage2Image
 
 
124
  TURBO_ID = "stabilityai/sd-turbo"
125
 
126
  def load_turbo_pipe(device):
 
129
  TURBO_ID,
130
  torch_dtype=dtype,
131
  safety_checker=None,
132
+ ).to(device)
 
133
  try:
134
  pipe.enable_attention_slicing()
135
  except Exception:
136
  pass
137
  return pipe
138
 
139
+ # init models once
140
  age_est = PretrainedAgeEstimator()
141
+ cropper = FaceCropper(device=age_est.device, margin_scale=1.9)
142
  sd_pipe = load_turbo_pipe(age_est.device)
143
 
144
+ # prompts
145
  DEFAULT_POSITIVE = (
146
  "beautiful princess portrait, elegant gown, tiara, soft magical lighting, "
147
  "sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, "
 
152
  "blurry, watermark, text, logo"
153
  )
154
 
 
155
  def _ensure_pil(img):
156
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
157
 
158
  def _resize_512(im: Image.Image):
 
159
  w, h = im.size
160
  scale = 512 / max(w, h)
161
  if scale < 1.0:
162
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
163
  return im
164
 
165
+ # ------------- AGE (single/group) -------------
166
  @torch.inference_mode()
167
+ def predict_age(img, group_mode=False, auto_crop=True):
168
  if img is None:
169
  return {}, "Please upload an image.", None
 
170
 
171
+ pil = _ensure_pil(img).convert("RGB")
172
+
173
+ if group_mode:
174
+ crops, annotated, boxes = cropper.detect_all_wide(pil)
175
+ if not crops:
176
+ # fallback to full image
177
+ age, top = age_est.predict(pil, topk=5)
178
+ probs = {lbl: float(p) for lbl, p in top}
179
+ md = f"**Estimated age (whole image):** {age:.1f} years"
180
+ return probs, md, pil
181
+
182
+ # per-face ages
183
+ rows = ["| # | Age (yrs) | Top-1 | p |", "|---:|---:|---|---:|"]
184
+ for i, face in enumerate(crops, 1):
185
+ age, top = age_est.predict(face, topk=3)
186
+ top1, p1 = top[0]
187
+ rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
188
+ md = "\n".join(rows)
189
+ # also return a simple dict from the largest (first) face just to feed Label
190
+ age0, top0 = age_est.predict(crops[0], topk=5)
191
+ probs0 = {lbl: float(p) for lbl, p in top0}
192
+ return probs0, md, annotated
193
+
194
+ # single
195
+ face_wide = None; annotated = None
196
  if auto_crop:
197
+ face_wide, annotated = cropper.detect_one_wide(pil)
198
+ target = face_wide if face_wide is not None else pil
 
199
  age, top = age_est.predict(target, topk=5)
200
  probs = {lbl: float(p) for lbl, p in top}
201
+ md = f"**Estimated age:** {age:.1f} years"
202
+ return probs, md, (annotated if annotated is not None else pil)
203
 
204
+ # ------------- CARTOON (single/group) -------------
205
  @torch.inference_mode()
206
+ def cartoonize(img, prompt="", group_mode=False, auto_crop=True, strength=0.5, steps=2, seed=-1):
207
  if img is None:
208
  return None
209
+ pil = _ensure_pil(img).convert("RGB")
210
 
 
 
 
 
 
 
 
 
 
 
211
  user = (prompt or "").strip()
212
  pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}"
213
  neg = DEFAULT_NEGATIVE
 
216
  if isinstance(seed, (int, float)) and int(seed) >= 0:
217
  generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
218
 
219
+ if group_mode:
220
+ # detect all faces, stylize each, assemble grid
221
+ crops, _, _ = cropper.detect_all_wide(pil)
222
+ if not crops:
223
+ crops = [pil] # fallback
224
+
225
+ # resize each to 384 for speed/variety
226
+ proc = []
227
+ for c in crops:
228
+ c = _resiz_