import os import torch import numpy as np import cv2 import gc import time from PIL import Image, ImageFilter from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation from ultralytics import YOLO from fastapi import FastAPI, File, UploadFile from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import io import asyncio from concurrent.futures import ThreadPoolExecutor import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ====================== CONFIG ====================== BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt" SAFE_IMG_SIZE = 384 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using Device: {DEVICE}") logger.info(f"CUDA Available: {torch.cuda.is_available()}") if DEVICE.type == "cpu": torch.set_num_threads(4) torch.set_num_interop_threads(1) cv2.setNumThreads(4) else: torch.set_num_threads(1) os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" executor = ThreadPoolExecutor(max_workers=2) face_processor = None face_parser = None beard_model = None # ====================== TIMED DECORATOR ====================== def timed(name: str): def decorator(func): def wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) elapsed = (time.perf_counter() - start) * 1000 logger.info(f"{name}: {elapsed:.1f} ms") return result return wrapper return decorator # ====================== MODEL LOADING ====================== def load_face_parser(): global face_processor, face_parser if face_parser is not None: return logger.info("Loading Segformer Face Parser...") face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") face_parser.to(DEVICE) face_parser.eval() logger.info("✅ Face parser loaded") def load_beard_model(): global beard_model if beard_model is None: logger.info("Loading YOLO Beard Model...") beard_model = YOLO(BEARD_MODEL_PATH) return beard_model # ====================== MUSTACHE MASK ====================== @timed("Mustache Mask") def get_mustache_mask(probs, orig_w, orig_h, exclude_mask): u_lip = (probs[11].numpy() > 0.13).astype(np.float32) l_lip = (probs[12].numpy() > 0.13).astype(np.float32) mouth = (probs[10].numpy() > 0.18).astype(np.float32) mustache = np.maximum(u_lip * 1.15, l_lip) mustache = np.maximum(mustache, mouth * 0.45) kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 3)) kernel_e = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mustache = cv2.dilate(mustache, kernel_e, iterations=1) mustache = cv2.morphologyEx(mustache, cv2.MORPH_CLOSE, kernel_h, iterations=2) mustache = cv2.GaussianBlur(mustache, (7, 5), 1.2) shift_y = 1 M = np.float32([[1, 0, 0], [0, 1, shift_y]]) mustache = cv2.warpAffine(mustache, M, (mustache.shape[1], mustache.shape[0]), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0) mustache = cv2.resize(mustache, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) mustache = np.maximum(mustache - exclude_mask * 0.5, 0) mustache = cv2.GaussianBlur(mustache, (5, 5), 1.0) mustache = (mustache > 0.15).astype(np.float32) return mustache # ====================== HAIR + EXCLUDE + LIP MASK ====================== @timed("Hair + Exclude + Lip Mask") def get_hair_and_exclude_masks(pil_image: Image.Image): load_face_parser() orig_w, orig_h = pil_image.size img_small = pil_image.resize((128, 128), Image.BILINEAR) inputs = face_processor(images=img_small, return_tensors="pt").to(DEVICE) with torch.inference_mode(): out = face_parser(**inputs) logits = out.logits up = torch.nn.functional.interpolate(logits, size=(128, 128), mode="bilinear", align_corners=False) probs = torch.softmax(up, dim=1)[0].cpu() # Hair mask hair = (probs[13].numpy() > 0.035).astype(np.float32) hair = cv2.GaussianBlur(hair, (3, 3), 1.0) # Face mask parsing = up.argmax(dim=1).squeeze(0).cpu().numpy() face_cls = list(range(1,6)) + list(range(8,13)) + [17,18] face_m = np.isin(parsing, face_cls).astype(np.float32) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)) face_m = cv2.dilate(face_m, kernel, iterations=1) h, w = face_m.shape forehead = np.zeros_like(face_m, dtype=np.float32) forehead[:int(h * 0.32)] = 1.0 face_m = face_m * (1 - forehead * 0.45) hair = hair * (1 - face_m) hair = cv2.resize(hair, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) # Exclude mask exclude = np.zeros((128, 128), dtype=np.float32) exclude = np.maximum(exclude, (probs[10].numpy() > 0.35).astype(np.float32)) exclude = np.maximum(exclude, (probs[11].numpy() > 0.35).astype(np.float32)) exclude = np.maximum(exclude, (probs[12].numpy() > 0.35).astype(np.float32)) exclude = np.maximum(exclude, (probs[4].numpy() > 0.35).astype(np.float32)) exclude = np.maximum(exclude, (probs[5].numpy() > 0.35).astype(np.float32)) exclude = cv2.dilate(exclude, kernel, iterations=2) exclude = cv2.GaussianBlur(exclude, (5, 5), 1.2) exclude = cv2.resize(exclude, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) # Lip mask lip_mask = np.zeros((128, 128), dtype=np.float32) lip_mask = np.maximum(lip_mask, (probs[10].numpy() > 0.42).astype(np.float32)) lip_mask = np.maximum(lip_mask, (probs[11].numpy() > 0.42).astype(np.float32)) lip_mask = np.maximum(lip_mask, (probs[12].numpy() > 0.42).astype(np.float32)) lip_mask = cv2.dilate(lip_mask, kernel, iterations=1) lip_mask = cv2.resize(lip_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) lip_mask = (lip_mask > 0.5).astype(np.float32) mustache = get_mustache_mask(probs, orig_w, orig_h, exclude) return hair, exclude, mustache, lip_mask # ====================== BEARD MASK (FIXED: returns beard_present flag) ====================== @timed("Beard Mask") def get_beard_mask_fast(pil_image: Image.Image, exclude_mask: np.ndarray, lip_mask: np.ndarray): model = load_beard_model() orig_w, orig_h = pil_image.size img_small = pil_image.resize((128, 128), Image.BILINEAR) img_array = np.array(img_small) results = model.predict( img_array, device=DEVICE.type, conf=0.18, iou=0.45, imgsz=128, half=False, verbose=False, max_det=8 ) mask = np.zeros((orig_h, orig_w), dtype=np.float32) beard_present = False # <-- NEW FLAG if results[0].masks is not None: for i, cls in enumerate(results[0].boxes.cls): if int(cls) == 0: conf = results[0].boxes.conf[i].item() if conf > 0.25: # confidence threshold for considering a real beard beard_present = True m = results[0].masks.data[i].cpu().numpy() m = cv2.resize(m, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) mask = np.maximum(mask, (m > 0.25).astype(np.float32)) mask = np.maximum(mask - exclude_mask * 0.6, 0) # Only apply morphological refinements if beard is actually present if beard_present and mask.sum() > 25: kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) mask = cv2.erode(mask, kernel_erode, iterations=2) kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=3) kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_open, iterations=1) contours, _ = cv2.findContours((mask > 0.1).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if contours: smooth_mask = np.zeros_like(mask, dtype=np.float32) for cnt in contours: if cv2.contourArea(cnt) > 50: epsilon = 0.008 * cv2.arcLength(cnt, True) approx = cv2.approxPolyDP(cnt, epsilon, True) cv2.drawContours(smooth_mask, [approx], -1, 1.0, thickness=cv2.FILLED) mask = smooth_mask mask = cv2.GaussianBlur(mask, (9, 9), 2.0) mask = (mask > 0.28).astype(np.float32) mask = cv2.erode(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1) mask[lip_mask > 0] = 0 return mask, beard_present # <-- RETURN BOTH # ====================== COLOR TRANSFER - BEARD SAME AS HAIR ====================== @timed("Color Transfer") def apply_strong_grey_hair(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray): # Combine hair and beard masks combined_mask = np.maximum(hair_mask, beard_mask) if combined_mask.sum() < 100: combined_mask = cv2.GaussianBlur(combined_mask, (5,5), 1.5) img = np.array(image).astype(np.float32) / 255.0 hsv = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32) hsv_transformed = hsv.copy() hsv_transformed[..., 1] = hsv_transformed[..., 1] * (1 - 0.78 * combined_mask) original_v = hsv[..., 2] boost_amount = 89 * combined_mask hsv_transformed[..., 2] = np.clip( original_v + boost_amount - (original_v * 0.35 * combined_mask), 110, 210 ) transformed_rgb = cv2.cvtColor(hsv_transformed.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32) / 255.0 combined_mask_3ch = np.stack([combined_mask, combined_mask, combined_mask], axis=2) final = transformed_rgb * combined_mask_3ch + img * (1 - combined_mask_3ch) final = final + (np.array([9, 7, 5], dtype=np.float32) / 255.0 * combined_mask[..., None] * 0.18) final = np.clip(final * 255, 0, 255).astype(np.uint8) result = Image.fromarray(final) result = result.filter(ImageFilter.UnsharpMask(radius=0.8, percent=75, threshold=1)) return result # ====================== MAIN PROCESSING (FIXED: mustache only if beard detected) ====================== @timed("Total Processing") def process_face_whitening(input_image: Image.Image): orig = input_image.convert("RGB") ow, oh = orig.size target = min(SAFE_IMG_SIZE, max(ow, oh)) if target % 2 != 0: target -= 1 img_resized = orig.resize((target, target), Image.BILINEAR) hair_mask, exclude_mask, mustache_mask, lip_mask = get_hair_and_exclude_masks(img_resized) beard_mask, beard_present = get_beard_mask_fast(img_resized, exclude_mask, lip_mask) # ========== KEY FIX: Apply mustache ONLY if a beard is present ========== if beard_present: beard_mask = np.maximum(beard_mask, mustache_mask * 0.98) weak_mustache = (mustache_mask > 0.18) & (beard_mask < 0.48) beard_mask[weak_mustache] = np.maximum(beard_mask[weak_mustache], 0.75) beard_mask[lip_mask > 0] = 0 # else: no beard → mustache mask is ignored completely final_resized = apply_strong_grey_hair(img_resized, hair_mask, beard_mask) final_img = final_resized.resize((ow, oh), Image.LANCZOS) gc.collect() if DEVICE.type == "cuda": torch.cuda.empty_cache() return final_img # ====================== FASTAPI ====================== app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.on_event("startup") async def startup(): loop = asyncio.get_event_loop() await loop.run_in_executor(executor, load_face_parser) await loop.run_in_executor(executor, load_beard_model) logger.info("✅ Models loaded") logger.info("Running light warmup...") dummy = Image.new("RGB", (256, 256)) _ = process_face_whitening(dummy) logger.info("✅ Server Ready!") @app.post("/age-face") async def age_face(file: UploadFile = File(...)): start_total = time.perf_counter() contents = await file.read() img = Image.open(io.BytesIO(contents)).convert("RGB") loop = asyncio.get_event_loop() result = await loop.run_in_executor(executor, process_face_whitening, img) buf = io.BytesIO() result.save(buf, format="JPEG", quality=92, optimize=True) buf.seek(0) total_time = (time.perf_counter() - start_total) * 1000 logger.info(f"✅ Total Request Time: {total_time:.1f} ms") return StreamingResponse(buf, media_type="image/jpeg") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)