Spaces:
Running
Running
| # backend/app.py | |
| import os, io, uuid, sys, json, asyncio | |
| from pathlib import Path | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| # ------------------ BASE SETUP ------------------ | |
| BASE_DIR = Path(__file__).resolve().parent | |
| sys.path.append(str(BASE_DIR / "helpers")) | |
| from helpers.transform_net import TransformerNet | |
| app = FastAPI() | |
| # ------------------ CORS ------------------ | |
| FRONTEND_URL = os.environ.get("FRONTEND_URL") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| FRONTEND_URL | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------ DEVICE ------------------ | |
| # HF Spaces free tier = CPU only | |
| # cuda.amp.autocast is disabled on CPU to avoid warnings | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| use_amp = device.type == "cuda" | |
| print(f"Running on: {device}") | |
| # ------------------ OUTPUTS ------------------ | |
| OUTPUT_DIR = BASE_DIR / "outputs" | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download") | |
| # ------------------ MODELS ------------------ | |
| models_json_path = BASE_DIR / "models.json" | |
| if not models_json_path.exists(): | |
| raise RuntimeError(f"models.json not found at {models_json_path}") | |
| with open(models_json_path, "r") as f: | |
| MODEL_PATHS = json.load(f) | |
| # Convert relative paths to absolute | |
| for cat, styles in MODEL_PATHS.items(): | |
| for style_name, rel_path in styles.items(): | |
| p = Path(rel_path) | |
| if not p.is_absolute(): | |
| MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve()) | |
| # In-memory model cache | |
| models = {} | |
| def load_model(category: str, style: str): | |
| key = (category, style) | |
| if key in models: | |
| return models[key] | |
| if category not in MODEL_PATHS or style not in MODEL_PATHS[category]: | |
| raise HTTPException(status_code=400, detail="Invalid category/style") | |
| path = MODEL_PATHS[category][style] | |
| if not os.path.exists(path): | |
| raise HTTPException(status_code=404, detail=f"Model file not found: {path}") | |
| model = TransformerNet().to(device) | |
| model.load_state_dict(torch.load(path, map_location=device)) | |
| model.eval() | |
| model = torch.jit.script(model) | |
| models[key] = model | |
| print(f"Loaded model: {category}/{style}") | |
| return model | |
| # Preload all models at startup | |
| # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM | |
| async def preload_all_models(): | |
| print("Preloading all models...") | |
| for cat, styles in MODEL_PATHS.items(): | |
| for style in styles: | |
| try: | |
| load_model(cat, style) | |
| except Exception as e: | |
| print(f"Warning: Could not load {cat}/{style} — {e}") | |
| print(f"Done. {len(models)} model(s) loaded.") | |
| # ------------------ IMAGE UTILS ------------------ | |
| def save_image_tensor(tensor, path: Path): | |
| img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255 | |
| Image.fromarray(img.astype("uint8")).save(path) | |
| def stylize_image(img: Image.Image, model, img_size: int = 256): | |
| transform = transforms.Compose([ | |
| transforms.Resize(img_size), | |
| transforms.ToTensor() | |
| ]) | |
| x = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| # autocast only when GPU is available, safe no-op on CPU | |
| y = model(x) | |
| return y | |
| # ------------------ CLEANUP ------------------ | |
| async def delete_file_after_delay(path: Path, delay: int = 180): | |
| await asyncio.sleep(delay) | |
| try: | |
| if path.exists(): | |
| path.unlink() | |
| print(f"Deleted {path} after {delay}s") | |
| except Exception as e: | |
| print(f"Error deleting file: {e}") | |
| # ------------------ ROUTES ------------------ | |
| async def root(): | |
| return {"message": "Backend is running!", "device": str(device)} | |
| async def get_styles(): | |
| return MODEL_PATHS | |
| async def stylize( | |
| request: Request, | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| category: str = Form(...), | |
| style: str = Form(...), | |
| ): | |
| model = load_model(category, style) | |
| contents = await file.read() | |
| input_img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| output_tensor = stylize_image(input_img, model) | |
| filename = f"{uuid.uuid4().hex}.jpg" | |
| out_path = OUTPUT_DIR / filename | |
| save_image_tensor(output_tensor, out_path) | |
| background_tasks.add_task(delete_file_after_delay, out_path, 180) | |
| # base_url = str(request.base_url).rstrip("/") | |
| # return {"image_url": f"{base_url}/download/{filename}"} | |
| return {"filename": filename} | |
| async def download(filename: str): | |
| path = OUTPUT_DIR / filename | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail="File not found or already deleted") | |
| return FileResponse(path, media_type="image/jpeg", filename=filename) | |