Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| import traceback | |
| import gc | |
| from PIL import Image, ImageFilter, ImageEnhance | |
| from torchvision.transforms import functional as TF | |
| from scipy.ndimage import label | |
| import antialiased_cnns | |
| import mediapipe as mp | |
| from skimage.exposure import match_histograms | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| from ultralytics import YOLO | |
| from gfpgan import GFPGANer | |
| import urllib.request | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import io | |
| # ========================= CONFIG ========================= | |
| AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth" | |
| BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt" | |
| GFPGAN_MODEL_PATH = "GFPGANv1.4.pth" | |
| SAFE_IMG_SIZE = 512 | |
| SOURCE_AGE = 20 | |
| TARGET_AGE = 80 | |
| WRINKLE_STRENGTH = 0.42 | |
| CONTRAST_BOOST = 1.10 | |
| SHARPNESS_BOOST = 1.20 | |
| ALPHA_HAIR = 0.95 | |
| BLUR_RADIUS = 7 | |
| EDGE_SMOOTHING = True | |
| USE_GFPGAN = True | |
| GFPGAN_UPSCALE = 1 | |
| GFPGAN_WEIGHT = 0.5 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"🚀 Device: {DEVICE}") | |
| if DEVICE.type == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.makedirs("/tmp/hf_cache", exist_ok=True) | |
| # Global models (Lazy Loading) | |
| age_model = None | |
| face_processor = None | |
| face_parser = None | |
| beard_model = None | |
| gfpgan_restorer = None | |
| mp_face_mesh = mp.solutions.face_mesh.FaceMesh( | |
| static_image_mode=True, | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5 | |
| ) | |
| # ================== DOWNLOAD HELPER ================== | |
| def download_file(url, filename): | |
| if os.path.exists(filename): | |
| print(f"✅ {filename} already exists.") | |
| return True | |
| print(f"🔄 Downloading {filename}... (~350 MB)") | |
| try: | |
| urllib.request.urlretrieve(url, filename) | |
| print(f"✅ Download completed: {filename}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Download failed: {e}") | |
| return False | |
| # ================== LOAD MODELS (Safer - No torch.compile) ================== | |
| def load_aging_model(): | |
| global age_model | |
| if age_model is not None: | |
| return age_model | |
| print("Loading UNet aging model...") | |
| class DownLayer(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.layer = nn.Sequential( | |
| nn.MaxPool2d(2, stride=1), | |
| antialiased_cnns.BlurPool(in_ch, stride=2), | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
| nn.LeakyReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.layer(x) | |
| class UpLayer(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.blur_upsample = nn.Sequential( | |
| nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), | |
| antialiased_cnns.BlurPool(out_ch, stride=1) | |
| ) | |
| self.layer = nn.Sequential( | |
| nn.Conv2d(out_ch * 2, out_ch, 3, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
| nn.LeakyReLU(inplace=True) | |
| ) | |
| def forward(self, x, skip): | |
| x = self.blur_upsample(x) | |
| x = torch.cat([x, skip], dim=1) | |
| return self.layer(x) | |
| class UNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.init_conv = nn.Sequential( | |
| nn.Conv2d(5, 64, 3, padding=1), | |
| nn.LeakyReLU(inplace=True), | |
| nn.Conv2d(64, 64, 3, padding=1), | |
| nn.LeakyReLU(inplace=True) | |
| ) | |
| self.down1 = DownLayer(64, 128) | |
| self.down2 = DownLayer(128, 256) | |
| self.down3 = DownLayer(256, 512) | |
| self.down4 = DownLayer(512, 1024) | |
| self.up1 = UpLayer(1024, 512) | |
| self.up2 = UpLayer(512, 256) | |
| self.up3 = UpLayer(256, 128) | |
| self.up4 = UpLayer(128, 64) | |
| self.final_conv = nn.Conv2d(64, 3, 1) | |
| def forward(self, x): | |
| x0 = self.init_conv(x) | |
| x1 = self.down1(x0) | |
| x2 = self.down2(x1) | |
| x3 = self.down3(x2) | |
| x4 = self.down4(x3) | |
| x = self.up1(x4, x3) | |
| x = self.up2(x, x2) | |
| x = self.up3(x, x1) | |
| x = self.up4(x, x0) | |
| return self.final_conv(x) | |
| age_model = UNet().to(DEVICE) | |
| state = torch.load(AGING_MODEL_PATH, map_location=DEVICE, weights_only=True) | |
| age_model.load_state_dict(state) | |
| age_model.eval() | |
| print("✅ Aging model loaded!") | |
| return age_model | |
| def load_face_parser(): | |
| global face_processor, face_parser | |
| if face_parser is not None: | |
| return face_processor, face_parser | |
| print("Loading Segformer face-parsing...") | |
| face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") | |
| face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to(DEVICE) | |
| face_parser.eval() | |
| print("✅ Face parser loaded!") | |
| return face_processor, face_parser | |
| def load_beard_model(): | |
| global beard_model | |
| if beard_model is None: | |
| print("Loading Beard Detection Model (YOLO)...") | |
| beard_model = YOLO(BEARD_MODEL_PATH) | |
| return beard_model | |
| def load_gfpgan(): | |
| global gfpgan_restorer | |
| if gfpgan_restorer is not None: | |
| return gfpgan_restorer | |
| if not os.path.exists(GFPGAN_MODEL_PATH): | |
| model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' | |
| download_file(model_url, GFPGAN_MODEL_PATH) | |
| print("🔄 Loading GFPGAN v1.4...") | |
| try: | |
| gfpgan_restorer = GFPGANer( | |
| model_path=GFPGAN_MODEL_PATH, | |
| upscale=GFPGAN_UPSCALE, | |
| arch='clean', | |
| channel_multiplier=2, | |
| bg_upsampler=None, | |
| device=DEVICE | |
| ) | |
| print("✅ GFPGAN loaded successfully!") | |
| return gfpgan_restorer | |
| except Exception as e: | |
| print(f"❌ GFPGAN load failed: {e}") | |
| return None | |
| # ================== MASK FUNCTIONS ================== | |
| def get_lips_mask(pil_image: Image.Image) -> np.ndarray: | |
| img_np = np.array(pil_image) | |
| h, w = img_np.shape[:2] | |
| lips_mask = np.zeros((h, w), dtype=np.uint8) | |
| rgb_image = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| results = mp_face_mesh.process(rgb_image) | |
| if results.multi_face_landmarks: | |
| for face_landmarks in results.multi_face_landmarks: | |
| lip_landmarks = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95] | |
| points = [] | |
| for idx in lip_landmarks: | |
| landmark = face_landmarks.landmark[idx] | |
| x = int(landmark.x * w) | |
| y = int(landmark.y * h) | |
| points.append([x, y]) | |
| if points: | |
| points_np = np.array(points, np.int32) | |
| hull = cv2.convexHull(points_np) | |
| cv2.fillConvexPoly(lips_mask, hull, 255) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| lips_mask = cv2.dilate(lips_mask, kernel, iterations=2) | |
| lips_mask = cv2.GaussianBlur(lips_mask.astype(np.float32), (15, 15), 4) | |
| lips_mask = np.clip(lips_mask / 255.0, 0, 1) | |
| return lips_mask | |
| return np.zeros((h, w), dtype=np.float32) | |
| def exclude_lips_from_mask(beard_mask: np.ndarray, pil_image: Image.Image) -> np.ndarray: | |
| if np.sum(beard_mask) == 0: | |
| return beard_mask | |
| lips_mask = get_lips_mask(pil_image) | |
| lips_region = (lips_mask > 0.3).astype(np.float32) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| lips_region = cv2.dilate(lips_region, kernel, iterations=1) | |
| beard_mask = beard_mask * (1.0 - lips_region) | |
| beard_mask = cv2.GaussianBlur(beard_mask, (5, 5), 1) | |
| return beard_mask | |
| def get_beard_mask(pil_image: Image.Image) -> np.ndarray: | |
| temp_path = "temp_input.jpg" | |
| try: | |
| pil_image.save(temp_path) | |
| model = load_beard_model() | |
| results = model(temp_path, device=DEVICE.type, conf=0.25, iou=0.5, verbose=False, | |
| half=True if DEVICE.type == "cuda" else False) | |
| img_np = np.array(pil_image) | |
| h, w = img_np.shape[:2] | |
| beard_mask = np.zeros((h, w), dtype=np.uint8) | |
| if results[0].masks is not None: | |
| for i, cls in enumerate(results[0].boxes.cls): | |
| if int(cls) == 0: # beard class | |
| mask = results[0].masks.data[i].cpu().numpy() | |
| mask = cv2.resize(mask, (w, h)) | |
| mask = (mask > 0.4).astype(np.uint8) * 255 | |
| beard_mask = cv2.bitwise_or(beard_mask, mask) | |
| if np.sum(beard_mask) > 0: | |
| beard_mask_float = beard_mask.astype(np.float32) / 255.0 | |
| beard_mask_float = cv2.dilate(beard_mask_float, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2) | |
| beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) | |
| beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)), iterations=2) | |
| beard_mask_float = exclude_lips_from_mask(beard_mask_float, pil_image) | |
| beard_mask_float = cv2.GaussianBlur(beard_mask_float, (7, 7), 2) | |
| beard_mask_float = np.clip(beard_mask_float, 0, 1) | |
| return beard_mask_float | |
| return np.zeros((h, w), dtype=np.float32) | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| def clean_mask(mask, min_area=150): | |
| mask = mask.astype(np.uint8) | |
| labeled, num = label(mask) | |
| new_mask = np.zeros_like(mask) | |
| for i in range(1, num + 1): | |
| if np.sum(labeled == i) >= min_area: | |
| new_mask[labeled == i] = 1 | |
| return new_mask | |
| def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray: | |
| processor, parser = load_face_parser() | |
| inputs = processor(images=pil_image, return_tensors="pt").to(DEVICE) | |
| with torch.inference_mode(): | |
| outputs = parser(**inputs) | |
| logits = outputs.logits | |
| upsampled = torch.nn.functional.interpolate(logits, size=pil_image.size[::-1], mode="bilinear", align_corners=False) | |
| probs = torch.softmax(upsampled, dim=1)[0] | |
| hair_prob = probs[13].cpu().numpy() | |
| hair_mask = (hair_prob > 0.12).astype(np.uint8) | |
| face_classes = list(range(1, 6)) + list(range(8, 13)) + [17, 18] | |
| parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() | |
| face_mask = np.isin(parsing, face_classes).astype(np.uint8) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| face_mask = cv2.dilate(face_mask, kernel, iterations=1) | |
| hair_mask = hair_mask * (1 - face_mask) | |
| hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) | |
| hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)), iterations=2) | |
| hair_mask = clean_mask(hair_mask, min_area=100) | |
| hair_mask = cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 1.5) | |
| hair_mask = np.clip(hair_mask, 0, 1) | |
| return hair_mask | |
| def apply_hair_and_beard_color(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray): | |
| combined_mask = np.maximum(hair_mask, beard_mask) | |
| if np.sum(combined_mask) == 0: | |
| return image | |
| combined_mask = cv2.GaussianBlur(combined_mask, (BLUR_RADIUS*2+1, BLUR_RADIUS*2+1), BLUR_RADIUS) | |
| combined_mask = np.clip(combined_mask, 0, 1) | |
| if EDGE_SMOOTHING: | |
| combined_mask = cv2.bilateralFilter(combined_mask.astype(np.float32), 9, 75, 75) | |
| combined_mask = np.clip(combined_mask, 0, 1) | |
| combined_mask = np.clip(combined_mask * 1.2, 0, 1) | |
| img_np = np.array(image).astype(np.float32) | |
| target_color = np.array([255, 255, 255], dtype=np.float32) | |
| gray = cv2.cvtColor(img_np.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 | |
| lum_factor = 0.6 + 0.4 * gray | |
| white_layer = target_color * lum_factor[..., np.newaxis] | |
| alpha = ALPHA_HAIR | |
| result = (1 - alpha * combined_mask[..., np.newaxis]) * img_np + (alpha * combined_mask[..., np.newaxis]) * white_layer | |
| result = np.clip(result, 0, 255).astype(np.uint8) | |
| result_pil = Image.fromarray(result) | |
| result_pil = result_pil.filter(ImageFilter.UnsharpMask(1.2, 140, 2)) | |
| return result_pil | |
| def post_correct_aged(original: Image.Image, aged: Image.Image) -> Image.Image: | |
| orig_np = np.array(original) | |
| aged_np = np.array(aged) | |
| matched = match_histograms(aged_np, orig_np, channel_axis=-1) | |
| matched_img = Image.fromarray(np.clip(matched, 0, 255).astype(np.uint8)) | |
| matched_img = ImageEnhance.Brightness(matched_img).enhance(1.10) | |
| matched_img = ImageEnhance.Contrast(matched_img).enhance(1.06) | |
| return matched_img | |
| def enhance_texture(img: Image.Image) -> Image.Image: | |
| img = img.filter(ImageFilter.UnsharpMask(2, 160, 3)) | |
| img = ImageEnhance.Contrast(img).enhance(CONTRAST_BOOST) | |
| img = ImageEnhance.Sharpness(img).enhance(SHARPNESS_BOOST) | |
| return img | |
| # ================== MAIN PROCESSING FUNCTION (Memory Safe) ================== | |
| def process_face_aging(input_image: Image.Image) -> Image.Image: | |
| if input_image is None: | |
| raise ValueError("Please provide a valid image!") | |
| try: | |
| print(f"→ Processing image: {input_image.size}") | |
| orig = input_image.convert("RGB") | |
| ow, oh = orig.size | |
| img_resized = orig.resize((SAFE_IMG_SIZE, SAFE_IMG_SIZE), Image.LANCZOS) | |
| rgb_tensor = TF.to_tensor(img_resized).to(DEVICE) | |
| src_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), SOURCE_AGE / 100.0, device=DEVICE) | |
| tgt_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), TARGET_AGE / 100.0, device=DEVICE) | |
| cond_input = torch.cat([rgb_tensor.unsqueeze(0), src_age.unsqueeze(0), tgt_age.unsqueeze(0)], dim=1) | |
| # Aging Model | |
| with torch.inference_mode(): | |
| aging_net = load_aging_model() | |
| raw_output = aging_net(cond_input).squeeze(0).float() | |
| alpha = WRINKLE_STRENGTH | |
| blended = (1 - alpha) * rgb_tensor + alpha * raw_output | |
| blended = blended.clamp(0, 1) | |
| final_aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS) | |
| final_aged = enhance_texture(final_aged) | |
| final_aged = post_correct_aged(orig, final_aged) | |
| # Hair & Beard | |
| print(" Generating hair mask...") | |
| hair_mask = get_hair_mask_segformer(final_aged) | |
| print(" Generating beard mask...") | |
| beard_mask = get_beard_mask(final_aged) | |
| print(" Applying white hair & beard...") | |
| final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask) | |
| # GFPGAN | |
| if USE_GFPGAN: | |
| print(" Applying GFPGAN face restoration...") | |
| gfpgan = load_gfpgan() | |
| if gfpgan is not None: | |
| try: | |
| img_cv = cv2.cvtColor(np.array(final_img), cv2.COLOR_RGB2BGR) | |
| with torch.inference_mode(): | |
| _, _, restored_cv = gfpgan.enhance( | |
| img_cv, has_aligned=False, only_center_face=False, | |
| paste_back=True, weight=GFPGAN_WEIGHT | |
| ) | |
| final_img = Image.fromarray(cv2.cvtColor(restored_cv, cv2.COLOR_BGR2RGB)) | |
| except Exception as e: | |
| print(f" GFPGAN error (skipped): {e}") | |
| print("✓ Processing completed!") | |
| return final_img | |
| except Exception as e: | |
| print(f"❌ Error: {str(e)}") | |
| traceback.print_exc() | |
| if DEVICE.type == "cuda": | |
| print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") | |
| raise | |
| finally: | |
| gc.collect() | |
| if DEVICE.type == "cuda": | |
| torch.cuda.empty_cache() | |
| # ================== FASTAPI SETUP ================== | |
| app = FastAPI(title="Face Aging + White Hair & Beard API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def age_face(file: UploadFile = File(...)): | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Only image files allowed") | |
| contents = await file.read() | |
| try: | |
| input_image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| result_image = process_face_aging(input_image) | |
| img_byte_arr = io.BytesIO() | |
| result_image.save(img_byte_arr, format="PNG") | |
| img_byte_arr.seek(0) | |
| return StreamingResponse(img_byte_arr, media_type="image/png") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") | |
| finally: | |
| gc.collect() | |
| if DEVICE.type == "cuda": | |
| torch.cuda.empty_cache() | |
| # For local testing | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting FastAPI server...") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |