OldServer1 / app.py
Seniordev22's picture
Update app.py
dacc8dc verified
import os
import torch
import torch.nn as nn
import numpy as np
import cv2
import traceback
import gc
from PIL import Image, ImageFilter, ImageEnhance
from torchvision.transforms import functional as TF
from scipy.ndimage import label
import antialiased_cnns
import mediapipe as mp
from skimage.exposure import match_histograms
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from ultralytics import YOLO
from gfpgan import GFPGANer
import urllib.request
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import io
# ========================= CONFIG =========================
AGING_MODEL_PATH = "face_aging_model/best_unet_model.pth"
BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt"
GFPGAN_MODEL_PATH = "GFPGANv1.4.pth"
SAFE_IMG_SIZE = 512
SOURCE_AGE = 20
TARGET_AGE = 80
WRINKLE_STRENGTH = 0.42
CONTRAST_BOOST = 1.10
SHARPNESS_BOOST = 1.20
ALPHA_HAIR = 0.95
BLUR_RADIUS = 7
EDGE_SMOOTHING = True
USE_GFPGAN = True
GFPGAN_UPSCALE = 1
GFPGAN_WEIGHT = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Device: {DEVICE}")
if DEVICE.type == "cuda":
torch.backends.cudnn.benchmark = True
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.makedirs("/tmp/hf_cache", exist_ok=True)
# Global models (Lazy Loading)
age_model = None
face_processor = None
face_parser = None
beard_model = None
gfpgan_restorer = None
mp_face_mesh = mp.solutions.face_mesh.FaceMesh(
static_image_mode=True,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5
)
# ================== DOWNLOAD HELPER ==================
def download_file(url, filename):
if os.path.exists(filename):
print(f"✅ {filename} already exists.")
return True
print(f"🔄 Downloading {filename}... (~350 MB)")
try:
urllib.request.urlretrieve(url, filename)
print(f"✅ Download completed: {filename}")
return True
except Exception as e:
print(f"❌ Download failed: {e}")
return False
# ================== LOAD MODELS (Safer - No torch.compile) ==================
def load_aging_model():
global age_model
if age_model is not None:
return age_model
print("Loading UNet aging model...")
class DownLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.layer = nn.Sequential(
nn.MaxPool2d(2, stride=1),
antialiased_cnns.BlurPool(in_ch, stride=2),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.LeakyReLU(inplace=True)
)
def forward(self, x):
return self.layer(x)
class UpLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.blur_upsample = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
antialiased_cnns.BlurPool(out_ch, stride=1)
)
self.layer = nn.Sequential(
nn.Conv2d(out_ch * 2, out_ch, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.LeakyReLU(inplace=True)
)
def forward(self, x, skip):
x = self.blur_upsample(x)
x = torch.cat([x, skip], dim=1)
return self.layer(x)
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.init_conv = nn.Sequential(
nn.Conv2d(5, 64, 3, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.LeakyReLU(inplace=True)
)
self.down1 = DownLayer(64, 128)
self.down2 = DownLayer(128, 256)
self.down3 = DownLayer(256, 512)
self.down4 = DownLayer(512, 1024)
self.up1 = UpLayer(1024, 512)
self.up2 = UpLayer(512, 256)
self.up3 = UpLayer(256, 128)
self.up4 = UpLayer(128, 64)
self.final_conv = nn.Conv2d(64, 3, 1)
def forward(self, x):
x0 = self.init_conv(x)
x1 = self.down1(x0)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.up1(x4, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.up4(x, x0)
return self.final_conv(x)
age_model = UNet().to(DEVICE)
state = torch.load(AGING_MODEL_PATH, map_location=DEVICE, weights_only=True)
age_model.load_state_dict(state)
age_model.eval()
print("✅ Aging model loaded!")
return age_model
def load_face_parser():
global face_processor, face_parser
if face_parser is not None:
return face_processor, face_parser
print("Loading Segformer face-parsing...")
face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to(DEVICE)
face_parser.eval()
print("✅ Face parser loaded!")
return face_processor, face_parser
def load_beard_model():
global beard_model
if beard_model is None:
print("Loading Beard Detection Model (YOLO)...")
beard_model = YOLO(BEARD_MODEL_PATH)
return beard_model
def load_gfpgan():
global gfpgan_restorer
if gfpgan_restorer is not None:
return gfpgan_restorer
if not os.path.exists(GFPGAN_MODEL_PATH):
model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
download_file(model_url, GFPGAN_MODEL_PATH)
print("🔄 Loading GFPGAN v1.4...")
try:
gfpgan_restorer = GFPGANer(
model_path=GFPGAN_MODEL_PATH,
upscale=GFPGAN_UPSCALE,
arch='clean',
channel_multiplier=2,
bg_upsampler=None,
device=DEVICE
)
print("✅ GFPGAN loaded successfully!")
return gfpgan_restorer
except Exception as e:
print(f"❌ GFPGAN load failed: {e}")
return None
# ================== MASK FUNCTIONS ==================
def get_lips_mask(pil_image: Image.Image) -> np.ndarray:
img_np = np.array(pil_image)
h, w = img_np.shape[:2]
lips_mask = np.zeros((h, w), dtype=np.uint8)
rgb_image = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
results = mp_face_mesh.process(rgb_image)
if results.multi_face_landmarks:
for face_landmarks in results.multi_face_landmarks:
lip_landmarks = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95]
points = []
for idx in lip_landmarks:
landmark = face_landmarks.landmark[idx]
x = int(landmark.x * w)
y = int(landmark.y * h)
points.append([x, y])
if points:
points_np = np.array(points, np.int32)
hull = cv2.convexHull(points_np)
cv2.fillConvexPoly(lips_mask, hull, 255)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
lips_mask = cv2.dilate(lips_mask, kernel, iterations=2)
lips_mask = cv2.GaussianBlur(lips_mask.astype(np.float32), (15, 15), 4)
lips_mask = np.clip(lips_mask / 255.0, 0, 1)
return lips_mask
return np.zeros((h, w), dtype=np.float32)
def exclude_lips_from_mask(beard_mask: np.ndarray, pil_image: Image.Image) -> np.ndarray:
if np.sum(beard_mask) == 0:
return beard_mask
lips_mask = get_lips_mask(pil_image)
lips_region = (lips_mask > 0.3).astype(np.float32)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
lips_region = cv2.dilate(lips_region, kernel, iterations=1)
beard_mask = beard_mask * (1.0 - lips_region)
beard_mask = cv2.GaussianBlur(beard_mask, (5, 5), 1)
return beard_mask
def get_beard_mask(pil_image: Image.Image) -> np.ndarray:
temp_path = "temp_input.jpg"
try:
pil_image.save(temp_path)
model = load_beard_model()
results = model(temp_path, device=DEVICE.type, conf=0.25, iou=0.5, verbose=False,
half=True if DEVICE.type == "cuda" else False)
img_np = np.array(pil_image)
h, w = img_np.shape[:2]
beard_mask = np.zeros((h, w), dtype=np.uint8)
if results[0].masks is not None:
for i, cls in enumerate(results[0].boxes.cls):
if int(cls) == 0: # beard class
mask = results[0].masks.data[i].cpu().numpy()
mask = cv2.resize(mask, (w, h))
mask = (mask > 0.4).astype(np.uint8) * 255
beard_mask = cv2.bitwise_or(beard_mask, mask)
if np.sum(beard_mask) > 0:
beard_mask_float = beard_mask.astype(np.float32) / 255.0
beard_mask_float = cv2.dilate(beard_mask_float, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)), iterations=2)
beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
beard_mask_float = cv2.morphologyEx(beard_mask_float, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)), iterations=2)
beard_mask_float = exclude_lips_from_mask(beard_mask_float, pil_image)
beard_mask_float = cv2.GaussianBlur(beard_mask_float, (7, 7), 2)
beard_mask_float = np.clip(beard_mask_float, 0, 1)
return beard_mask_float
return np.zeros((h, w), dtype=np.float32)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
def clean_mask(mask, min_area=150):
mask = mask.astype(np.uint8)
labeled, num = label(mask)
new_mask = np.zeros_like(mask)
for i in range(1, num + 1):
if np.sum(labeled == i) >= min_area:
new_mask[labeled == i] = 1
return new_mask
def get_hair_mask_segformer(pil_image: Image.Image) -> np.ndarray:
processor, parser = load_face_parser()
inputs = processor(images=pil_image, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
outputs = parser(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(logits, size=pil_image.size[::-1], mode="bilinear", align_corners=False)
probs = torch.softmax(upsampled, dim=1)[0]
hair_prob = probs[13].cpu().numpy()
hair_mask = (hair_prob > 0.12).astype(np.uint8)
face_classes = list(range(1, 6)) + list(range(8, 13)) + [17, 18]
parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
face_mask = np.isin(parsing, face_classes).astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
face_mask = cv2.dilate(face_mask, kernel, iterations=1)
hair_mask = hair_mask * (1 - face_mask)
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11)), iterations=2)
hair_mask = clean_mask(hair_mask, min_area=100)
hair_mask = cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 1.5)
hair_mask = np.clip(hair_mask, 0, 1)
return hair_mask
def apply_hair_and_beard_color(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray):
combined_mask = np.maximum(hair_mask, beard_mask)
if np.sum(combined_mask) == 0:
return image
combined_mask = cv2.GaussianBlur(combined_mask, (BLUR_RADIUS*2+1, BLUR_RADIUS*2+1), BLUR_RADIUS)
combined_mask = np.clip(combined_mask, 0, 1)
if EDGE_SMOOTHING:
combined_mask = cv2.bilateralFilter(combined_mask.astype(np.float32), 9, 75, 75)
combined_mask = np.clip(combined_mask, 0, 1)
combined_mask = np.clip(combined_mask * 1.2, 0, 1)
img_np = np.array(image).astype(np.float32)
target_color = np.array([255, 255, 255], dtype=np.float32)
gray = cv2.cvtColor(img_np.astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
lum_factor = 0.6 + 0.4 * gray
white_layer = target_color * lum_factor[..., np.newaxis]
alpha = ALPHA_HAIR
result = (1 - alpha * combined_mask[..., np.newaxis]) * img_np + (alpha * combined_mask[..., np.newaxis]) * white_layer
result = np.clip(result, 0, 255).astype(np.uint8)
result_pil = Image.fromarray(result)
result_pil = result_pil.filter(ImageFilter.UnsharpMask(1.2, 140, 2))
return result_pil
def post_correct_aged(original: Image.Image, aged: Image.Image) -> Image.Image:
orig_np = np.array(original)
aged_np = np.array(aged)
matched = match_histograms(aged_np, orig_np, channel_axis=-1)
matched_img = Image.fromarray(np.clip(matched, 0, 255).astype(np.uint8))
matched_img = ImageEnhance.Brightness(matched_img).enhance(1.10)
matched_img = ImageEnhance.Contrast(matched_img).enhance(1.06)
return matched_img
def enhance_texture(img: Image.Image) -> Image.Image:
img = img.filter(ImageFilter.UnsharpMask(2, 160, 3))
img = ImageEnhance.Contrast(img).enhance(CONTRAST_BOOST)
img = ImageEnhance.Sharpness(img).enhance(SHARPNESS_BOOST)
return img
# ================== MAIN PROCESSING FUNCTION (Memory Safe) ==================
def process_face_aging(input_image: Image.Image) -> Image.Image:
if input_image is None:
raise ValueError("Please provide a valid image!")
try:
print(f"→ Processing image: {input_image.size}")
orig = input_image.convert("RGB")
ow, oh = orig.size
img_resized = orig.resize((SAFE_IMG_SIZE, SAFE_IMG_SIZE), Image.LANCZOS)
rgb_tensor = TF.to_tensor(img_resized).to(DEVICE)
src_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), SOURCE_AGE / 100.0, device=DEVICE)
tgt_age = torch.full((1, SAFE_IMG_SIZE, SAFE_IMG_SIZE), TARGET_AGE / 100.0, device=DEVICE)
cond_input = torch.cat([rgb_tensor.unsqueeze(0), src_age.unsqueeze(0), tgt_age.unsqueeze(0)], dim=1)
# Aging Model
with torch.inference_mode():
aging_net = load_aging_model()
raw_output = aging_net(cond_input).squeeze(0).float()
alpha = WRINKLE_STRENGTH
blended = (1 - alpha) * rgb_tensor + alpha * raw_output
blended = blended.clamp(0, 1)
final_aged = TF.to_pil_image(blended).resize((ow, oh), Image.LANCZOS)
final_aged = enhance_texture(final_aged)
final_aged = post_correct_aged(orig, final_aged)
# Hair & Beard
print(" Generating hair mask...")
hair_mask = get_hair_mask_segformer(final_aged)
print(" Generating beard mask...")
beard_mask = get_beard_mask(final_aged)
print(" Applying white hair & beard...")
final_img = apply_hair_and_beard_color(final_aged, hair_mask, beard_mask)
# GFPGAN
if USE_GFPGAN:
print(" Applying GFPGAN face restoration...")
gfpgan = load_gfpgan()
if gfpgan is not None:
try:
img_cv = cv2.cvtColor(np.array(final_img), cv2.COLOR_RGB2BGR)
with torch.inference_mode():
_, _, restored_cv = gfpgan.enhance(
img_cv, has_aligned=False, only_center_face=False,
paste_back=True, weight=GFPGAN_WEIGHT
)
final_img = Image.fromarray(cv2.cvtColor(restored_cv, cv2.COLOR_BGR2RGB))
except Exception as e:
print(f" GFPGAN error (skipped): {e}")
print("✓ Processing completed!")
return final_img
except Exception as e:
print(f"❌ Error: {str(e)}")
traceback.print_exc()
if DEVICE.type == "cuda":
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
raise
finally:
gc.collect()
if DEVICE.type == "cuda":
torch.cuda.empty_cache()
# ================== FASTAPI SETUP ==================
app = FastAPI(title="Face Aging + White Hair & Beard API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/age-face")
async def age_face(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Only image files allowed")
contents = await file.read()
try:
input_image = Image.open(io.BytesIO(contents)).convert("RGB")
result_image = process_face_aging(input_image)
img_byte_arr = io.BytesIO()
result_image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
finally:
gc.collect()
if DEVICE.type == "cuda":
torch.cuda.empty_cache()
# For local testing
if __name__ == "__main__":
import uvicorn
print("Starting FastAPI server...")
uvicorn.run(app, host="0.0.0.0", port=7860)