Face-Aging / app.py
kingloft's picture
Update app.py
7f91008 verified
import os
import uuid
import shutil
import tempfile
import asyncio
import concurrent.futures
import torch
from fastapi import (
FastAPI,
UploadFile,
File,
Form,
HTTPException
)
from fastapi.responses import (
FileResponse,
HTMLResponse,
JSONResponse
)
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from pydantic import BaseModel
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from models import UNet
from test_functions import process_image
# =========================================================
# LOAD ENV
# =========================================================
load_dotenv()
# =========================================================
# CPU OPTIMIZATION
# =========================================================
torch.set_num_threads(2)
torch.backends.mkldnn.enabled = True
# =========================================================
# THREAD POOL
# =========================================================
MAX_WORKERS = 4
executor = concurrent.futures.ThreadPoolExecutor(
max_workers=MAX_WORKERS
)
# =========================================================
# APP
# =========================================================
app = FastAPI(
title="Face Aging API",
version="4.0.0"
)
# =========================================================
# CORS
# =========================================================
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Content-Disposition"]
)
# =========================================================
# SETTINGS
# =========================================================
class AppSettings(BaseModel):
model_repo: str = "Robys01/face-aging"
max_upload_size_mb: int = 10
allowed_extensions: list = [
"jpg",
"jpeg",
"png",
"webp"
]
settings = AppSettings()
# =========================================================
# MODEL PATH
# =========================================================
MODEL_DIR = "/tmp/model"
os.makedirs(MODEL_DIR, exist_ok=True)
MODEL_PATH = os.path.join(
MODEL_DIR,
"best_unet_model.pth"
)
# =========================================================
# DOWNLOAD MODEL
# =========================================================
def download_model():
print("Downloading model...")
hf_hub_download(
repo_id=settings.model_repo,
filename="best_unet_model.pth",
local_dir=MODEL_DIR,
cache_dir=os.environ.get(
"HUGGINGFACE_HUB_CACHE"
),
)
# =========================================================
# LOAD MODEL
# =========================================================
if not os.path.exists(MODEL_PATH):
download_model()
model = UNet()
model.load_state_dict(
torch.load(
MODEL_PATH,
map_location=torch.device("cpu"),
weights_only=False
)
)
model.eval()
print("Model loaded successfully")
# =========================================================
# IMAGE SETTINGS
# =========================================================
MAX_IMAGE_SIZE = 768
PNG_COMPRESS_LEVEL = 9
# =========================================================
# UTILITIES
# =========================================================
def validate_image(filename: str):
if "." not in filename:
raise HTTPException(
status_code=400,
detail="Invalid filename"
)
ext = filename.split(".")[-1].lower()
if ext not in settings.allowed_extensions:
raise HTTPException(
status_code=400,
detail="Unsupported image format"
)
def save_upload_temp(upload_file: UploadFile):
suffix = "." + upload_file.filename.split(".")[-1]
temp_file = tempfile.NamedTemporaryFile(
delete=False,
suffix=suffix
)
with temp_file as buffer:
shutil.copyfileobj(
upload_file.file,
buffer
)
return temp_file.name
def resize_for_mobile(image: Image.Image):
image.thumbnail(
(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE),
Image.LANCZOS
)
return image
def create_png_output(image: Image.Image):
output_filename = f"{uuid.uuid4().hex}.png"
output_path = os.path.join(
tempfile.gettempdir(),
output_filename
)
image.save(
output_path,
format="PNG",
optimize=True,
compress_level=PNG_COMPRESS_LEVEL
)
return output_path
def cleanup_temp(path):
try:
if path and os.path.exists(path):
os.remove(path)
except:
pass
# =========================================================
# AI PROCESSING
# =========================================================
def run_face_aging(
image_path,
source_age,
target_age
):
pil_image = Image.open(image_path)
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
pil_image = resize_for_mobile(
pil_image
)
with torch.inference_mode():
processed_image = process_image(
model,
pil_image,
source_age,
target_age
)
processed_image = resize_for_mobile(
processed_image
)
output_path = create_png_output(
processed_image
)
return output_path
# =========================================================
# HOME
# =========================================================
@app.get("/", response_class=HTMLResponse)
async def home():
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport"
content="width=device-width, initial-scale=1.0">
<title>Face Aging API</title>
<style>
*{
margin:0;
padding:0;
box-sizing:border-box;
}
body{
font-family:Arial,sans-serif;
background:#0f172a;
color:white;
min-height:100vh;
padding:20px;
}
.container{
max-width:700px;
margin:auto;
}
.title{
text-align:center;
margin-bottom:30px;
}
.title h1{
font-size:40px;
margin-bottom:10px;
}
.title p{
color:#94a3b8;
}
.card{
background:#1e293b;
border-radius:20px;
padding:25px;
}
input{
width:100%;
padding:14px;
margin-top:14px;
border:none;
border-radius:10px;
background:#334155;
color:white;
font-size:16px;
}
button{
width:100%;
padding:15px;
margin-top:20px;
border:none;
border-radius:12px;
background:#2563eb;
color:white;
font-size:16px;
cursor:pointer;
font-weight:bold;
}
button:hover{
background:#1d4ed8;
}
.preview{
width:100%;
margin-top:20px;
border-radius:16px;
}
.loader{
width:100%;
text-align:center;
margin-top:15px;
display:none;
}
.status{
margin-top:15px;
}
.success{
color:#22c55e;
}
.error{
color:#ef4444;
}
</style>
</head>
<body>
<div class="container">
<div class="title">
<h1>Face Aging AI</h1>
<p>Fast Multi Request CPU API</p>
</div>
<div class="card">
<input
type="file"
id="faceImage"
accept="image/*">
<input
type="number"
id="sourceAge"
placeholder="Current Age"
value="20">
<input
type="number"
id="targetAge"
placeholder="Target Age"
value="70">
<button onclick="ageFace()">
Generate Aged Face
</button>
<div
class="loader"
id="loader">
Processing...
</div>
<div
class="status"
id="status">
</div>
<img
id="preview"
class="preview">
</div>
</div>
<script>
function showLoader(){
document.getElementById(
"loader"
).style.display = "block"
}
function hideLoader(){
document.getElementById(
"loader"
).style.display = "none"
}
async function ageFace(){
try{
showLoader()
document.getElementById(
"status"
).innerHTML = ""
const file =
document.getElementById(
"faceImage"
).files[0]
if(!file){
alert("Select image")
hideLoader()
return
}
const formData = new FormData()
formData.append(
"image",
file
)
formData.append(
"source_age",
document.getElementById(
"sourceAge"
).value
)
formData.append(
"target_age",
document.getElementById(
"targetAge"
).value
)
const response =
await fetch(
"/age-face",
{
method:"POST",
body:formData,
cache:"no-cache"
}
)
hideLoader()
if(!response.ok){
const err =
await response.text()
document.getElementById(
"status"
).innerHTML =
"<span class='error'>"
+ err +
"</span>"
return
}
const blob =
await response.blob()
const sizeMB =
(blob.size / 1024 / 1024)
.toFixed(2)
const url =
URL.createObjectURL(blob)
document.getElementById(
"preview"
).src = url
document.getElementById(
"status"
).innerHTML =
"<span class='success'>Done • "
+ sizeMB +
" MB</span>"
}catch(error){
hideLoader()
document.getElementById(
"status"
).innerHTML =
"<span class='error'>"
+ error +
"</span>"
}
}
</script>
</body>
</html>
"""
# =========================================================
# HEALTH
# =========================================================
@app.get("/health")
def health():
return {
"status": "healthy",
"model_loaded": True,
"device": "cpu",
"max_workers": MAX_WORKERS
}
# =========================================================
# SETTINGS
# =========================================================
@app.get("/settings")
def get_settings():
return settings.dict()
# =========================================================
# AGE FACE
# =========================================================
@app.post("/age-face")
async def age_face(
image: UploadFile = File(...),
source_age: int = Form(...),
target_age: int = Form(...)
):
temp_input = None
output_path = None
try:
validate_image(image.filename)
temp_input = save_upload_temp(
image
)
loop = asyncio.get_running_loop()
output_path = await loop.run_in_executor(
executor,
run_face_aging,
temp_input,
source_age,
target_age
)
return FileResponse(
path=output_path,
media_type="image/png",
filename="aged_face.png",
headers={
"Content-Disposition":
"inline; filename=aged_face.png",
"Cache-Control":
"public, max-age=86400"
}
)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"error": str(e)
}
)
finally:
cleanup_temp(temp_input)
# =========================================================
# MAIN
# =========================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7000,
reload=False,
workers=1
)