Spaces:
Running
Running
| import os | |
| import uuid | |
| import shutil | |
| import tempfile | |
| import asyncio | |
| import concurrent.futures | |
| import torch | |
| from fastapi import ( | |
| FastAPI, | |
| UploadFile, | |
| File, | |
| Form, | |
| HTTPException | |
| ) | |
| from fastapi.responses import ( | |
| FileResponse, | |
| HTMLResponse, | |
| JSONResponse | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download | |
| from models import UNet | |
| from test_functions import process_image | |
| # ========================================================= | |
| # LOAD ENV | |
| # ========================================================= | |
| load_dotenv() | |
| # ========================================================= | |
| # CPU OPTIMIZATION | |
| # ========================================================= | |
| torch.set_num_threads(2) | |
| torch.backends.mkldnn.enabled = True | |
| # ========================================================= | |
| # THREAD POOL | |
| # ========================================================= | |
| MAX_WORKERS = 4 | |
| executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=MAX_WORKERS | |
| ) | |
| # ========================================================= | |
| # APP | |
| # ========================================================= | |
| app = FastAPI( | |
| title="Face Aging API", | |
| version="4.0.0" | |
| ) | |
| # ========================================================= | |
| # CORS | |
| # ========================================================= | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["Content-Disposition"] | |
| ) | |
| # ========================================================= | |
| # SETTINGS | |
| # ========================================================= | |
| class AppSettings(BaseModel): | |
| model_repo: str = "Robys01/face-aging" | |
| max_upload_size_mb: int = 10 | |
| allowed_extensions: list = [ | |
| "jpg", | |
| "jpeg", | |
| "png", | |
| "webp" | |
| ] | |
| settings = AppSettings() | |
| # ========================================================= | |
| # MODEL PATH | |
| # ========================================================= | |
| MODEL_DIR = "/tmp/model" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| MODEL_PATH = os.path.join( | |
| MODEL_DIR, | |
| "best_unet_model.pth" | |
| ) | |
| # ========================================================= | |
| # DOWNLOAD MODEL | |
| # ========================================================= | |
| def download_model(): | |
| print("Downloading model...") | |
| hf_hub_download( | |
| repo_id=settings.model_repo, | |
| filename="best_unet_model.pth", | |
| local_dir=MODEL_DIR, | |
| cache_dir=os.environ.get( | |
| "HUGGINGFACE_HUB_CACHE" | |
| ), | |
| ) | |
| # ========================================================= | |
| # LOAD MODEL | |
| # ========================================================= | |
| if not os.path.exists(MODEL_PATH): | |
| download_model() | |
| model = UNet() | |
| model.load_state_dict( | |
| torch.load( | |
| MODEL_PATH, | |
| map_location=torch.device("cpu"), | |
| weights_only=False | |
| ) | |
| ) | |
| model.eval() | |
| print("Model loaded successfully") | |
| # ========================================================= | |
| # IMAGE SETTINGS | |
| # ========================================================= | |
| MAX_IMAGE_SIZE = 768 | |
| PNG_COMPRESS_LEVEL = 9 | |
| # ========================================================= | |
| # UTILITIES | |
| # ========================================================= | |
| def validate_image(filename: str): | |
| if "." not in filename: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid filename" | |
| ) | |
| ext = filename.split(".")[-1].lower() | |
| if ext not in settings.allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Unsupported image format" | |
| ) | |
| def save_upload_temp(upload_file: UploadFile): | |
| suffix = "." + upload_file.filename.split(".")[-1] | |
| temp_file = tempfile.NamedTemporaryFile( | |
| delete=False, | |
| suffix=suffix | |
| ) | |
| with temp_file as buffer: | |
| shutil.copyfileobj( | |
| upload_file.file, | |
| buffer | |
| ) | |
| return temp_file.name | |
| def resize_for_mobile(image: Image.Image): | |
| image.thumbnail( | |
| (MAX_IMAGE_SIZE, MAX_IMAGE_SIZE), | |
| Image.LANCZOS | |
| ) | |
| return image | |
| def create_png_output(image: Image.Image): | |
| output_filename = f"{uuid.uuid4().hex}.png" | |
| output_path = os.path.join( | |
| tempfile.gettempdir(), | |
| output_filename | |
| ) | |
| image.save( | |
| output_path, | |
| format="PNG", | |
| optimize=True, | |
| compress_level=PNG_COMPRESS_LEVEL | |
| ) | |
| return output_path | |
| def cleanup_temp(path): | |
| try: | |
| if path and os.path.exists(path): | |
| os.remove(path) | |
| except: | |
| pass | |
| # ========================================================= | |
| # AI PROCESSING | |
| # ========================================================= | |
| def run_face_aging( | |
| image_path, | |
| source_age, | |
| target_age | |
| ): | |
| pil_image = Image.open(image_path) | |
| if pil_image.mode != "RGB": | |
| pil_image = pil_image.convert("RGB") | |
| pil_image = resize_for_mobile( | |
| pil_image | |
| ) | |
| with torch.inference_mode(): | |
| processed_image = process_image( | |
| model, | |
| pil_image, | |
| source_age, | |
| target_age | |
| ) | |
| processed_image = resize_for_mobile( | |
| processed_image | |
| ) | |
| output_path = create_png_output( | |
| processed_image | |
| ) | |
| return output_path | |
| # ========================================================= | |
| # HOME | |
| # ========================================================= | |
| async def home(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" | |
| content="width=device-width, initial-scale=1.0"> | |
| <title>Face Aging API</title> | |
| <style> | |
| *{ | |
| margin:0; | |
| padding:0; | |
| box-sizing:border-box; | |
| } | |
| body{ | |
| font-family:Arial,sans-serif; | |
| background:#0f172a; | |
| color:white; | |
| min-height:100vh; | |
| padding:20px; | |
| } | |
| .container{ | |
| max-width:700px; | |
| margin:auto; | |
| } | |
| .title{ | |
| text-align:center; | |
| margin-bottom:30px; | |
| } | |
| .title h1{ | |
| font-size:40px; | |
| margin-bottom:10px; | |
| } | |
| .title p{ | |
| color:#94a3b8; | |
| } | |
| .card{ | |
| background:#1e293b; | |
| border-radius:20px; | |
| padding:25px; | |
| } | |
| input{ | |
| width:100%; | |
| padding:14px; | |
| margin-top:14px; | |
| border:none; | |
| border-radius:10px; | |
| background:#334155; | |
| color:white; | |
| font-size:16px; | |
| } | |
| button{ | |
| width:100%; | |
| padding:15px; | |
| margin-top:20px; | |
| border:none; | |
| border-radius:12px; | |
| background:#2563eb; | |
| color:white; | |
| font-size:16px; | |
| cursor:pointer; | |
| font-weight:bold; | |
| } | |
| button:hover{ | |
| background:#1d4ed8; | |
| } | |
| .preview{ | |
| width:100%; | |
| margin-top:20px; | |
| border-radius:16px; | |
| } | |
| .loader{ | |
| width:100%; | |
| text-align:center; | |
| margin-top:15px; | |
| display:none; | |
| } | |
| .status{ | |
| margin-top:15px; | |
| } | |
| .success{ | |
| color:#22c55e; | |
| } | |
| .error{ | |
| color:#ef4444; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="title"> | |
| <h1>Face Aging AI</h1> | |
| <p>Fast Multi Request CPU API</p> | |
| </div> | |
| <div class="card"> | |
| <input | |
| type="file" | |
| id="faceImage" | |
| accept="image/*"> | |
| <input | |
| type="number" | |
| id="sourceAge" | |
| placeholder="Current Age" | |
| value="20"> | |
| <input | |
| type="number" | |
| id="targetAge" | |
| placeholder="Target Age" | |
| value="70"> | |
| <button onclick="ageFace()"> | |
| Generate Aged Face | |
| </button> | |
| <div | |
| class="loader" | |
| id="loader"> | |
| Processing... | |
| </div> | |
| <div | |
| class="status" | |
| id="status"> | |
| </div> | |
| <img | |
| id="preview" | |
| class="preview"> | |
| </div> | |
| </div> | |
| <script> | |
| function showLoader(){ | |
| document.getElementById( | |
| "loader" | |
| ).style.display = "block" | |
| } | |
| function hideLoader(){ | |
| document.getElementById( | |
| "loader" | |
| ).style.display = "none" | |
| } | |
| async function ageFace(){ | |
| try{ | |
| showLoader() | |
| document.getElementById( | |
| "status" | |
| ).innerHTML = "" | |
| const file = | |
| document.getElementById( | |
| "faceImage" | |
| ).files[0] | |
| if(!file){ | |
| alert("Select image") | |
| hideLoader() | |
| return | |
| } | |
| const formData = new FormData() | |
| formData.append( | |
| "image", | |
| file | |
| ) | |
| formData.append( | |
| "source_age", | |
| document.getElementById( | |
| "sourceAge" | |
| ).value | |
| ) | |
| formData.append( | |
| "target_age", | |
| document.getElementById( | |
| "targetAge" | |
| ).value | |
| ) | |
| const response = | |
| await fetch( | |
| "/age-face", | |
| { | |
| method:"POST", | |
| body:formData, | |
| cache:"no-cache" | |
| } | |
| ) | |
| hideLoader() | |
| if(!response.ok){ | |
| const err = | |
| await response.text() | |
| document.getElementById( | |
| "status" | |
| ).innerHTML = | |
| "<span class='error'>" | |
| + err + | |
| "</span>" | |
| return | |
| } | |
| const blob = | |
| await response.blob() | |
| const sizeMB = | |
| (blob.size / 1024 / 1024) | |
| .toFixed(2) | |
| const url = | |
| URL.createObjectURL(blob) | |
| document.getElementById( | |
| "preview" | |
| ).src = url | |
| document.getElementById( | |
| "status" | |
| ).innerHTML = | |
| "<span class='success'>Done • " | |
| + sizeMB + | |
| " MB</span>" | |
| }catch(error){ | |
| hideLoader() | |
| document.getElementById( | |
| "status" | |
| ).innerHTML = | |
| "<span class='error'>" | |
| + error + | |
| "</span>" | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ========================================================= | |
| # HEALTH | |
| # ========================================================= | |
| def health(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "device": "cpu", | |
| "max_workers": MAX_WORKERS | |
| } | |
| # ========================================================= | |
| # SETTINGS | |
| # ========================================================= | |
| def get_settings(): | |
| return settings.dict() | |
| # ========================================================= | |
| # AGE FACE | |
| # ========================================================= | |
| async def age_face( | |
| image: UploadFile = File(...), | |
| source_age: int = Form(...), | |
| target_age: int = Form(...) | |
| ): | |
| temp_input = None | |
| output_path = None | |
| try: | |
| validate_image(image.filename) | |
| temp_input = save_upload_temp( | |
| image | |
| ) | |
| loop = asyncio.get_running_loop() | |
| output_path = await loop.run_in_executor( | |
| executor, | |
| run_face_aging, | |
| temp_input, | |
| source_age, | |
| target_age | |
| ) | |
| return FileResponse( | |
| path=output_path, | |
| media_type="image/png", | |
| filename="aged_face.png", | |
| headers={ | |
| "Content-Disposition": | |
| "inline; filename=aged_face.png", | |
| "Cache-Control": | |
| "public, max-age=86400" | |
| } | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "success": False, | |
| "error": str(e) | |
| } | |
| ) | |
| finally: | |
| cleanup_temp(temp_input) | |
| # ========================================================= | |
| # MAIN | |
| # ========================================================= | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7000, | |
| reload=False, | |
| workers=1 | |
| ) |