import os import uuid import shutil import tempfile import asyncio import concurrent.futures import torch from fastapi import ( FastAPI, UploadFile, File, Form, HTTPException ) from fastapi.responses import ( FileResponse, HTMLResponse, JSONResponse ) from fastapi.middleware.cors import CORSMiddleware from PIL import Image from pydantic import BaseModel from dotenv import load_dotenv from huggingface_hub import hf_hub_download from models import UNet from test_functions import process_image # ========================================================= # LOAD ENV # ========================================================= load_dotenv() # ========================================================= # CPU OPTIMIZATION # ========================================================= torch.set_num_threads(2) torch.backends.mkldnn.enabled = True # ========================================================= # THREAD POOL # ========================================================= MAX_WORKERS = 4 executor = concurrent.futures.ThreadPoolExecutor( max_workers=MAX_WORKERS ) # ========================================================= # APP # ========================================================= app = FastAPI( title="Face Aging API", version="4.0.0" ) # ========================================================= # CORS # ========================================================= app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["Content-Disposition"] ) # ========================================================= # SETTINGS # ========================================================= class AppSettings(BaseModel): model_repo: str = "Robys01/face-aging" max_upload_size_mb: int = 10 allowed_extensions: list = [ "jpg", "jpeg", "png", "webp" ] settings = AppSettings() # ========================================================= # MODEL PATH # ========================================================= MODEL_DIR = "/tmp/model" os.makedirs(MODEL_DIR, exist_ok=True) MODEL_PATH = os.path.join( MODEL_DIR, "best_unet_model.pth" ) # ========================================================= # DOWNLOAD MODEL # ========================================================= def download_model(): print("Downloading model...") hf_hub_download( repo_id=settings.model_repo, filename="best_unet_model.pth", local_dir=MODEL_DIR, cache_dir=os.environ.get( "HUGGINGFACE_HUB_CACHE" ), ) # ========================================================= # LOAD MODEL # ========================================================= if not os.path.exists(MODEL_PATH): download_model() model = UNet() model.load_state_dict( torch.load( MODEL_PATH, map_location=torch.device("cpu"), weights_only=False ) ) model.eval() print("Model loaded successfully") # ========================================================= # IMAGE SETTINGS # ========================================================= MAX_IMAGE_SIZE = 768 PNG_COMPRESS_LEVEL = 9 # ========================================================= # UTILITIES # ========================================================= def validate_image(filename: str): if "." not in filename: raise HTTPException( status_code=400, detail="Invalid filename" ) ext = filename.split(".")[-1].lower() if ext not in settings.allowed_extensions: raise HTTPException( status_code=400, detail="Unsupported image format" ) def save_upload_temp(upload_file: UploadFile): suffix = "." + upload_file.filename.split(".")[-1] temp_file = tempfile.NamedTemporaryFile( delete=False, suffix=suffix ) with temp_file as buffer: shutil.copyfileobj( upload_file.file, buffer ) return temp_file.name def resize_for_mobile(image: Image.Image): image.thumbnail( (MAX_IMAGE_SIZE, MAX_IMAGE_SIZE), Image.LANCZOS ) return image def create_png_output(image: Image.Image): output_filename = f"{uuid.uuid4().hex}.png" output_path = os.path.join( tempfile.gettempdir(), output_filename ) image.save( output_path, format="PNG", optimize=True, compress_level=PNG_COMPRESS_LEVEL ) return output_path def cleanup_temp(path): try: if path and os.path.exists(path): os.remove(path) except: pass # ========================================================= # AI PROCESSING # ========================================================= def run_face_aging( image_path, source_age, target_age ): pil_image = Image.open(image_path) if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") pil_image = resize_for_mobile( pil_image ) with torch.inference_mode(): processed_image = process_image( model, pil_image, source_age, target_age ) processed_image = resize_for_mobile( processed_image ) output_path = create_png_output( processed_image ) return output_path # ========================================================= # HOME # ========================================================= @app.get("/", response_class=HTMLResponse) async def home(): return """ Face Aging API

Face Aging AI

Fast Multi Request CPU API

Processing...
""" # ========================================================= # HEALTH # ========================================================= @app.get("/health") def health(): return { "status": "healthy", "model_loaded": True, "device": "cpu", "max_workers": MAX_WORKERS } # ========================================================= # SETTINGS # ========================================================= @app.get("/settings") def get_settings(): return settings.dict() # ========================================================= # AGE FACE # ========================================================= @app.post("/age-face") async def age_face( image: UploadFile = File(...), source_age: int = Form(...), target_age: int = Form(...) ): temp_input = None output_path = None try: validate_image(image.filename) temp_input = save_upload_temp( image ) loop = asyncio.get_running_loop() output_path = await loop.run_in_executor( executor, run_face_aging, temp_input, source_age, target_age ) return FileResponse( path=output_path, media_type="image/png", filename="aged_face.png", headers={ "Content-Disposition": "inline; filename=aged_face.png", "Cache-Control": "public, max-age=86400" } ) except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "error": str(e) } ) finally: cleanup_temp(temp_input) # ========================================================= # MAIN # ========================================================= if __name__ == "__main__": import uvicorn uvicorn.run( "app:app", host="0.0.0.0", port=7000, reload=False, workers=1 )