# backend/app.py import os, io, uuid, sys, json, asyncio from pathlib import Path from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from PIL import Image import torch from torchvision import transforms # ------------------ BASE SETUP ------------------ BASE_DIR = Path(__file__).resolve().parent sys.path.append(str(BASE_DIR / "helpers")) from helpers.transform_net import TransformerNet app = FastAPI() # ------------------ CORS ------------------ FRONTEND_URL = os.environ.get("FRONTEND_URL") app.add_middleware( CORSMiddleware, allow_origins=[ FRONTEND_URL ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------ DEVICE ------------------ # HF Spaces free tier = CPU only # cuda.amp.autocast is disabled on CPU to avoid warnings device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_amp = device.type == "cuda" print(f"Running on: {device}") # ------------------ OUTPUTS ------------------ OUTPUT_DIR = BASE_DIR / "outputs" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download") # ------------------ MODELS ------------------ models_json_path = BASE_DIR / "models.json" if not models_json_path.exists(): raise RuntimeError(f"models.json not found at {models_json_path}") with open(models_json_path, "r") as f: MODEL_PATHS = json.load(f) # Convert relative paths to absolute for cat, styles in MODEL_PATHS.items(): for style_name, rel_path in styles.items(): p = Path(rel_path) if not p.is_absolute(): MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve()) # In-memory model cache models = {} def load_model(category: str, style: str): key = (category, style) if key in models: return models[key] if category not in MODEL_PATHS or style not in MODEL_PATHS[category]: raise HTTPException(status_code=400, detail="Invalid category/style") path = MODEL_PATHS[category][style] if not os.path.exists(path): raise HTTPException(status_code=404, detail=f"Model file not found: {path}") model = TransformerNet().to(device) model.load_state_dict(torch.load(path, map_location=device)) model.eval() model = torch.jit.script(model) models[key] = model print(f"Loaded model: {category}/{style}") return model # Preload all models at startup # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM @app.on_event("startup") async def preload_all_models(): print("Preloading all models...") for cat, styles in MODEL_PATHS.items(): for style in styles: try: load_model(cat, style) except Exception as e: print(f"Warning: Could not load {cat}/{style} — {e}") print(f"Done. {len(models)} model(s) loaded.") # ------------------ IMAGE UTILS ------------------ def save_image_tensor(tensor, path: Path): img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255 Image.fromarray(img.astype("uint8")).save(path) def stylize_image(img: Image.Image, model, img_size: int = 256): transform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor() ]) x = transform(img).unsqueeze(0).to(device) with torch.no_grad(): # autocast only when GPU is available, safe no-op on CPU y = model(x) return y # ------------------ CLEANUP ------------------ async def delete_file_after_delay(path: Path, delay: int = 180): await asyncio.sleep(delay) try: if path.exists(): path.unlink() print(f"Deleted {path} after {delay}s") except Exception as e: print(f"Error deleting file: {e}") # ------------------ ROUTES ------------------ @app.get("/") async def root(): return {"message": "Backend is running!", "device": str(device)} @app.get("/api/styles") async def get_styles(): return MODEL_PATHS @app.post("/api/stylize") async def stylize( request: Request, background_tasks: BackgroundTasks, file: UploadFile = File(...), category: str = Form(...), style: str = Form(...), ): model = load_model(category, style) contents = await file.read() input_img = Image.open(io.BytesIO(contents)).convert("RGB") output_tensor = stylize_image(input_img, model) filename = f"{uuid.uuid4().hex}.jpg" out_path = OUTPUT_DIR / filename save_image_tensor(output_tensor, out_path) background_tasks.add_task(delete_file_after_delay, out_path, 180) # base_url = str(request.base_url).rstrip("/") # return {"image_url": f"{base_url}/download/{filename}"} return {"filename": filename} @app.get("/api/download/{filename}") async def download(filename: str): path = OUTPUT_DIR / filename if not path.exists(): raise HTTPException(status_code=404, detail="File not found or already deleted") return FileResponse(path, media_type="image/jpeg", filename=filename)