pluto90 commited on
Commit
9346913
·
verified ·
1 Parent(s): a2c6496
Files changed (1) hide show
  1. app.py +168 -328
app.py CHANGED
@@ -1,329 +1,169 @@
1
-
2
- # # backend/main.py
3
- # import os, io, uuid, sys, json, asyncio
4
- # from pathlib import Path
5
- # from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
- # from fastapi.responses import FileResponse, JSONResponse
7
- # from fastapi.middleware.cors import CORSMiddleware
8
- # from fastapi.staticfiles import StaticFiles
9
- # from PIL import Image
10
- # import torch
11
- # from torchvision import transforms
12
-
13
-
14
-
15
-
16
- # BASE_DIR = Path(__file__).resolve().parent
17
- # sys.path.append(str(BASE_DIR / "helpers"))
18
- # from helpers.transform_net import TransformerNet
19
-
20
- # app = FastAPI()
21
-
22
- # # -------- CORS: add your frontend origin (dev: http://localhost:5173) ----------
23
- # FRONTEND_URL = os.environ.get("FRONTEND_URL", "http://localhost:5173")
24
- # app.add_middleware(
25
- # CORSMiddleware,
26
- # allow_origins=[FRONTEND_URL],
27
- # allow_credentials=True,
28
- # allow_methods=["*"],
29
- # allow_headers=["*"],
30
- # )
31
-
32
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
-
34
- # # ---------- outputs dir ------------
35
- # OUTPUT_DIR = BASE_DIR / "outputs"
36
- # OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
37
-
38
- # # mount static so /download/<file> serves images directly
39
- # app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
40
-
41
- # # ---------- load models.json + your model cache (keep your existing logic) ----------
42
- # models_json_path = BASE_DIR / "models.json"
43
- # if not models_json_path.exists():
44
- # raise RuntimeError(f"models.json not found at {models_json_path}")
45
-
46
- # with open(models_json_path, "r") as f:
47
- # MODEL_PATHS = json.load(f)
48
-
49
- # # convert to absolute paths (your existing code)
50
- # for cat, styles in MODEL_PATHS.items():
51
- # for style_name, rel_path in styles.items():
52
- # p = Path(rel_path)
53
- # if not p.is_absolute():
54
- # MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
55
-
56
- # models = {}
57
- # def load_model(category: str, style: str):
58
- # key = (category, style)
59
- # if key in models:
60
- # return models[key]
61
- # if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
62
- # raise HTTPException(status_code=400, detail="Invalid category/style")
63
- # path = MODEL_PATHS[category][style]
64
- # if not os.path.exists(path):
65
- # raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
66
- # model = TransformerNet().to(device)
67
- # model.load_state_dict(torch.load(path, map_location=device))
68
- # model.eval()
69
- # models[key] = model
70
- # return model
71
-
72
- # def save_image_tensor(tensor, path: Path):
73
- # img = tensor.detach().float().cpu()[0].clamp(0,1).permute(1,2,0).numpy() * 255
74
- # Image.fromarray(img.astype("uint8")).save(path)
75
-
76
- # def stylize_image(img: Image.Image, model, img_size: int = 512):
77
- # transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
78
- # x = transform(img).unsqueeze(0).to(device)
79
- # with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
80
- # y = model(x)
81
- # return y
82
-
83
- # # cleanup helper
84
- # async def delete_file_after_delay(path: Path, delay: int = 180):
85
- # await asyncio.sleep(delay)
86
- # try:
87
- # if path.exists():
88
- # path.unlink()
89
- # print(f"🧹 Deleted {path} after {delay} sec")
90
- # except Exception as e:
91
- # print("⚠️ error deleting file:", e)
92
-
93
- # # --------- routes ----------
94
-
95
-
96
- # @app.get("/")
97
- # async def root():
98
- # return {"message": "Backend is running!"}
99
-
100
-
101
- # # # -------------------- API Routes --------------------
102
- # @app.get("/api/styles")
103
- # async def get_styles():
104
- # return MODEL_PATHS
105
-
106
-
107
-
108
- # @app.post("/api/stylize")
109
- # async def stylize(
110
- # request: Request,
111
- # background_tasks: BackgroundTasks,
112
- # file: UploadFile = File(...),
113
- # category: str = Form(...),
114
- # style: str = Form(...),
115
- # ):
116
- # model = load_model(category, style)
117
- # contents = await file.read()
118
- # input_img = Image.open(io.BytesIO(contents)).convert("RGB")
119
- # output_tensor = stylize_image(input_img, model)
120
- # filename = f"{uuid.uuid4().hex}.jpg"
121
- # out_path = OUTPUT_DIR / filename
122
- # save_image_tensor(output_tensor, out_path)
123
-
124
- # # schedule deletion
125
- # background_tasks.add_task(delete_file_after_delay, out_path, 180)
126
-
127
- # # build absolute URL using the request base URL (works in dev and production)
128
- # base = str(request.base_url).rstrip('/')
129
- # image_url = f"{base}/download/{filename}"
130
- # return {"image_url": image_url}
131
-
132
- # # (OPTIONAL) keep a file endpoint if you want; not necessary because StaticFiles serves it:
133
- # @app.get("/api/download/{filename}")
134
- # async def download(filename: str):
135
- # path = OUTPUT_DIR / filename
136
- # return FileResponse(path, media_type="image/jpeg", filename=filename)
137
-
138
-
139
-
140
-
141
-
142
-
143
-
144
-
145
-
146
-
147
-
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
-
156
-
157
-
158
-
159
-
160
-
161
-
162
- # backend/main.py
163
- import os, io, uuid, sys, json, asyncio
164
- from pathlib import Path
165
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
166
- from fastapi.responses import FileResponse, JSONResponse
167
- from fastapi.middleware.cors import CORSMiddleware
168
- from fastapi.staticfiles import StaticFiles
169
- from PIL import Image
170
- import torch
171
- from torchvision import transforms
172
-
173
- # ------------------ BASE SETUP ------------------
174
-
175
- BASE_DIR = Path(__file__).resolve().parent
176
- sys.path.append(str(BASE_DIR / "helpers"))
177
- from helpers.transform_net import TransformerNet
178
-
179
- app = FastAPI()
180
-
181
- # ------------------ CORS ------------------
182
- # In HF Spaces dashboard, set environment variable:
183
- # FRONTEND_URL = https://your-app.vercel.app
184
- # For local dev it defaults to localhost:5173
185
-
186
- FRONTEND_URL = os.environ.get("FRONTEND_URL", "http://localhost:5173")
187
-
188
- app.add_middleware(
189
- CORSMiddleware,
190
- allow_origins=[FRONTEND_URL],
191
- allow_credentials=True,
192
- allow_methods=["*"],
193
- allow_headers=["*"],
194
- )
195
-
196
- # ------------------ DEVICE ------------------
197
- # HF Spaces free tier = CPU only
198
- # cuda.amp.autocast is disabled on CPU to avoid warnings
199
-
200
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
- use_amp = device.type == "cuda"
202
- print(f"Running on: {device}")
203
-
204
- # ------------------ OUTPUTS ------------------
205
-
206
- OUTPUT_DIR = BASE_DIR / "outputs"
207
- OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
208
-
209
- app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
210
-
211
- # ------------------ MODELS ------------------
212
-
213
- models_json_path = BASE_DIR / "models.json"
214
- if not models_json_path.exists():
215
- raise RuntimeError(f"models.json not found at {models_json_path}")
216
-
217
- with open(models_json_path, "r") as f:
218
- MODEL_PATHS = json.load(f)
219
-
220
- # Convert relative paths to absolute
221
- for cat, styles in MODEL_PATHS.items():
222
- for style_name, rel_path in styles.items():
223
- p = Path(rel_path)
224
- if not p.is_absolute():
225
- MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
226
-
227
- # In-memory model cache
228
- models = {}
229
-
230
- def load_model(category: str, style: str):
231
- key = (category, style)
232
- if key in models:
233
- return models[key]
234
-
235
- if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
236
- raise HTTPException(status_code=400, detail="Invalid category/style")
237
-
238
- path = MODEL_PATHS[category][style]
239
- if not os.path.exists(path):
240
- raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
241
-
242
- model = TransformerNet().to(device)
243
- model.load_state_dict(torch.load(path, map_location=device))
244
- model.eval()
245
- models[key] = model
246
- print(f"Loaded model: {category}/{style}")
247
- return model
248
-
249
- # Preload all models at startup
250
- # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
251
- @app.on_event("startup")
252
- async def preload_all_models():
253
- print("Preloading all models...")
254
- for cat, styles in MODEL_PATHS.items():
255
- for style in styles:
256
- try:
257
- load_model(cat, style)
258
- except Exception as e:
259
- print(f"Warning: Could not load {cat}/{style} — {e}")
260
- print(f"Done. {len(models)} model(s) loaded.")
261
-
262
- # ------------------ IMAGE UTILS ------------------
263
-
264
- def save_image_tensor(tensor, path: Path):
265
- img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
266
- Image.fromarray(img.astype("uint8")).save(path)
267
-
268
- def stylize_image(img: Image.Image, model, img_size: int = 512):
269
- transform = transforms.Compose([
270
- transforms.Resize(img_size),
271
- transforms.ToTensor()
272
- ])
273
- x = transform(img).unsqueeze(0).to(device)
274
- with torch.no_grad():
275
- # autocast only when GPU is available, safe no-op on CPU
276
- with torch.cuda.amp.autocast(enabled=use_amp):
277
- y = model(x)
278
- return y
279
-
280
- # ------------------ CLEANUP ------------------
281
-
282
- async def delete_file_after_delay(path: Path, delay: int = 180):
283
- await asyncio.sleep(delay)
284
- try:
285
- if path.exists():
286
- path.unlink()
287
- print(f"Deleted {path} after {delay}s")
288
- except Exception as e:
289
- print(f"Error deleting file: {e}")
290
-
291
- # ------------------ ROUTES ------------------
292
-
293
- @app.get("/")
294
- async def root():
295
- return {"message": "Backend is running!", "device": str(device)}
296
-
297
- @app.get("/api/styles")
298
- async def get_styles():
299
- return MODEL_PATHS
300
-
301
- @app.post("/api/stylize")
302
- async def stylize(
303
- request: Request,
304
- background_tasks: BackgroundTasks,
305
- file: UploadFile = File(...),
306
- category: str = Form(...),
307
- style: str = Form(...),
308
- ):
309
- model = load_model(category, style)
310
-
311
- contents = await file.read()
312
- input_img = Image.open(io.BytesIO(contents)).convert("RGB")
313
- output_tensor = stylize_image(input_img, model)
314
-
315
- filename = f"{uuid.uuid4().hex}.jpg"
316
- out_path = OUTPUT_DIR / filename
317
- save_image_tensor(output_tensor, out_path)
318
-
319
- background_tasks.add_task(delete_file_after_delay, out_path, 180)
320
-
321
- base_url = str(request.base_url).rstrip("/")
322
- return {"image_url": f"{base_url}/download/{filename}"}
323
-
324
- @app.get("/api/download/{filename}")
325
- async def download(filename: str):
326
- path = OUTPUT_DIR / filename
327
- if not path.exists():
328
- raise HTTPException(status_code=404, detail="File not found or already deleted")
329
  return FileResponse(path, media_type="image/jpeg", filename=filename)
 
1
+
2
+ # backend/app.py
3
+ import os, io, uuid, sys, json, asyncio
4
+ from pathlib import Path
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
+ from fastapi.responses import FileResponse, JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from PIL import Image
10
+ import torch
11
+ from torchvision import transforms
12
+
13
+ # ------------------ BASE SETUP ------------------
14
+
15
+ BASE_DIR = Path(__file__).resolve().parent
16
+ sys.path.append(str(BASE_DIR / "helpers"))
17
+ from helpers.transform_net import TransformerNet
18
+
19
+ app = FastAPI()
20
+
21
+ # ------------------ CORS ------------------
22
+ # In HF Spaces dashboard, set environment variable:
23
+ # FRONTEND_URL = https://your-app.vercel.app
24
+ # For local dev it defaults to localhost:5173
25
+
26
+ FRONTEND_URL = os.environ.get("FRONTEND_URL", "https://image-stylizer-deploy.vercel.app/")
27
+
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=[FRONTEND_URL],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # ------------------ DEVICE ------------------
37
+ # HF Spaces free tier = CPU only
38
+ # cuda.amp.autocast is disabled on CPU to avoid warnings
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ use_amp = device.type == "cuda"
42
+ print(f"Running on: {device}")
43
+
44
+ # ------------------ OUTPUTS ------------------
45
+
46
+ OUTPUT_DIR = BASE_DIR / "outputs"
47
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
48
+
49
+ app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
50
+
51
+ # ------------------ MODELS ------------------
52
+
53
+ models_json_path = BASE_DIR / "models.json"
54
+ if not models_json_path.exists():
55
+ raise RuntimeError(f"models.json not found at {models_json_path}")
56
+
57
+ with open(models_json_path, "r") as f:
58
+ MODEL_PATHS = json.load(f)
59
+
60
+ # Convert relative paths to absolute
61
+ for cat, styles in MODEL_PATHS.items():
62
+ for style_name, rel_path in styles.items():
63
+ p = Path(rel_path)
64
+ if not p.is_absolute():
65
+ MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
66
+
67
+ # In-memory model cache
68
+ models = {}
69
+
70
+ def load_model(category: str, style: str):
71
+ key = (category, style)
72
+ if key in models:
73
+ return models[key]
74
+
75
+ if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
76
+ raise HTTPException(status_code=400, detail="Invalid category/style")
77
+
78
+ path = MODEL_PATHS[category][style]
79
+ if not os.path.exists(path):
80
+ raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
81
+
82
+ model = TransformerNet().to(device)
83
+ model.load_state_dict(torch.load(path, map_location=device))
84
+ model.eval()
85
+ models[key] = model
86
+ print(f"Loaded model: {category}/{style}")
87
+ return model
88
+
89
+ # Preload all models at startup
90
+ # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
91
+ @app.on_event("startup")
92
+ async def preload_all_models():
93
+ print("Preloading all models...")
94
+ for cat, styles in MODEL_PATHS.items():
95
+ for style in styles:
96
+ try:
97
+ load_model(cat, style)
98
+ except Exception as e:
99
+ print(f"Warning: Could not load {cat}/{style} — {e}")
100
+ print(f"Done. {len(models)} model(s) loaded.")
101
+
102
+ # ------------------ IMAGE UTILS ------------------
103
+
104
+ def save_image_tensor(tensor, path: Path):
105
+ img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
106
+ Image.fromarray(img.astype("uint8")).save(path)
107
+
108
+ def stylize_image(img: Image.Image, model, img_size: int = 512):
109
+ transform = transforms.Compose([
110
+ transforms.Resize(img_size),
111
+ transforms.ToTensor()
112
+ ])
113
+ x = transform(img).unsqueeze(0).to(device)
114
+ with torch.no_grad():
115
+ # autocast only when GPU is available, safe no-op on CPU
116
+ with torch.cuda.amp.autocast(enabled=use_amp):
117
+ y = model(x)
118
+ return y
119
+
120
+ # ------------------ CLEANUP ------------------
121
+
122
+ async def delete_file_after_delay(path: Path, delay: int = 180):
123
+ await asyncio.sleep(delay)
124
+ try:
125
+ if path.exists():
126
+ path.unlink()
127
+ print(f"Deleted {path} after {delay}s")
128
+ except Exception as e:
129
+ print(f"Error deleting file: {e}")
130
+
131
+ # ------------------ ROUTES ------------------
132
+
133
+ @app.get("/")
134
+ async def root():
135
+ return {"message": "Backend is running!", "device": str(device)}
136
+
137
+ @app.get("/api/styles")
138
+ async def get_styles():
139
+ return MODEL_PATHS
140
+
141
+ @app.post("/api/stylize")
142
+ async def stylize(
143
+ request: Request,
144
+ background_tasks: BackgroundTasks,
145
+ file: UploadFile = File(...),
146
+ category: str = Form(...),
147
+ style: str = Form(...),
148
+ ):
149
+ model = load_model(category, style)
150
+
151
+ contents = await file.read()
152
+ input_img = Image.open(io.BytesIO(contents)).convert("RGB")
153
+ output_tensor = stylize_image(input_img, model)
154
+
155
+ filename = f"{uuid.uuid4().hex}.jpg"
156
+ out_path = OUTPUT_DIR / filename
157
+ save_image_tensor(output_tensor, out_path)
158
+
159
+ background_tasks.add_task(delete_file_after_delay, out_path, 180)
160
+
161
+ base_url = str(request.base_url).rstrip("/")
162
+ return {"image_url": f"{base_url}/download/{filename}"}
163
+
164
+ @app.get("/api/download/{filename}")
165
+ async def download(filename: str):
166
+ path = OUTPUT_DIR / filename
167
+ if not path.exists():
168
+ raise HTTPException(status_code=404, detail="File not found or already deleted")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  return FileResponse(path, media_type="image/jpeg", filename=filename)