pluto90's picture
Update app.py
c8ce396 verified
# backend/app.py
import os, io, uuid, sys, json, asyncio
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from torchvision import transforms
# ------------------ BASE SETUP ------------------
BASE_DIR = Path(__file__).resolve().parent
sys.path.append(str(BASE_DIR / "helpers"))
from helpers.transform_net import TransformerNet
app = FastAPI()
# ------------------ CORS ------------------
FRONTEND_URL = os.environ.get("FRONTEND_URL")
app.add_middleware(
CORSMiddleware,
allow_origins=[
FRONTEND_URL
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------ DEVICE ------------------
# HF Spaces free tier = CPU only
# cuda.amp.autocast is disabled on CPU to avoid warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = device.type == "cuda"
print(f"Running on: {device}")
# ------------------ OUTPUTS ------------------
OUTPUT_DIR = BASE_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
# ------------------ MODELS ------------------
models_json_path = BASE_DIR / "models.json"
if not models_json_path.exists():
raise RuntimeError(f"models.json not found at {models_json_path}")
with open(models_json_path, "r") as f:
MODEL_PATHS = json.load(f)
# Convert relative paths to absolute
for cat, styles in MODEL_PATHS.items():
for style_name, rel_path in styles.items():
p = Path(rel_path)
if not p.is_absolute():
MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
# In-memory model cache
models = {}
def load_model(category: str, style: str):
key = (category, style)
if key in models:
return models[key]
if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
raise HTTPException(status_code=400, detail="Invalid category/style")
path = MODEL_PATHS[category][style]
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
model = TransformerNet().to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
model = torch.jit.script(model)
models[key] = model
print(f"Loaded model: {category}/{style}")
return model
# Preload all models at startup
# Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
@app.on_event("startup")
async def preload_all_models():
print("Preloading all models...")
for cat, styles in MODEL_PATHS.items():
for style in styles:
try:
load_model(cat, style)
except Exception as e:
print(f"Warning: Could not load {cat}/{style}{e}")
print(f"Done. {len(models)} model(s) loaded.")
# ------------------ IMAGE UTILS ------------------
def save_image_tensor(tensor, path: Path):
img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
Image.fromarray(img.astype("uint8")).save(path)
def stylize_image(img: Image.Image, model, img_size: int = 256):
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor()
])
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
# autocast only when GPU is available, safe no-op on CPU
y = model(x)
return y
# ------------------ CLEANUP ------------------
async def delete_file_after_delay(path: Path, delay: int = 180):
await asyncio.sleep(delay)
try:
if path.exists():
path.unlink()
print(f"Deleted {path} after {delay}s")
except Exception as e:
print(f"Error deleting file: {e}")
# ------------------ ROUTES ------------------
@app.get("/")
async def root():
return {"message": "Backend is running!", "device": str(device)}
@app.get("/api/styles")
async def get_styles():
return MODEL_PATHS
@app.post("/api/stylize")
async def stylize(
request: Request,
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
category: str = Form(...),
style: str = Form(...),
):
model = load_model(category, style)
contents = await file.read()
input_img = Image.open(io.BytesIO(contents)).convert("RGB")
output_tensor = stylize_image(input_img, model)
filename = f"{uuid.uuid4().hex}.jpg"
out_path = OUTPUT_DIR / filename
save_image_tensor(output_tensor, out_path)
background_tasks.add_task(delete_file_after_delay, out_path, 180)
# base_url = str(request.base_url).rstrip("/")
# return {"image_url": f"{base_url}/download/{filename}"}
return {"filename": filename}
@app.get("/api/download/{filename}")
async def download(filename: str):
path = OUTPUT_DIR / filename
if not path.exists():
raise HTTPException(status_code=404, detail="File not found or already deleted")
return FileResponse(path, media_type="image/jpeg", filename=filename)