import os import uuid import shutil import tempfile import asyncio import concurrent.futures import torch from fastapi import ( FastAPI, UploadFile, File, Form, HTTPException ) from fastapi.responses import ( FileResponse, HTMLResponse, JSONResponse ) from fastapi.middleware.cors import CORSMiddleware from PIL import Image from pydantic import BaseModel from dotenv import load_dotenv from huggingface_hub import hf_hub_download from models import UNet from test_functions import process_image # ========================================================= # LOAD ENV # ========================================================= load_dotenv() # ========================================================= # CPU OPTIMIZATION # ========================================================= torch.set_num_threads(2) torch.backends.mkldnn.enabled = True # ========================================================= # THREAD POOL # ========================================================= MAX_WORKERS = 4 executor = concurrent.futures.ThreadPoolExecutor( max_workers=MAX_WORKERS ) # ========================================================= # APP # ========================================================= app = FastAPI( title="Face Aging API", version="4.0.0" ) # ========================================================= # CORS # ========================================================= app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["Content-Disposition"] ) # ========================================================= # SETTINGS # ========================================================= class AppSettings(BaseModel): model_repo: str = "Robys01/face-aging" max_upload_size_mb: int = 10 allowed_extensions: list = [ "jpg", "jpeg", "png", "webp" ] settings = AppSettings() # ========================================================= # MODEL PATH # ========================================================= MODEL_DIR = "/tmp/model" os.makedirs(MODEL_DIR, exist_ok=True) MODEL_PATH = os.path.join( MODEL_DIR, "best_unet_model.pth" ) # ========================================================= # DOWNLOAD MODEL # ========================================================= def download_model(): print("Downloading model...") hf_hub_download( repo_id=settings.model_repo, filename="best_unet_model.pth", local_dir=MODEL_DIR, cache_dir=os.environ.get( "HUGGINGFACE_HUB_CACHE" ), ) # ========================================================= # LOAD MODEL # ========================================================= if not os.path.exists(MODEL_PATH): download_model() model = UNet() model.load_state_dict( torch.load( MODEL_PATH, map_location=torch.device("cpu"), weights_only=False ) ) model.eval() print("Model loaded successfully") # ========================================================= # IMAGE SETTINGS # ========================================================= MAX_IMAGE_SIZE = 768 PNG_COMPRESS_LEVEL = 9 # ========================================================= # UTILITIES # ========================================================= def validate_image(filename: str): if "." not in filename: raise HTTPException( status_code=400, detail="Invalid filename" ) ext = filename.split(".")[-1].lower() if ext not in settings.allowed_extensions: raise HTTPException( status_code=400, detail="Unsupported image format" ) def save_upload_temp(upload_file: UploadFile): suffix = "." + upload_file.filename.split(".")[-1] temp_file = tempfile.NamedTemporaryFile( delete=False, suffix=suffix ) with temp_file as buffer: shutil.copyfileobj( upload_file.file, buffer ) return temp_file.name def resize_for_mobile(image: Image.Image): image.thumbnail( (MAX_IMAGE_SIZE, MAX_IMAGE_SIZE), Image.LANCZOS ) return image def create_png_output(image: Image.Image): output_filename = f"{uuid.uuid4().hex}.png" output_path = os.path.join( tempfile.gettempdir(), output_filename ) image.save( output_path, format="PNG", optimize=True, compress_level=PNG_COMPRESS_LEVEL ) return output_path def cleanup_temp(path): try: if path and os.path.exists(path): os.remove(path) except: pass # ========================================================= # AI PROCESSING # ========================================================= def run_face_aging( image_path, source_age, target_age ): pil_image = Image.open(image_path) if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") pil_image = resize_for_mobile( pil_image ) with torch.inference_mode(): processed_image = process_image( model, pil_image, source_age, target_age ) processed_image = resize_for_mobile( processed_image ) output_path = create_png_output( processed_image ) return output_path # ========================================================= # HOME # ========================================================= @app.get("/", response_class=HTMLResponse) async def home(): return """
Fast Multi Request CPU API