Files changed (1) hide show
  1. app.py +98 -42
app.py CHANGED
@@ -9,6 +9,9 @@ import cv2
9
  import traceback
10
  import gc
11
  import uuid
 
 
 
12
  from PIL import Image, ImageFilter, ImageEnhance
13
  from torchvision.transforms import functional as TF
14
  from scipy.ndimage import label
@@ -23,8 +26,6 @@ from fastapi import FastAPI, File, UploadFile, HTTPException
23
  from fastapi.responses import StreamingResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
  import io
26
- import asyncio
27
- from concurrent.futures import ThreadPoolExecutor
28
  import logging
29
 
30
  logging.basicConfig(level=logging.INFO)
@@ -37,7 +38,7 @@ AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth"
37
  BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt"
38
  GFPGAN_MODEL_PATH = "GFPGANv1.4.pth"
39
 
40
- SAFE_IMG_SIZE = 384
41
  SOURCE_AGE = 20
42
  TARGET_AGE = 80
43
  WRINKLE_STRENGTH = 0.42
@@ -52,7 +53,6 @@ GFPGAN_WEIGHT = 0.5
52
 
53
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  USE_FP16 = DEVICE.type == "cuda" and torch.cuda.is_available()
55
-
56
  logger.info(f"🚀 Device: {DEVICE}, FP16: {USE_FP16}")
57
 
58
  os.environ["HF_HOME"] = "/tmp/hf_cache"
@@ -199,7 +199,7 @@ def load_aging_model():
199
  return age_model
200
 
201
  # ================================================
202
- # 6. LOAD FACE PARSER & BEARD MODEL
203
  # ================================================
204
  def load_face_parser():
205
  global face_processor, face_parser
@@ -225,7 +225,7 @@ def load_beard_model():
225
  return beard_model
226
 
227
  # ================================================
228
- # 7. MASK FUNCTIONS (shortened for space - same as before)
229
  # ================================================
230
  def get_lips_mask(pil_image: Image.Image) -> np.ndarray:
231
  img_np = np.array(pil_image.resize((256, 256), Image.LANCZOS))
@@ -254,9 +254,10 @@ def exclude_lips_from_mask(beard_mask, pil_image):
254
  def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
255
  temp = f"temp_{uuid.uuid4().hex[:8]}.jpg"
256
  try:
257
- pil_image.resize((384,384), Image.LANCZOS).save(temp)
 
258
  model = load_beard_model()
259
- res = model(temp, device=DEVICE.type, conf=0.3, iou=0.5, verbose=False, half=USE_FP16, imgsz=384)
260
  h, w = np.array(pil_image).shape[:2]
261
  mask = np.zeros((h,w), dtype=np.uint8)
262
  if res[0].masks is not None:
@@ -287,7 +288,8 @@ def clean_mask(mask, min_area=100):
287
 
288
  def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray:
289
  processor, parser = load_face_parser()
290
- img_r = pil_image.resize((384,384), Image.LANCZOS)
 
291
  inputs = processor(images=img_r, return_tensors="pt").to(DEVICE)
292
  if USE_FP16: inputs['pixel_values'] = inputs['pixel_values'].half()
293
  with torch.no_grad():
@@ -345,42 +347,83 @@ def process_masks_parallel(image):
345
  return h.result(), b.result()
346
 
347
  # ================================================
348
- # 8. MAIN PROCESSING (FIXED tensor dimension)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  # ================================================
350
  def process_face_aging(input_image: Image.Image) -> Image.Image:
 
351
  try:
352
  logger.info(f"→ Processing image: {input_image.size}")
353
  orig = input_image.convert("RGB")
354
  ow, oh = orig.size
355
 
 
356
  target_size = min(SAFE_IMG_SIZE, max(ow, oh))
357
  if target_size % 2 == 1:
358
  target_size -= 1
359
 
360
  img_resized = orig.resize((target_size, target_size), Image.LANCZOS)
361
- rgb_tensor = TF.to_tensor(img_resized) # shape: (3, H, W)
362
-
363
  src_age = torch.full((1, target_size, target_size), SOURCE_AGE / 100.0)
364
  tgt_age = torch.full((1, target_size, target_size), TARGET_AGE / 100.0)
365
-
366
- # FIXED: Make all 4D tensors
367
  cond_input = torch.cat([
368
- rgb_tensor.unsqueeze(0), # (1, 3, H, W)
369
- src_age.unsqueeze(1), # (1, 1, H, W)
370
- tgt_age.unsqueeze(1) # (1, 1, H, W)
371
- ], dim=1).to(DEVICE) # Final: (1, 5, H, W)
372
-
373
  if USE_FP16:
374
  cond_input = cond_input.half()
375
 
376
  with torch.no_grad():
377
  aging_net = load_aging_model()
378
- raw_output = aging_net(cond_input).squeeze(0) # (3, H, W)
379
 
380
  alpha = WRINKLE_STRENGTH
381
  blended = (1 - alpha) * rgb_tensor + alpha * raw_output
382
  blended = blended.clamp(0, 1)
383
-
384
  if USE_FP16:
385
  blended = blended.float()
386
 
@@ -388,14 +431,17 @@ def process_face_aging(input_image: Image.Image) -> Image.Image:
388
  final_aged = enhance_texture(final_aged)
389
  final_aged = post_correct_aged(orig, final_aged)
390
 
391
- logger.info(" Generating masks...")
 
392
  hair_mask, beard_mask = process_masks_parallel(final_aged)
393
 
394
- logger.info(" Applying white hair & beard...")
395
  final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask)
396
 
397
- if USE_GFPGAN and (ow * oh) < 2000000:
398
- logger.info(" Applying GFPGAN...")
 
 
399
  gfpgan = load_gfpgan()
400
  if gfpgan:
401
  try:
@@ -406,8 +452,10 @@ def process_face_aging(input_image: Image.Image) -> Image.Image:
406
  final_img = Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_BGR2RGB))
407
  except Exception as e:
408
  logger.warning(f"GFPGAN skipped: {e}")
 
 
409
 
410
- logger.info("✅ Processing completed!")
411
  gc.collect()
412
  return final_img
413
 
@@ -417,12 +465,17 @@ def process_face_aging(input_image: Image.Image) -> Image.Image:
417
  raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
418
 
419
  # ================================================
420
- # 9. FASTAPI
421
  # ================================================
422
  app = FastAPI(title="Face Aging API")
423
 
424
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
425
- allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
426
 
427
  @app.on_event("startup")
428
  async def startup_event():
@@ -440,20 +493,23 @@ async def age_face(file: UploadFile = File(...)):
440
  if not file.content_type.startswith("image/"):
441
  raise HTTPException(400, "Only image files allowed")
442
  contents = await file.read()
 
 
 
443
  try:
444
- input_image = Image.open(io.BytesIO(contents)).convert("RGB")
445
- loop = asyncio.get_event_loop()
446
- result = await loop.run_in_executor(executor, process_face_aging, input_image)
447
-
448
- buf = io.BytesIO()
449
- result.save(buf, format="JPEG", quality=90, optimize=True)
450
- buf.seek(0)
451
- return StreamingResponse(buf, media_type="image/jpeg")
452
- except Exception as e:
453
- logger.error(f"Endpoint error: {e}")
454
- raise HTTPException(500, f"Failed: {str(e)}")
455
- finally:
456
- gc.collect()
457
 
458
  if __name__ == "__main__":
459
  import uvicorn
 
9
  import traceback
10
  import gc
11
  import uuid
12
+ import time
13
+ import asyncio
14
+ from concurrent.futures import ThreadPoolExecutor
15
  from PIL import Image, ImageFilter, ImageEnhance
16
  from torchvision.transforms import functional as TF
17
  from scipy.ndimage import label
 
26
  from fastapi.responses import StreamingResponse
27
  from fastapi.middleware.cors import CORSMiddleware
28
  import io
 
 
29
  import logging
30
 
31
  logging.basicConfig(level=logging.INFO)
 
38
  BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt"
39
  GFPGAN_MODEL_PATH = "GFPGANv1.4.pth"
40
 
41
+ SAFE_IMG_SIZE = 384 # used only for aging model
42
  SOURCE_AGE = 20
43
  TARGET_AGE = 80
44
  WRINKLE_STRENGTH = 0.42
 
53
 
54
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  USE_FP16 = DEVICE.type == "cuda" and torch.cuda.is_available()
 
56
  logger.info(f"🚀 Device: {DEVICE}, FP16: {USE_FP16}")
57
 
58
  os.environ["HF_HOME"] = "/tmp/hf_cache"
 
199
  return age_model
200
 
201
  # ================================================
202
+ # 6. LOAD FACE PARSER & BEARD MODEL (optimized sizes)
203
  # ================================================
204
  def load_face_parser():
205
  global face_processor, face_parser
 
225
  return beard_model
226
 
227
  # ================================================
228
+ # 7. MASK FUNCTIONS (optimized: 256px inference)
229
  # ================================================
230
  def get_lips_mask(pil_image: Image.Image) -> np.ndarray:
231
  img_np = np.array(pil_image.resize((256, 256), Image.LANCZOS))
 
254
  def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
255
  temp = f"temp_{uuid.uuid4().hex[:8]}.jpg"
256
  try:
257
+ # OPTIMIZATION: use 256px instead of 384
258
+ pil_image.resize((256, 256), Image.LANCZOS).save(temp)
259
  model = load_beard_model()
260
+ res = model(temp, device=DEVICE.type, conf=0.3, iou=0.5, verbose=False, half=USE_FP16, imgsz=256)
261
  h, w = np.array(pil_image).shape[:2]
262
  mask = np.zeros((h,w), dtype=np.uint8)
263
  if res[0].masks is not None:
 
288
 
289
  def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray:
290
  processor, parser = load_face_parser()
291
+ # OPTIMIZATION: 256px instead of 384
292
+ img_r = pil_image.resize((256, 256), Image.LANCZOS)
293
  inputs = processor(images=img_r, return_tensors="pt").to(DEVICE)
294
  if USE_FP16: inputs['pixel_values'] = inputs['pixel_values'].half()
295
  with torch.no_grad():
 
347
  return h.result(), b.result()
348
 
349
  # ================================================
350
+ # 8. FAST FALLBACK (when full pipeline times out)
351
+ # ================================================
352
+ def fast_aging_fallback(input_image: Image.Image) -> Image.Image:
353
+ """Very fast aging: no GFPGAN, no beard mask, simplified hair mask."""
354
+ logger.info("⚡ Using fast fallback aging")
355
+ orig = input_image.convert("RGB")
356
+ ow, oh = orig.size
357
+ target_size = min(256, max(ow, oh))
358
+ if target_size % 2 == 0:
359
+ target_size -= 1
360
+
361
+ img_resized = orig.resize((target_size, target_size), Image.LANCZOS)
362
+ rgb_tensor = TF.to_tensor(img_resized)
363
+ src_age = torch.full((1, target_size, target_size), SOURCE_AGE / 100.0)
364
+ tgt_age = torch.full((1, target_size, target_size), TARGET_AGE / 100.0)
365
+ cond_input = torch.cat([
366
+ rgb_tensor.unsqueeze(0),
367
+ src_age.unsqueeze(1),
368
+ tgt_age.unsqueeze(1)
369
+ ], dim=1).to(DEVICE)
370
+ if USE_FP16:
371
+ cond_input = cond_input.half()
372
+
373
+ with torch.no_grad():
374
+ raw_output = load_aging_model()(cond_input).squeeze(0)
375
+
376
+ alpha = WRINKLE_STRENGTH
377
+ blended = (1 - alpha) * rgb_tensor + alpha * raw_output
378
+ blended = blended.clamp(0, 1).float() if USE_FP16 else blended
379
+
380
+ aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS)
381
+ aged = enhance_texture(aged)
382
+ aged = post_correct_aged(orig, aged)
383
+
384
+ # Simple luminance-based hair whitening (no segmentation)
385
+ gray = np.array(aged.convert('L')) / 255.0
386
+ hair_mask = (gray > 0.65).astype(np.float32)
387
+ hair_mask = cv2.GaussianBlur(hair_mask, (9,9), 3)
388
+ beard_mask = np.zeros_like(hair_mask)
389
+
390
+ final = apply_hair_and_beard_color(aged, hair_mask, beard_mask)
391
+ return final
392
+
393
+ # ================================================
394
+ # 9. MAIN PROCESSING (with optional time check)
395
  # ================================================
396
  def process_face_aging(input_image: Image.Image) -> Image.Image:
397
+ start_time = time.time()
398
  try:
399
  logger.info(f"→ Processing image: {input_image.size}")
400
  orig = input_image.convert("RGB")
401
  ow, oh = orig.size
402
 
403
+ # Aging model step
404
  target_size = min(SAFE_IMG_SIZE, max(ow, oh))
405
  if target_size % 2 == 1:
406
  target_size -= 1
407
 
408
  img_resized = orig.resize((target_size, target_size), Image.LANCZOS)
409
+ rgb_tensor = TF.to_tensor(img_resized)
 
410
  src_age = torch.full((1, target_size, target_size), SOURCE_AGE / 100.0)
411
  tgt_age = torch.full((1, target_size, target_size), TARGET_AGE / 100.0)
 
 
412
  cond_input = torch.cat([
413
+ rgb_tensor.unsqueeze(0),
414
+ src_age.unsqueeze(1),
415
+ tgt_age.unsqueeze(1)
416
+ ], dim=1).to(DEVICE)
 
417
  if USE_FP16:
418
  cond_input = cond_input.half()
419
 
420
  with torch.no_grad():
421
  aging_net = load_aging_model()
422
+ raw_output = aging_net(cond_input).squeeze(0)
423
 
424
  alpha = WRINKLE_STRENGTH
425
  blended = (1 - alpha) * rgb_tensor + alpha * raw_output
426
  blended = blended.clamp(0, 1)
 
427
  if USE_FP16:
428
  blended = blended.float()
429
 
 
431
  final_aged = enhance_texture(final_aged)
432
  final_aged = post_correct_aged(orig, final_aged)
433
 
434
+ # Masks (parallel)
435
+ logger.info("🔄 Generating masks...")
436
  hair_mask, beard_mask = process_masks_parallel(final_aged)
437
 
438
+ logger.info("🎨 Applying white hair & beard...")
439
  final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask)
440
 
441
+ # GFPGAN only if image is not too large and we have time left (> 2 sec)
442
+ elapsed = time.time() - start_time
443
+ if USE_GFPGAN and (ow * oh) < 1000000 and elapsed < 7.0:
444
+ logger.info("✨ Applying GFPGAN...")
445
  gfpgan = load_gfpgan()
446
  if gfpgan:
447
  try:
 
452
  final_img = Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_BGR2RGB))
453
  except Exception as e:
454
  logger.warning(f"GFPGAN skipped: {e}")
455
+ else:
456
+ logger.info("⏭️ Skipping GFPGAN (image too large or time low)")
457
 
458
+ logger.info(f"✅ Processing completed in {time.time()-start_time:.2f}s")
459
  gc.collect()
460
  return final_img
461
 
 
465
  raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
466
 
467
  # ================================================
468
+ # 10. FASTAPI WITH TIMEOUT
469
  # ================================================
470
  app = FastAPI(title="Face Aging API")
471
 
472
+ app.add_middleware(
473
+ CORSMiddleware,
474
+ allow_origins=["*"],
475
+ allow_credentials=True,
476
+ allow_methods=["*"],
477
+ allow_headers=["*"],
478
+ )
479
 
480
  @app.on_event("startup")
481
  async def startup_event():
 
493
  if not file.content_type.startswith("image/"):
494
  raise HTTPException(400, "Only image files allowed")
495
  contents = await file.read()
496
+ input_image = Image.open(io.BytesIO(contents)).convert("RGB")
497
+
498
+ loop = asyncio.get_event_loop()
499
  try:
500
+ # 9.5 second timeout for full pipeline
501
+ result = await asyncio.wait_for(
502
+ loop.run_in_executor(executor, process_face_aging, input_image),
503
+ timeout=9.5
504
+ )
505
+ except asyncio.TimeoutError:
506
+ logger.warning("⏱️ Full processing timeout – using fast fallback")
507
+ result = await loop.run_in_executor(executor, fast_aging_fallback, input_image)
508
+
509
+ buf = io.BytesIO()
510
+ result.save(buf, format="JPEG", quality=90, optimize=True)
511
+ buf.seek(0)
512
+ return StreamingResponse(buf, media_type="image/jpeg")
513
 
514
  if __name__ == "__main__":
515
  import uvicorn