hongyu12321 commited on
Commit
482599f
·
verified ·
1 Parent(s): f63b4cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -97
app.py CHANGED
@@ -1,27 +1,26 @@
1
- # app.py — Single-face: Age + Gender + Fast Cartoon (Queen/King/Fairy)
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
8
- from typing import Optional
9
  import gradio as gr
10
  from PIL import Image, ImageDraw
11
  import numpy as np
12
  import torch
13
 
14
- # ------------------ Age estimator ------------------
15
  from transformers import AutoImageProcessor, AutoModelForImageClassification
16
 
17
- HF_AGE_ID = "nateraw/vit-age-classifier"
18
  AGE_RANGE_TO_MID = {
19
  "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
20
  "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
21
  }
22
 
23
- class AgeEstimator:
24
- def __init__(self, model_id: str = HF_AGE_ID, device: Optional[str] = None):
25
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
26
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
27
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
@@ -42,51 +41,12 @@ class AgeEstimator:
42
  for i, p in enumerate(probs))
43
  return expected, top
44
 
45
- # ------------------ Gender estimator (best-effort, optional) ------------------
46
- # We try to load a small HF gender classifier. If unavailable, we return "unknown".
47
- _GENDER_MODEL_IDS = [
48
- "phiyodr/vit-gender-classification", # (common community model)
49
- "rizvandwiki/gender-classification", # fallback
50
- ]
51
- class GenderEstimator:
52
- def __init__(self, device: Optional[str] = None):
53
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
54
- self.model = None
55
- self.processor = None
56
- self.id2label = None
57
- from transformers import AutoImageProcessor, AutoModelForImageClassification
58
- for mid in _GENDER_MODEL_IDS:
59
- try:
60
- self.processor = AutoImageProcessor.from_pretrained(mid, use_fast=True)
61
- self.model = AutoModelForImageClassification.from_pretrained(mid)
62
- self.model.to(self.device).eval()
63
- self.id2label = self.model.config.id2label
64
- self.model_id = mid
65
- break
66
- except Exception:
67
- self.processor = None
68
- self.model = None
69
- self.id2label = None
70
- self.available = self.model is not None
71
-
72
- @torch.inference_mode()
73
- def predict(self, img: Image.Image):
74
- if not self.available:
75
- return "unknown", 0.0
76
- if img.mode != "RGB":
77
- img = img.convert("RGB")
78
- inputs = self.processor(images=img, return_tensors="pt").to(self.device)
79
- logits = self.model(**inputs).logits
80
- probs = logits.softmax(dim=-1).squeeze(0)
81
- score, idx = torch.max(probs, dim=0)
82
- label = self.id2label[idx.item()]
83
- return label, float(score.item())
84
-
85
- # ------------------ Largest-face detector with margin ------------------
86
  from facenet_pytorch import MTCNN
87
 
88
  class FaceCropper:
89
- def __init__(self, device: Optional[str] = None, margin_scale: float = 1.9):
 
90
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
91
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
92
  self.margin_scale = margin_scale
@@ -96,58 +56,59 @@ class FaceCropper:
96
  return img.convert("RGB")
97
  return Image.fromarray(img).convert("RGB")
98
 
99
- def _expand_box(self, box, W, H, aspect=0.8): # ~4:5 portrait
100
- x1, y1, x2, y2 = box
101
- cx, cy = (x1 + x2)/2, (y1 + y2)/2
102
- w, h = (x2 - x1), (y2 - y1)
103
- side = max(w, h) * self.margin_scale
104
- tw = side
105
- th = side / aspect
106
- nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
107
- ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
108
- return nx1, ny1, nx2, ny2
109
-
110
- def detect_largest_wide(self, img):
111
  pil = self._ensure_pil(img)
112
  W, H = pil.size
113
  boxes, probs = self.mtcnn.detect(pil)
114
 
115
  annotated = pil.copy()
116
  draw = ImageDraw.Draw(annotated)
 
117
  if boxes is None or len(boxes) == 0:
118
- return None, annotated # no detection
119
 
120
- # draw all detections
 
 
 
 
 
 
121
  for b, p in zip(boxes, probs):
122
  bx1, by1, bx2, by2 = map(float, b)
123
- draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
124
- draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(0, 200, 255))
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # pick largest box
127
- idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
128
- nx1, ny1, nx2, ny2 = self._expand_box(boxes[idx], W, H)
129
  crop = pil.crop((nx1, ny1, nx2, ny2))
130
  return crop, annotated
131
 
132
- # ------------------ FAST Cartoonizer (SD-Turbo, with safety) ------------------
133
  from diffusers import AutoPipelineForImage2Image
134
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
135
- from transformers import AutoFeatureExtractor
136
 
 
137
  TURBO_ID = "stabilityai/sd-turbo"
 
138
  def load_turbo_pipe(device):
139
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
140
  pipe = AutoPipelineForImage2Image.from_pretrained(
141
  TURBO_ID,
142
- dtype=dtype, # no deprecation warning
143
- ).to(device)
144
- # enable safety checker for public Space
145
- pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
146
- "CompVis/stable-diffusion-safety-checker"
147
- )
148
- pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
149
- "CompVis/stable-diffusion-safety-checker"
150
  )
 
151
  try:
152
  pipe.enable_attention_slicing()
153
  except Exception:
@@ -155,42 +116,116 @@ def load_turbo_pipe(device):
155
  return pipe
156
 
157
  # ------------------ Init models once ------------------
158
- age_est = AgeEstimator()
159
- gender_est = GenderEstimator(device=age_est.device)
160
- cropper = FaceCropper(device=age_est.device, margin_scale=1.9)
161
  sd_pipe = load_turbo_pipe(age_est.device)
162
 
163
  # ------------------ Prompts ------------------
164
- STYLE_BASE = {
165
- "Queen": "regal queen portrait, elegant royal gown, jeweled tiara, ornate details, dreamy castle background, soft magical lighting, sparkles, storybook illustration, high quality",
166
- "King": "regal king portrait, ornate royal cloak and crown, majestic posture, grand throne room background, cinematic soft lighting, painterly style, storybook illustration, high quality",
167
- "Fairy": "fairy portrait, ethereal wings, glowing particles, enchanted forest background, luminous soft lighting, delicate dress, whimsical, storybook illustration, high quality",
168
- }
169
- NEGATIVE_PROMPT = "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, blurry, watermark, text, logo"
170
-
 
 
 
 
171
  def _ensure_pil(img):
172
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
173
 
174
  def _resize_512(im: Image.Image):
 
175
  w, h = im.size
176
  scale = 512 / max(w, h)
177
  if scale < 1.0:
178
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
179
  return im
180
 
181
- # ------------------ 1) Predict Age+Gender (fast) ------------------
182
  @torch.inference_mode()
183
- def predict_age_gender(img, auto_crop=True):
184
  if img is None:
185
  return {}, "Please upload an image.", None
186
- pil = _ensure_pil(img).convert("RGB")
187
 
188
- face_wide, annotated = (None, None)
 
189
  if auto_crop:
190
- face_wide, annotated = cropper.detect_largest_wide(pil)
 
191
 
192
- target = face_wide if face_wide is not None else pil
193
-
194
- # age
195
  age, top = age_est.predict(target, topk=5)
196
  probs = {lbl: float(p) for lbl, p in top}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Age-first + FAST cartoon (Turbo), nicer framing & magical background
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
5
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
 
8
  import gradio as gr
9
  from PIL import Image, ImageDraw
10
  import numpy as np
11
  import torch
12
 
13
+ # ------------------ Age estimator (Hugging Face) ------------------
14
  from transformers import AutoImageProcessor, AutoModelForImageClassification
15
 
16
+ HF_MODEL_ID = "nateraw/vit-age-classifier"
17
  AGE_RANGE_TO_MID = {
18
  "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
19
  "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
20
  }
21
 
22
+ class PretrainedAgeEstimator:
23
+ def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None):
24
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
25
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
26
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
 
41
  for i, p in enumerate(probs))
42
  return expected, top
43
 
44
+ # ------------------ Face detection with WIDER crop ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  from facenet_pytorch import MTCNN
46
 
47
  class FaceCropper:
48
+ """Detect faces; return (cropped_wide, annotated). Adds margin so face isn't full screen."""
49
+ def __init__(self, device: str | None = None, margin_scale: float = 1.8):
50
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
51
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
52
  self.margin_scale = margin_scale
 
56
  return img.convert("RGB")
57
  return Image.fromarray(img).convert("RGB")
58
 
59
+ def detect_and_crop_wide(self, img, select="largest"):
 
 
 
 
 
 
 
 
 
 
 
60
  pil = self._ensure_pil(img)
61
  W, H = pil.size
62
  boxes, probs = self.mtcnn.detect(pil)
63
 
64
  annotated = pil.copy()
65
  draw = ImageDraw.Draw(annotated)
66
+
67
  if boxes is None or len(boxes) == 0:
68
+ return None, annotated
69
 
70
+ # choose largest face
71
+ idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
72
+ if isinstance(select, int) and 0 <= select < len(boxes):
73
+ idx = select
74
+ x1, y1, x2, y2 = boxes[idx]
75
+
76
+ # draw all boxes
77
  for b, p in zip(boxes, probs):
78
  bx1, by1, bx2, by2 = map(float, b)
79
+ draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
80
+ draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
81
+
82
+ # expand with margin
83
+ cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
84
+ w, h = (x2 - x1), (y2 - y1)
85
+ side = max(w, h) * self.margin_scale # wider frame to include background/shoulders
86
+ # keep a pleasant portrait aspect (4:5)
87
+ target_w = side
88
+ target_h = side * 1.25
89
+
90
+ nx1 = int(max(0, cx - target_w/2))
91
+ nx2 = int(min(W, cx + target_w/2))
92
+ ny1 = int(max(0, cy - target_h/2))
93
+ ny2 = int(min(H, cy + target_h/2))
94
 
 
 
 
95
  crop = pil.crop((nx1, ny1, nx2, ny2))
96
  return crop, annotated
97
 
98
+ # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
99
  from diffusers import AutoPipelineForImage2Image
 
 
100
 
101
+ # Turbo is very fast (1–4 steps). Great for stylization on CPU/GPU.
102
  TURBO_ID = "stabilityai/sd-turbo"
103
+
104
  def load_turbo_pipe(device):
105
+ dtype = torch.float16 if (device == "cuda") else torch.float32
106
  pipe = AutoPipelineForImage2Image.from_pretrained(
107
  TURBO_ID,
108
+ torch_dtype=dtype,
109
+ safety_checker=None,
 
 
 
 
 
 
110
  )
111
+ pipe = pipe.to(device)
112
  try:
113
  pipe.enable_attention_slicing()
114
  except Exception:
 
116
  return pipe
117
 
118
  # ------------------ Init models once ------------------
119
+ age_est = PretrainedAgeEstimator()
120
+ cropper = FaceCropper(device=age_est.device, margin_scale=1.8) # 1.6–2.0 feels good
 
121
  sd_pipe = load_turbo_pipe(age_est.device)
122
 
123
  # ------------------ Prompts ------------------
124
+ DEFAULT_POSITIVE = (
125
+ "beautiful princess portrait, elegant gown, tiara, soft magical lighting, "
126
+ "sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, "
127
+ "storybook illustration, high quality"
128
+ )
129
+ DEFAULT_NEGATIVE = (
130
+ "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, "
131
+ "blurry, watermark, text, logo"
132
+ )
133
+
134
+ # ------------------ Helpers ------------------
135
  def _ensure_pil(img):
136
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
137
 
138
  def _resize_512(im: Image.Image):
139
+ # keep aspect, fit longest side to 512 (faster, fewer artifacts)
140
  w, h = im.size
141
  scale = 512 / max(w, h)
142
  if scale < 1.0:
143
  im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS)
144
  return im
145
 
146
+ # ------------------ 1) Predict Age (fast) ------------------
147
  @torch.inference_mode()
148
+ def predict_age_only(img, auto_crop=True):
149
  if img is None:
150
  return {}, "Please upload an image.", None
151
+ img = _ensure_pil(img).convert("RGB")
152
 
153
+ face_wide = None
154
+ annotated = None
155
  if auto_crop:
156
+ face_wide, annotated = cropper.detect_and_crop_wide(img)
157
+ target = face_wide if face_wide is not None else img
158
 
 
 
 
159
  age, top = age_est.predict(target, topk=5)
160
  probs = {lbl: float(p) for lbl, p in top}
161
+ summary = f"**Estimated age:** {age:.1f} years"
162
+ return probs, summary, (annotated if annotated is not None else img)
163
+
164
+ # ------------------ 2) Generate Cartoon (fast) ------------------
165
+ @torch.inference_mode()
166
+ def generate_cartoon(img, prompt="", auto_crop=True, strength=0.5, steps=2, seed=-1):
167
+ if img is None:
168
+ return None
169
+
170
+ img = _ensure_pil(img).convert("RGB")
171
+ # use wide face crop to include background/shoulders
172
+ if auto_crop:
173
+ face_wide, _ = cropper.detect_and_crop_wide(img)
174
+ if face_wide is not None:
175
+ img = face_wide
176
+
177
+ img = _resize_512(img)
178
+
179
+ # prompt assembly
180
+ user = (prompt or "").strip()
181
+ pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}"
182
+ neg = DEFAULT_NEGATIVE
183
+
184
+ generator = None
185
+ if isinstance(seed, (int, float)) and int(seed) >= 0:
186
+ generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
187
+
188
+ # Turbo likes low steps and guidance ~0
189
+ out = sd_pipe(
190
+ prompt=pos,
191
+ negative_prompt=neg,
192
+ image=img,
193
+ strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
194
+ guidance_scale=0.0, # Turbo typically uses 0
195
+ num_inference_steps=int(steps), # 1–4 steps → very fast
196
+ generator=generator,
197
+ )
198
+ return out.images[0]
199
+
200
+ # ------------------ UI ------------------
201
+ with gr.Blocks(title="Age First + Fast Cartoon") as demo:
202
+ gr.Markdown("# Upload or capture once — get age prediction first, then a faster cartoon ✨")
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=1):
206
+ img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
207
+ auto = gr.Checkbox(True, label="Auto face crop (wide, recommended)")
208
+ prompt = gr.Textbox(
209
+ label="(Optional) Extra cartoon style",
210
+ placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette"
211
+ )
212
+ with gr.Row():
213
+ strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
214
+ steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
215
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
216
+
217
+ btn_age = gr.Button("Predict Age (fast)", variant="primary")
218
+ btn_cartoon = gr.Button("Make Cartoon (fast)", variant="secondary")
219
+
220
+ with gr.Column(scale=1):
221
+ probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
222
+ age_md = gr.Markdown(label="Age Summary")
223
+ preview = gr.Image(label="Detection Preview")
224
+ cartoon_out = gr.Image(label="Cartoon Result")
225
+
226
+ # Wire the buttons
227
+ btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
228
+ btn_cartoon.click(fn=generate_cartoon, inputs=[img_in, prompt, auto, strength, steps, seed], outputs=cartoon_out)
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch()