Spaces:
Running
Running
File size: 5,257 Bytes
1322ef8 9346913 1322ef8 9346913 bb59fed 9346913 1322ef8 9346913 1322ef8 9346913 1322ef8 9346913 877ea8f 1322ef8 9346913 877ea8f 9346913 1322ef8 9346913 1322ef8 9346913 1322ef8 877ea8f 9346913 1322ef8 9346913 1322ef8 9346913 c8ce396 9346913 1322ef8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | # backend/app.py
import os, io, uuid, sys, json, asyncio
from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from torchvision import transforms
# ------------------ BASE SETUP ------------------
BASE_DIR = Path(__file__).resolve().parent
sys.path.append(str(BASE_DIR / "helpers"))
from helpers.transform_net import TransformerNet
app = FastAPI()
# ------------------ CORS ------------------
FRONTEND_URL = os.environ.get("FRONTEND_URL")
app.add_middleware(
CORSMiddleware,
allow_origins=[
FRONTEND_URL
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------ DEVICE ------------------
# HF Spaces free tier = CPU only
# cuda.amp.autocast is disabled on CPU to avoid warnings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = device.type == "cuda"
print(f"Running on: {device}")
# ------------------ OUTPUTS ------------------
OUTPUT_DIR = BASE_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
# ------------------ MODELS ------------------
models_json_path = BASE_DIR / "models.json"
if not models_json_path.exists():
raise RuntimeError(f"models.json not found at {models_json_path}")
with open(models_json_path, "r") as f:
MODEL_PATHS = json.load(f)
# Convert relative paths to absolute
for cat, styles in MODEL_PATHS.items():
for style_name, rel_path in styles.items():
p = Path(rel_path)
if not p.is_absolute():
MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
# In-memory model cache
models = {}
def load_model(category: str, style: str):
key = (category, style)
if key in models:
return models[key]
if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
raise HTTPException(status_code=400, detail="Invalid category/style")
path = MODEL_PATHS[category][style]
if not os.path.exists(path):
raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
model = TransformerNet().to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
model = torch.jit.script(model)
models[key] = model
print(f"Loaded model: {category}/{style}")
return model
# Preload all models at startup
# Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
@app.on_event("startup")
async def preload_all_models():
print("Preloading all models...")
for cat, styles in MODEL_PATHS.items():
for style in styles:
try:
load_model(cat, style)
except Exception as e:
print(f"Warning: Could not load {cat}/{style} — {e}")
print(f"Done. {len(models)} model(s) loaded.")
# ------------------ IMAGE UTILS ------------------
def save_image_tensor(tensor, path: Path):
img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
Image.fromarray(img.astype("uint8")).save(path)
def stylize_image(img: Image.Image, model, img_size: int = 256):
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor()
])
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
# autocast only when GPU is available, safe no-op on CPU
y = model(x)
return y
# ------------------ CLEANUP ------------------
async def delete_file_after_delay(path: Path, delay: int = 180):
await asyncio.sleep(delay)
try:
if path.exists():
path.unlink()
print(f"Deleted {path} after {delay}s")
except Exception as e:
print(f"Error deleting file: {e}")
# ------------------ ROUTES ------------------
@app.get("/")
async def root():
return {"message": "Backend is running!", "device": str(device)}
@app.get("/api/styles")
async def get_styles():
return MODEL_PATHS
@app.post("/api/stylize")
async def stylize(
request: Request,
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
category: str = Form(...),
style: str = Form(...),
):
model = load_model(category, style)
contents = await file.read()
input_img = Image.open(io.BytesIO(contents)).convert("RGB")
output_tensor = stylize_image(input_img, model)
filename = f"{uuid.uuid4().hex}.jpg"
out_path = OUTPUT_DIR / filename
save_image_tensor(output_tensor, out_path)
background_tasks.add_task(delete_file_after_delay, out_path, 180)
# base_url = str(request.base_url).rstrip("/")
# return {"image_url": f"{base_url}/download/{filename}"}
return {"filename": filename}
@app.get("/api/download/{filename}")
async def download(filename: str):
path = OUTPUT_DIR / filename
if not path.exists():
raise HTTPException(status_code=404, detail="File not found or already deleted")
return FileResponse(path, media_type="image/jpeg", filename=filename)
|