Files changed (1) hide show
  1. app.py +70 -78
app.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import cv2
6
  import traceback
7
  import gc
8
- from PIL import Image, ImageFilter, ImageEnhance, ImageDraw, ImageFont
9
  from torchvision.transforms import functional as TF
10
  from scipy.ndimage import label
11
  import antialiased_cnns
@@ -15,7 +15,10 @@ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentati
15
  from ultralytics import YOLO
16
  from gfpgan import GFPGANer
17
  import urllib.request
18
- import gradio as gr
 
 
 
19
 
20
  # ========================= CONFIG =========================
21
  AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth"
@@ -96,6 +99,7 @@ def load_aging_model():
96
  if age_model is not None:
97
  return age_model
98
  print("Loading UNet aging model...")
 
99
  class DownLayer(nn.Module):
100
  def __init__(self, in_ch, out_ch):
101
  super().__init__()
@@ -161,9 +165,11 @@ def load_aging_model():
161
  state = torch.load(AGING_MODEL_PATH, map_location=DEVICE, weights_only=True)
162
  age_model.load_state_dict(state)
163
  age_model.eval()
 
164
  if DEVICE.type == "cuda" and int(torch.__version__.split('.')[0]) >= 2:
165
  print("Compiling UNet with torch.compile...")
166
  age_model = torch.compile(age_model, mode="reduce-overhead")
 
167
  print("✅ Aging model loaded!")
168
  return age_model
169
 
@@ -177,9 +183,11 @@ def load_face_parser():
177
  face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
178
  face_parser.to(DEVICE)
179
  face_parser.eval()
 
180
  if DEVICE.type == "cuda" and int(torch.__version__.split('.')[0]) >= 2:
181
  print("Compiling Segformer with torch.compile...")
182
  face_parser = torch.compile(face_parser, mode="reduce-overhead")
 
183
  print("✅ Face parser loaded!")
184
  return face_processor, face_parser
185
 
@@ -196,12 +204,12 @@ def get_lips_mask(pil_image: Image.Image) -> np.ndarray:
196
  img_np = np.array(pil_image)
197
  h, w = img_np.shape[:2]
198
  lips_mask = np.zeros((h, w), dtype=np.uint8)
199
-
200
  with mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True,
201
  min_detection_confidence=0.5) as face_mesh:
202
  rgb_image = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
203
  results = face_mesh.process(rgb_image)
204
-
205
  if results.multi_face_landmarks:
206
  for face_landmarks in results.multi_face_landmarks:
207
  lip_landmarks = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95]
@@ -240,11 +248,11 @@ def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
240
  model = load_beard_model()
241
  results = model(temp_path, device=DEVICE.type, conf=0.25, iou=0.5, verbose=False,
242
  half=True if DEVICE.type == "cuda" else False)
243
-
244
  img_np = np.array(pil_image)
245
  h, w = img_np.shape[:2]
246
  beard_mask = np.zeros((h, w), dtype=np.uint8)
247
-
248
  if results[0].masks is not None:
249
  for i, cls in enumerate(results[0].boxes.cls):
250
  if int(cls) == 0: # beard class
@@ -252,7 +260,7 @@ def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
252
  mask = cv2.resize(mask, (w, h))
253
  mask = (mask > 0.4).astype(np.uint8) * 255
254
  beard_mask = cv2.bitwise_or(beard_mask, mask)
255
-
256
  if np.sum(beard_mask) > 0:
257
  beard_mask_float = beard_mask.astype(np.float32) / 255.0
258
  beard_mask_float = cv2.dilate(beard_mask_float, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2)
@@ -262,7 +270,7 @@ def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
262
  beard_mask_float = cv2.GaussianBlur(beard_mask_float, (7, 7), 2)
263
  beard_mask_float = np.clip(beard_mask_float, 0, 1)
264
  return beard_mask_float
265
-
266
  return np.zeros((h, w), dtype=np.float32)
267
  finally:
268
  if os.path.exists(temp_path):
@@ -280,25 +288,25 @@ def clean_mask(mask, min_area=150):
280
  def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray:
281
  processor, parser = load_face_parser()
282
  inputs = processor(images=pil_image, return_tensors="pt").to(DEVICE)
283
-
284
  with torch.no_grad():
285
  outputs = parser(**inputs)
286
-
287
  logits = outputs.logits
288
  upsampled = torch.nn.functional.interpolate(logits, size=pil_image.size[::-1], mode="bilinear", align_corners=False)
289
  probs = torch.softmax(upsampled, dim=1)[0]
290
  hair_prob = probs[13].cpu().numpy()
291
-
292
  hair_mask = (hair_prob > 0.12).astype(np.uint8)
293
-
294
  face_classes = list(range(1, 6)) + list(range(8, 13)) + [17, 18]
295
  parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
296
  face_mask = np.isin(parsing, face_classes).astype(np.uint8)
297
-
298
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
299
  face_mask = cv2.dilate(face_mask, kernel, iterations=1)
300
  hair_mask = hair_mask * (1 - face_mask)
301
-
302
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
303
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)), iterations=2)
304
  hair_mask = clean_mask(hair_mask, min_area=100)
@@ -310,22 +318,22 @@ def apply_hair_and_beard_color(image: Image.Image, hair_mask: np.ndarray, beard_
310
  combined_mask = np.maximum(hair_mask, beard_mask)
311
  if np.sum(combined_mask) == 0:
312
  return image
313
-
314
  combined_mask = cv2.GaussianBlur(combined_mask, (BLUR_RADIUS*2+1, BLUR_RADIUS*2+1), BLUR_RADIUS)
315
  combined_mask = np.clip(combined_mask, 0, 1)
316
-
317
  if EDGE_SMOOTHING:
318
  combined_mask = cv2.bilateralFilter(combined_mask.astype(np.float32), 9, 75, 75)
319
  combined_mask = np.clip(combined_mask, 0, 1)
320
-
321
  combined_mask = np.clip(combined_mask * 1.2, 0, 1)
322
-
323
  img_np = np.array(image).astype(np.float32)
324
  target_color = np.array([255, 255, 255], dtype=np.float32)
325
  gray = cv2.cvtColor(img_np.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
326
  lum_factor = 0.6 + 0.4 * gray
327
  white_layer = target_color * lum_factor[..., np.newaxis]
328
-
329
  alpha = ALPHA_HAIR
330
  result = (1 - alpha * combined_mask[..., np.newaxis]) * img_np + (alpha * combined_mask[..., np.newaxis]) * white_layer
331
  result = np.clip(result, 0, 255).astype(np.uint8)
@@ -348,34 +356,10 @@ def enhance_texture(img: Image.Image) -> Image.Image:
348
  img = ImageEnhance.Sharpness(img).enhance(SHARPNESS_BOOST)
349
  return img
350
 
351
- def create_comparison(orig, raw_aged, final):
352
- W = 640
353
- def rsz(img):
354
- ratio = img.height / img.width if img.width else 1
355
- return img.resize((W, int(W * ratio)), Image.LANCZOS)
356
-
357
- o, r, f = rsz(orig), rsz(raw_aged), rsz(final)
358
- H = max(o.height, r.height, f.height)
359
- canvas = Image.new("RGB", (W*3, H), (255, 255, 255))
360
- canvas.paste(o, (0, (H - o.height)//2))
361
- canvas.paste(r, (W, (H - r.height)//2))
362
- canvas.paste(f, (W*2, (H - f.height)//2))
363
-
364
- draw = ImageDraw.Draw(canvas)
365
- try:
366
- font = ImageFont.truetype("arial.ttf", 28)
367
- except:
368
- font = ImageFont.load_default()
369
-
370
- draw.text((W//4, 8), "Original", (0, 0, 0), font=font)
371
- draw.text((W + W//5, 8), "Aged Raw", (0, 0, 0), font=font)
372
- draw.text((W*2 + W//6, 8), "Final Result", (0, 0, 0), font=font)
373
- return canvas
374
-
375
  # ================== MAIN PROCESSING FUNCTION ==================
376
- def process_face_aging(input_image: Image.Image):
377
  if input_image is None:
378
- raise gr.Error("Please upload a clear photo of a young person!")
379
 
380
  try:
381
  print(f"→ Processing image: {input_image.size}")
@@ -387,27 +371,27 @@ def process_face_aging(input_image: Image.Image):
387
 
388
  src_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), SOURCE_AGE / 100.0)
389
  tgt_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), TARGET_AGE / 100.0)
 
390
  cond_input = torch.cat([rgb_tensor, src_age, tgt_age], dim=0).unsqueeze(0).to(DEVICE)
391
 
392
  with torch.no_grad():
393
  aging_net = load_aging_model()
394
  raw_output = aging_net(cond_input).squeeze(0)
395
- raw_aged = TF.to_pil_image(raw_output.clamp(0, 1)).resize((ow, oh), Image.LANCZOS)
396
-
397
  alpha = WRINKLE_STRENGTH
398
  blended = (1 - alpha) * rgb_tensor.unsqueeze(0) + alpha * raw_output
399
  blended = blended.clamp(0, 1).squeeze(0)
 
400
  final_aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS)
401
-
402
  final_aged = enhance_texture(final_aged)
403
  final_aged = post_correct_aged(orig, final_aged)
404
 
405
  print(" Generating hair mask...")
406
  hair_mask = get_hair_mask_segformer(final_aged)
407
-
408
  print(" Generating beard mask...")
409
  beard_mask = get_beard_mask(final_aged)
410
-
411
  print(" Applying white hair & beard...")
412
  final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask)
413
 
@@ -424,44 +408,52 @@ def process_face_aging(input_image: Image.Image):
424
  except Exception as e:
425
  print(f" GFPGAN error: {e}")
426
 
427
- comparison = create_comparison(orig, raw_aged, final_img)
428
-
429
  print("✓ Processing completed!")
430
  gc.collect()
431
- return final_img, comparison
432
 
433
  except Exception as e:
434
  print(f"❌ Error: {str(e)}")
435
  traceback.print_exc()
436
- raise gr.Error(f"Processing failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
- # ================== GRADIO INTERFACE ==================
439
- with gr.Blocks(theme=gr.themes.Soft(), title="👴 Face Aging + White Hair & Beard Generator") as demo:
440
- gr.Markdown("# 👴 Face Aging + White Hair & Beard Generator")
441
- gr.Markdown("Upload a clear photo of a young person.<br>This tool will age them to ~90 years with realistic wrinkles and add natural white hair & beard.")
442
 
443
- with gr.Row():
444
- input_img = gr.Image(type="pil", label="Upload Young Face Photo", height=450)
 
 
 
 
 
 
445
 
446
- with gr.Row():
447
- output_img = gr.Image(type="pil", label="Final Aged Result (with White Hair & Beard)", height=450)
448
- comparison_img = gr.Image(type="pil", label="Comparison: Original | Raw Aged | Final", height=450)
449
 
450
- btn = gr.Button("🚀 Generate Aged Face", variant="primary")
 
 
 
451
 
452
- btn.click(
453
- fn=process_face_aging,
454
- inputs=input_img,
455
- outputs=[output_img, comparison_img],
456
- queue=True,
457
- concurrency_limit=2 # ← Yeh line concurrency_count ki jagah use hui
458
- )
459
 
 
460
  if __name__ == "__main__":
461
- print("Starting Face Aging App...")
462
- demo.queue(max_size=8).launch( # concurrency_count add kar do
463
- server_name="0.0.0.0",
464
- server_port=7860,
465
- share=False,
466
- debug=False
467
- )
 
5
  import cv2
6
  import traceback
7
  import gc
8
+ from PIL import Image, ImageFilter, ImageEnhance
9
  from torchvision.transforms import functional as TF
10
  from scipy.ndimage import label
11
  import antialiased_cnns
 
15
  from ultralytics import YOLO
16
  from gfpgan import GFPGANer
17
  import urllib.request
18
+ from fastapi import FastAPI, File, UploadFile, HTTPException
19
+ from fastapi.responses import StreamingResponse
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ import io
22
 
23
  # ========================= CONFIG =========================
24
  AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth"
 
99
  if age_model is not None:
100
  return age_model
101
  print("Loading UNet aging model...")
102
+
103
  class DownLayer(nn.Module):
104
  def __init__(self, in_ch, out_ch):
105
  super().__init__()
 
165
  state = torch.load(AGING_MODEL_PATH, map_location=DEVICE, weights_only=True)
166
  age_model.load_state_dict(state)
167
  age_model.eval()
168
+
169
  if DEVICE.type == "cuda" and int(torch.__version__.split('.')[0]) >= 2:
170
  print("Compiling UNet with torch.compile...")
171
  age_model = torch.compile(age_model, mode="reduce-overhead")
172
+
173
  print("✅ Aging model loaded!")
174
  return age_model
175
 
 
183
  face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
184
  face_parser.to(DEVICE)
185
  face_parser.eval()
186
+
187
  if DEVICE.type == "cuda" and int(torch.__version__.split('.')[0]) >= 2:
188
  print("Compiling Segformer with torch.compile...")
189
  face_parser = torch.compile(face_parser, mode="reduce-overhead")
190
+
191
  print("✅ Face parser loaded!")
192
  return face_processor, face_parser
193
 
 
204
  img_np = np.array(pil_image)
205
  h, w = img_np.shape[:2]
206
  lips_mask = np.zeros((h, w), dtype=np.uint8)
207
+
208
  with mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True,
209
  min_detection_confidence=0.5) as face_mesh:
210
  rgb_image = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
211
  results = face_mesh.process(rgb_image)
212
+
213
  if results.multi_face_landmarks:
214
  for face_landmarks in results.multi_face_landmarks:
215
  lip_landmarks = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95]
 
248
  model = load_beard_model()
249
  results = model(temp_path, device=DEVICE.type, conf=0.25, iou=0.5, verbose=False,
250
  half=True if DEVICE.type == "cuda" else False)
251
+
252
  img_np = np.array(pil_image)
253
  h, w = img_np.shape[:2]
254
  beard_mask = np.zeros((h, w), dtype=np.uint8)
255
+
256
  if results[0].masks is not None:
257
  for i, cls in enumerate(results[0].boxes.cls):
258
  if int(cls) == 0: # beard class
 
260
  mask = cv2.resize(mask, (w, h))
261
  mask = (mask > 0.4).astype(np.uint8) * 255
262
  beard_mask = cv2.bitwise_or(beard_mask, mask)
263
+
264
  if np.sum(beard_mask) > 0:
265
  beard_mask_float = beard_mask.astype(np.float32) / 255.0
266
  beard_mask_float = cv2.dilate(beard_mask_float, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2)
 
270
  beard_mask_float = cv2.GaussianBlur(beard_mask_float, (7, 7), 2)
271
  beard_mask_float = np.clip(beard_mask_float, 0, 1)
272
  return beard_mask_float
273
+
274
  return np.zeros((h, w), dtype=np.float32)
275
  finally:
276
  if os.path.exists(temp_path):
 
288
  def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray:
289
  processor, parser = load_face_parser()
290
  inputs = processor(images=pil_image, return_tensors="pt").to(DEVICE)
291
+
292
  with torch.no_grad():
293
  outputs = parser(**inputs)
294
+
295
  logits = outputs.logits
296
  upsampled = torch.nn.functional.interpolate(logits, size=pil_image.size[::-1], mode="bilinear", align_corners=False)
297
  probs = torch.softmax(upsampled, dim=1)[0]
298
  hair_prob = probs[13].cpu().numpy()
299
+
300
  hair_mask = (hair_prob > 0.12).astype(np.uint8)
301
+
302
  face_classes = list(range(1, 6)) + list(range(8, 13)) + [17, 18]
303
  parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
304
  face_mask = np.isin(parsing, face_classes).astype(np.uint8)
305
+
306
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
307
  face_mask = cv2.dilate(face_mask, kernel, iterations=1)
308
  hair_mask = hair_mask * (1 - face_mask)
309
+
310
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
311
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)), iterations=2)
312
  hair_mask = clean_mask(hair_mask, min_area=100)
 
318
  combined_mask = np.maximum(hair_mask, beard_mask)
319
  if np.sum(combined_mask) == 0:
320
  return image
321
+
322
  combined_mask = cv2.GaussianBlur(combined_mask, (BLUR_RADIUS*2+1, BLUR_RADIUS*2+1), BLUR_RADIUS)
323
  combined_mask = np.clip(combined_mask, 0, 1)
324
+
325
  if EDGE_SMOOTHING:
326
  combined_mask = cv2.bilateralFilter(combined_mask.astype(np.float32), 9, 75, 75)
327
  combined_mask = np.clip(combined_mask, 0, 1)
328
+
329
  combined_mask = np.clip(combined_mask * 1.2, 0, 1)
330
+
331
  img_np = np.array(image).astype(np.float32)
332
  target_color = np.array([255, 255, 255], dtype=np.float32)
333
  gray = cv2.cvtColor(img_np.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
334
  lum_factor = 0.6 + 0.4 * gray
335
  white_layer = target_color * lum_factor[..., np.newaxis]
336
+
337
  alpha = ALPHA_HAIR
338
  result = (1 - alpha * combined_mask[..., np.newaxis]) * img_np + (alpha * combined_mask[..., np.newaxis]) * white_layer
339
  result = np.clip(result, 0, 255).astype(np.uint8)
 
356
  img = ImageEnhance.Sharpness(img).enhance(SHARPNESS_BOOST)
357
  return img
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  # ================== MAIN PROCESSING FUNCTION ==================
360
+ def process_face_aging(input_image: Image.Image) -> Image.Image:
361
  if input_image is None:
362
+ raise ValueError("Please provide a valid image!")
363
 
364
  try:
365
  print(f"→ Processing image: {input_image.size}")
 
371
 
372
  src_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), SOURCE_AGE / 100.0)
373
  tgt_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), TARGET_AGE / 100.0)
374
+
375
  cond_input = torch.cat([rgb_tensor, src_age, tgt_age], dim=0).unsqueeze(0).to(DEVICE)
376
 
377
  with torch.no_grad():
378
  aging_net = load_aging_model()
379
  raw_output = aging_net(cond_input).squeeze(0)
380
+
 
381
  alpha = WRINKLE_STRENGTH
382
  blended = (1 - alpha) * rgb_tensor.unsqueeze(0) + alpha * raw_output
383
  blended = blended.clamp(0, 1).squeeze(0)
384
+
385
  final_aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS)
 
386
  final_aged = enhance_texture(final_aged)
387
  final_aged = post_correct_aged(orig, final_aged)
388
 
389
  print(" Generating hair mask...")
390
  hair_mask = get_hair_mask_segformer(final_aged)
391
+
392
  print(" Generating beard mask...")
393
  beard_mask = get_beard_mask(final_aged)
394
+
395
  print(" Applying white hair & beard...")
396
  final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask)
397
 
 
408
  except Exception as e:
409
  print(f" GFPGAN error: {e}")
410
 
 
 
411
  print("✓ Processing completed!")
412
  gc.collect()
413
+ return final_img
414
 
415
  except Exception as e:
416
  print(f"❌ Error: {str(e)}")
417
  traceback.print_exc()
418
+ raise
419
+
420
+ # ================== FASTAPI SETUP ==================
421
+ app = FastAPI(title="Face Aging + White Hair & Beard API")
422
+
423
+ app.add_middleware(
424
+ CORSMiddleware,
425
+ allow_origins=["*"],
426
+ allow_credentials=True,
427
+ allow_methods=["*"],
428
+ allow_headers=["*"],
429
+ )
430
+
431
+ @app.post("/age-face")
432
+ async def age_face(file: UploadFile = File(...)):
433
+ if not file.content_type.startswith("image/"):
434
+ raise HTTPException(status_code=400, detail="Only image files allowed")
435
 
436
+ contents = await file.read()
 
 
 
437
 
438
+ try:
439
+ input_image = Image.open(io.BytesIO(contents)).convert("RGB")
440
+ result_image = process_face_aging(input_image)
441
+
442
+ # Convert result to bytes
443
+ img_byte_arr = io.BytesIO()
444
+ result_image.save(img_byte_arr, format="PNG")
445
+ img_byte_arr.seek(0)
446
 
447
+ return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
448
 
449
+ except Exception as e:
450
+ raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
451
+ finally:
452
+ gc.collect()
453
 
 
 
 
 
 
 
 
454
 
455
+ # For local testing
456
  if __name__ == "__main__":
457
+ import uvicorn
458
+ print("Starting FastAPI server...")
459
+ uvicorn.run(app, host="0.0.0.0", port=7860)