import os import torch import torch.nn as nn import numpy as np import cv2 import traceback import gc from PIL import Image, ImageFilter, ImageEnhance from torchvision.transforms import functional as TF from scipy.ndimage import label import antialiased_cnns import mediapipe as mp from skimage.exposure import match_histograms from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation from ultralytics import YOLO from gfpgan import GFPGANer import urllib.request from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import io # ========================= CONFIG ========================= AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth" BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt" GFPGAN_MODEL_PATH = "GFPGANv1.4.pth" SAFE_IMG_SIZE = 512 SOURCE_AGE = 20 TARGET_AGE = 80 WRINKLE_STRENGTH = 0.42 CONTRAST_BOOST = 1.10 SHARPNESS_BOOST = 1.20 ALPHA_HAIR = 0.95 BLUR_RADIUS = 7 EDGE_SMOOTHING = True USE_GFPGAN = True GFPGAN_UPSCALE = 1 GFPGAN_WEIGHT = 0.5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🚀 Device: {DEVICE}") if DEVICE.type == "cuda": torch.backends.cudnn.benchmark = True os.environ["HF_HOME"] = "/tmp/hf_cache" os.makedirs("/tmp/hf_cache", exist_ok=True) # Global models (Lazy Loading) age_model = None face_processor = None face_parser = None beard_model = None gfpgan_restorer = None mp_face_mesh = mp.solutions.face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5 ) # ================== DOWNLOAD HELPER ================== def download_file(url, filename): if os.path.exists(filename): print(f"✅ {filename} already exists.") return True print(f"🔄 Downloading {filename}... (~350 MB)") try: urllib.request.urlretrieve(url, filename) print(f"✅ Download completed: {filename}") return True except Exception as e: print(f"❌ Download failed: {e}") return False # ================== LOAD MODELS (Safer - No torch.compile) ================== def load_aging_model(): global age_model if age_model is not None: return age_model print("Loading UNet aging model...") class DownLayer(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.layer = nn.Sequential( nn.MaxPool2d(2, stride=1), antialiased_cnns.BlurPool(in_ch, stride=2), nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.LeakyReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.LeakyReLU(inplace=True) ) def forward(self, x): return self.layer(x) class UpLayer(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.blur_upsample = nn.Sequential( nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), antialiased_cnns.BlurPool(out_ch, stride=1) ) self.layer = nn.Sequential( nn.Conv2d(out_ch * 2, out_ch, 3, padding=1), nn.LeakyReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.LeakyReLU(inplace=True) ) def forward(self, x, skip): x = self.blur_upsample(x) x = torch.cat([x, skip], dim=1) return self.layer(x) class UNet(nn.Module): def __init__(self): super().__init__() self.init_conv = nn.Sequential( nn.Conv2d(5, 64, 3, padding=1), nn.LeakyReLU(inplace=True), nn.Conv2d(64, 64, 3, padding=1), nn.LeakyReLU(inplace=True) ) self.down1 = DownLayer(64, 128) self.down2 = DownLayer(128, 256) self.down3 = DownLayer(256, 512) self.down4 = DownLayer(512, 1024) self.up1 = UpLayer(1024, 512) self.up2 = UpLayer(512, 256) self.up3 = UpLayer(256, 128) self.up4 = UpLayer(128, 64) self.final_conv = nn.Conv2d(64, 3, 1) def forward(self, x): x0 = self.init_conv(x) x1 = self.down1(x0) x2 = self.down2(x1) x3 = self.down3(x2) x4 = self.down4(x3) x = self.up1(x4, x3) x = self.up2(x, x2) x = self.up3(x, x1) x = self.up4(x, x0) return self.final_conv(x) age_model = UNet().to(DEVICE) state = torch.load(AGING_MODEL_PATH, map_location=DEVICE, weights_only=True) age_model.load_state_dict(state) age_model.eval() print("✅ Aging model loaded!") return age_model def load_face_parser(): global face_processor, face_parser if face_parser is not None: return face_processor, face_parser print("Loading Segformer face-parsing...") face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to(DEVICE) face_parser.eval() print("✅ Face parser loaded!") return face_processor, face_parser def load_beard_model(): global beard_model if beard_model is None: print("Loading Beard Detection Model (YOLO)...") beard_model = YOLO(BEARD_MODEL_PATH) return beard_model def load_gfpgan(): global gfpgan_restorer if gfpgan_restorer is not None: return gfpgan_restorer if not os.path.exists(GFPGAN_MODEL_PATH): model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' download_file(model_url, GFPGAN_MODEL_PATH) print("🔄 Loading GFPGAN v1.4...") try: gfpgan_restorer = GFPGANer( model_path=GFPGAN_MODEL_PATH, upscale=GFPGAN_UPSCALE, arch='clean', channel_multiplier=2, bg_upsampler=None, device=DEVICE ) print("✅ GFPGAN loaded successfully!") return gfpgan_restorer except Exception as e: print(f"❌ GFPGAN load failed: {e}") return None # ================== MASK FUNCTIONS ================== def get_lips_mask(pil_image: Image.Image) -> np.ndarray: img_np = np.array(pil_image) h, w = img_np.shape[:2] lips_mask = np.zeros((h, w), dtype=np.uint8) rgb_image = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) results = mp_face_mesh.process(rgb_image) if results.multi_face_landmarks: for face_landmarks in results.multi_face_landmarks: lip_landmarks = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] points = [] for idx in lip_landmarks: landmark = face_landmarks.landmark[idx] x = int(landmark.x * w) y = int(landmark.y * h) points.append([x, y]) if points: points_np = np.array(points, np.int32) hull = cv2.convexHull(points_np) cv2.fillConvexPoly(lips_mask, hull, 255) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) lips_mask = cv2.dilate(lips_mask, kernel, iterations=2) lips_mask = cv2.GaussianBlur(lips_mask.astype(np.float32), (15, 15), 4) lips_mask = np.clip(lips_mask / 255.0, 0, 1) return lips_mask return np.zeros((h, w), dtype=np.float32) def exclude_lips_from_mask(beard_mask: np.ndarray, pil_image: Image.Image) -> np.ndarray: if np.sum(beard_mask) == 0: return beard_mask lips_mask = get_lips_mask(pil_image) lips_region = (lips_mask > 0.3).astype(np.float32) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) lips_region = cv2.dilate(lips_region, kernel, iterations=1) beard_mask = beard_mask * (1.0 - lips_region) beard_mask = cv2.GaussianBlur(beard_mask, (5, 5), 1) return beard_mask def get_beard_mask(pil_image: Image.Image) -> np.ndarray: temp_path = "temp_input.jpg" try: pil_image.save(temp_path) model = load_beard_model() results = model(temp_path, device=DEVICE.type, conf=0.25, iou=0.5, verbose=False, half=True if DEVICE.type == "cuda" else False) img_np = np.array(pil_image) h, w = img_np.shape[:2] beard_mask = np.zeros((h, w), dtype=np.uint8) if results[0].masks is not None: for i, cls in enumerate(results[0].boxes.cls): if int(cls) == 0: # beard class mask = results[0].masks.data[i].cpu().numpy() mask = cv2.resize(mask, (w, h)) mask = (mask > 0.4).astype(np.uint8) * 255 beard_mask = cv2.bitwise_or(beard_mask, mask) if np.sum(beard_mask) > 0: beard_mask_float = beard_mask.astype(np.float32) / 255.0 beard_mask_float = cv2.dilate(beard_mask_float, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2) beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)), iterations=2) beard_mask_float = exclude_lips_from_mask(beard_mask_float, pil_image) beard_mask_float = cv2.GaussianBlur(beard_mask_float, (7, 7), 2) beard_mask_float = np.clip(beard_mask_float, 0, 1) return beard_mask_float return np.zeros((h, w), dtype=np.float32) finally: if os.path.exists(temp_path): os.remove(temp_path) def clean_mask(mask, min_area=150): mask = mask.astype(np.uint8) labeled, num = label(mask) new_mask = np.zeros_like(mask) for i in range(1, num + 1): if np.sum(labeled == i) >= min_area: new_mask[labeled == i] = 1 return new_mask def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray: processor, parser = load_face_parser() inputs = processor(images=pil_image, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = parser(**inputs) logits = outputs.logits upsampled = torch.nn.functional.interpolate(logits, size=pil_image.size[::-1], mode="bilinear", align_corners=False) probs = torch.softmax(upsampled, dim=1)[0] hair_prob = probs[13].cpu().numpy() hair_mask = (hair_prob > 0.12).astype(np.uint8) face_classes = list(range(1, 6)) + list(range(8, 13)) + [17, 18] parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() face_mask = np.isin(parsing, face_classes).astype(np.uint8) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) face_mask = cv2.dilate(face_mask, kernel, iterations=1) hair_mask = hair_mask * (1 - face_mask) hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)), iterations=2) hair_mask = clean_mask(hair_mask, min_area=100) hair_mask = cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 1.5) hair_mask = np.clip(hair_mask, 0, 1) return hair_mask def apply_hair_and_beard_color(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray): combined_mask = np.maximum(hair_mask, beard_mask) if np.sum(combined_mask) == 0: return image combined_mask = cv2.GaussianBlur(combined_mask, (BLUR_RADIUS*2+1, BLUR_RADIUS*2+1), BLUR_RADIUS) combined_mask = np.clip(combined_mask, 0, 1) if EDGE_SMOOTHING: combined_mask = cv2.bilateralFilter(combined_mask.astype(np.float32), 9, 75, 75) combined_mask = np.clip(combined_mask, 0, 1) combined_mask = np.clip(combined_mask * 1.2, 0, 1) img_np = np.array(image).astype(np.float32) target_color = np.array([255, 255, 255], dtype=np.float32) gray = cv2.cvtColor(img_np.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 lum_factor = 0.6 + 0.4 * gray white_layer = target_color * lum_factor[..., np.newaxis] alpha = ALPHA_HAIR result = (1 - alpha * combined_mask[..., np.newaxis]) * img_np + (alpha * combined_mask[..., np.newaxis]) * white_layer result = np.clip(result, 0, 255).astype(np.uint8) result_pil = Image.fromarray(result) result_pil = result_pil.filter(ImageFilter.UnsharpMask(1.2, 140, 2)) return result_pil def post_correct_aged(original: Image.Image, aged: Image.Image) -> Image.Image: orig_np = np.array(original) aged_np = np.array(aged) matched = match_histograms(aged_np, orig_np, channel_axis=-1) matched_img = Image.fromarray(np.clip(matched, 0, 255).astype(np.uint8)) matched_img = ImageEnhance.Brightness(matched_img).enhance(1.10) matched_img = ImageEnhance.Contrast(matched_img).enhance(1.06) return matched_img def enhance_texture(img: Image.Image) -> Image.Image: img = img.filter(ImageFilter.UnsharpMask(2, 160, 3)) img = ImageEnhance.Contrast(img).enhance(CONTRAST_BOOST) img = ImageEnhance.Sharpness(img).enhance(SHARPNESS_BOOST) return img # ================== MAIN PROCESSING FUNCTION (Memory Safe) ================== def process_face_aging(input_image: Image.Image) -> Image.Image: if input_image is None: raise ValueError("Please provide a valid image!") try: print(f"→ Processing image: {input_image.size}") orig = input_image.convert("RGB") ow, oh = orig.size img_resized = orig.resize((SAFE_IMG_SIZE, SAFE_IMG_SIZE), Image.LANCZOS) rgb_tensor = TF.to_tensor(img_resized).to(DEVICE) src_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), SOURCE_AGE / 100.0, device=DEVICE) tgt_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), TARGET_AGE / 100.0, device=DEVICE) cond_input = torch.cat([rgb_tensor.unsqueeze(0), src_age.unsqueeze(0), tgt_age.unsqueeze(0)], dim=1) # Aging Model with torch.inference_mode(): aging_net = load_aging_model() raw_output = aging_net(cond_input).squeeze(0).float() alpha = WRINKLE_STRENGTH blended = (1 - alpha) * rgb_tensor + alpha * raw_output blended = blended.clamp(0, 1) final_aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS) final_aged = enhance_texture(final_aged) final_aged = post_correct_aged(orig, final_aged) # Hair & Beard print(" Generating hair mask...") hair_mask = get_hair_mask_segformer(final_aged) print(" Generating beard mask...") beard_mask = get_beard_mask(final_aged) print(" Applying white hair & beard...") final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask) # GFPGAN if USE_GFPGAN: print(" Applying GFPGAN face restoration...") gfpgan = load_gfpgan() if gfpgan is not None: try: img_cv = cv2.cvtColor(np.array(final_img), cv2.COLOR_RGB2BGR) with torch.inference_mode(): _, _, restored_cv = gfpgan.enhance( img_cv, has_aligned=False, only_center_face=False, paste_back=True, weight=GFPGAN_WEIGHT ) final_img = Image.fromarray(cv2.cvtColor(restored_cv, cv2.COLOR_BGR2RGB)) except Exception as e: print(f" GFPGAN error (skipped): {e}") print("✓ Processing completed!") return final_img except Exception as e: print(f"❌ Error: {str(e)}") traceback.print_exc() if DEVICE.type == "cuda": print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") raise finally: gc.collect() if DEVICE.type == "cuda": torch.cuda.empty_cache() # ================== FASTAPI SETUP ================== app = FastAPI(title="Face Aging + White Hair & Beard API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/age-face") async def age_face(file: UploadFile = File(...)): if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Only image files allowed") contents = await file.read() try: input_image = Image.open(io.BytesIO(contents)).convert("RGB") result_image = process_face_aging(input_image) img_byte_arr = io.BytesIO() result_image.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") finally: gc.collect() if DEVICE.type == "cuda": torch.cuda.empty_cache() # For local testing if __name__ == "__main__": import uvicorn print("Starting FastAPI server...") uvicorn.run(app, host="0.0.0.0", port=7860)