pluto90 commited on
Commit
98ce97b
·
verified ·
1 Parent(s): 5b0e7fa

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +329 -0
  2. models.json +26 -0
  3. package-lock.json +6 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # # backend/main.py
3
+ # import os, io, uuid, sys, json, asyncio
4
+ # from pathlib import Path
5
+ # from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
+ # from fastapi.responses import FileResponse, JSONResponse
7
+ # from fastapi.middleware.cors import CORSMiddleware
8
+ # from fastapi.staticfiles import StaticFiles
9
+ # from PIL import Image
10
+ # import torch
11
+ # from torchvision import transforms
12
+
13
+
14
+
15
+
16
+ # BASE_DIR = Path(__file__).resolve().parent
17
+ # sys.path.append(str(BASE_DIR / "helpers"))
18
+ # from helpers.transform_net import TransformerNet
19
+
20
+ # app = FastAPI()
21
+
22
+ # # -------- CORS: add your frontend origin (dev: http://localhost:5173) ----------
23
+ # FRONTEND_URL = os.environ.get("FRONTEND_URL", "http://localhost:5173")
24
+ # app.add_middleware(
25
+ # CORSMiddleware,
26
+ # allow_origins=[FRONTEND_URL],
27
+ # allow_credentials=True,
28
+ # allow_methods=["*"],
29
+ # allow_headers=["*"],
30
+ # )
31
+
32
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ # # ---------- outputs dir ------------
35
+ # OUTPUT_DIR = BASE_DIR / "outputs"
36
+ # OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
37
+
38
+ # # mount static so /download/<file> serves images directly
39
+ # app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
40
+
41
+ # # ---------- load models.json + your model cache (keep your existing logic) ----------
42
+ # models_json_path = BASE_DIR / "models.json"
43
+ # if not models_json_path.exists():
44
+ # raise RuntimeError(f"models.json not found at {models_json_path}")
45
+
46
+ # with open(models_json_path, "r") as f:
47
+ # MODEL_PATHS = json.load(f)
48
+
49
+ # # convert to absolute paths (your existing code)
50
+ # for cat, styles in MODEL_PATHS.items():
51
+ # for style_name, rel_path in styles.items():
52
+ # p = Path(rel_path)
53
+ # if not p.is_absolute():
54
+ # MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
55
+
56
+ # models = {}
57
+ # def load_model(category: str, style: str):
58
+ # key = (category, style)
59
+ # if key in models:
60
+ # return models[key]
61
+ # if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
62
+ # raise HTTPException(status_code=400, detail="Invalid category/style")
63
+ # path = MODEL_PATHS[category][style]
64
+ # if not os.path.exists(path):
65
+ # raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
66
+ # model = TransformerNet().to(device)
67
+ # model.load_state_dict(torch.load(path, map_location=device))
68
+ # model.eval()
69
+ # models[key] = model
70
+ # return model
71
+
72
+ # def save_image_tensor(tensor, path: Path):
73
+ # img = tensor.detach().float().cpu()[0].clamp(0,1).permute(1,2,0).numpy() * 255
74
+ # Image.fromarray(img.astype("uint8")).save(path)
75
+
76
+ # def stylize_image(img: Image.Image, model, img_size: int = 512):
77
+ # transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
78
+ # x = transform(img).unsqueeze(0).to(device)
79
+ # with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
80
+ # y = model(x)
81
+ # return y
82
+
83
+ # # cleanup helper
84
+ # async def delete_file_after_delay(path: Path, delay: int = 180):
85
+ # await asyncio.sleep(delay)
86
+ # try:
87
+ # if path.exists():
88
+ # path.unlink()
89
+ # print(f"🧹 Deleted {path} after {delay} sec")
90
+ # except Exception as e:
91
+ # print("⚠️ error deleting file:", e)
92
+
93
+ # # --------- routes ----------
94
+
95
+
96
+ # @app.get("/")
97
+ # async def root():
98
+ # return {"message": "Backend is running!"}
99
+
100
+
101
+ # # # -------------------- API Routes --------------------
102
+ # @app.get("/api/styles")
103
+ # async def get_styles():
104
+ # return MODEL_PATHS
105
+
106
+
107
+
108
+ # @app.post("/api/stylize")
109
+ # async def stylize(
110
+ # request: Request,
111
+ # background_tasks: BackgroundTasks,
112
+ # file: UploadFile = File(...),
113
+ # category: str = Form(...),
114
+ # style: str = Form(...),
115
+ # ):
116
+ # model = load_model(category, style)
117
+ # contents = await file.read()
118
+ # input_img = Image.open(io.BytesIO(contents)).convert("RGB")
119
+ # output_tensor = stylize_image(input_img, model)
120
+ # filename = f"{uuid.uuid4().hex}.jpg"
121
+ # out_path = OUTPUT_DIR / filename
122
+ # save_image_tensor(output_tensor, out_path)
123
+
124
+ # # schedule deletion
125
+ # background_tasks.add_task(delete_file_after_delay, out_path, 180)
126
+
127
+ # # build absolute URL using the request base URL (works in dev and production)
128
+ # base = str(request.base_url).rstrip('/')
129
+ # image_url = f"{base}/download/{filename}"
130
+ # return {"image_url": image_url}
131
+
132
+ # # (OPTIONAL) keep a file endpoint if you want; not necessary because StaticFiles serves it:
133
+ # @app.get("/api/download/{filename}")
134
+ # async def download(filename: str):
135
+ # path = OUTPUT_DIR / filename
136
+ # return FileResponse(path, media_type="image/jpeg", filename=filename)
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+ # backend/main.py
163
+ import os, io, uuid, sys, json, asyncio
164
+ from pathlib import Path
165
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
166
+ from fastapi.responses import FileResponse, JSONResponse
167
+ from fastapi.middleware.cors import CORSMiddleware
168
+ from fastapi.staticfiles import StaticFiles
169
+ from PIL import Image
170
+ import torch
171
+ from torchvision import transforms
172
+
173
+ # ------------------ BASE SETUP ------------------
174
+
175
+ BASE_DIR = Path(__file__).resolve().parent
176
+ sys.path.append(str(BASE_DIR / "helpers"))
177
+ from helpers.transform_net import TransformerNet
178
+
179
+ app = FastAPI()
180
+
181
+ # ------------------ CORS ------------------
182
+ # In HF Spaces dashboard, set environment variable:
183
+ # FRONTEND_URL = https://your-app.vercel.app
184
+ # For local dev it defaults to localhost:5173
185
+
186
+ FRONTEND_URL = os.environ.get("FRONTEND_URL", "http://localhost:5173")
187
+
188
+ app.add_middleware(
189
+ CORSMiddleware,
190
+ allow_origins=[FRONTEND_URL],
191
+ allow_credentials=True,
192
+ allow_methods=["*"],
193
+ allow_headers=["*"],
194
+ )
195
+
196
+ # ------------------ DEVICE ------------------
197
+ # HF Spaces free tier = CPU only
198
+ # cuda.amp.autocast is disabled on CPU to avoid warnings
199
+
200
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
+ use_amp = device.type == "cuda"
202
+ print(f"Running on: {device}")
203
+
204
+ # ------------------ OUTPUTS ------------------
205
+
206
+ OUTPUT_DIR = BASE_DIR / "outputs"
207
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
208
+
209
+ app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
210
+
211
+ # ------------------ MODELS ------------------
212
+
213
+ models_json_path = BASE_DIR / "models.json"
214
+ if not models_json_path.exists():
215
+ raise RuntimeError(f"models.json not found at {models_json_path}")
216
+
217
+ with open(models_json_path, "r") as f:
218
+ MODEL_PATHS = json.load(f)
219
+
220
+ # Convert relative paths to absolute
221
+ for cat, styles in MODEL_PATHS.items():
222
+ for style_name, rel_path in styles.items():
223
+ p = Path(rel_path)
224
+ if not p.is_absolute():
225
+ MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
226
+
227
+ # In-memory model cache
228
+ models = {}
229
+
230
+ def load_model(category: str, style: str):
231
+ key = (category, style)
232
+ if key in models:
233
+ return models[key]
234
+
235
+ if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
236
+ raise HTTPException(status_code=400, detail="Invalid category/style")
237
+
238
+ path = MODEL_PATHS[category][style]
239
+ if not os.path.exists(path):
240
+ raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
241
+
242
+ model = TransformerNet().to(device)
243
+ model.load_state_dict(torch.load(path, map_location=device))
244
+ model.eval()
245
+ models[key] = model
246
+ print(f"Loaded model: {category}/{style}")
247
+ return model
248
+
249
+ # Preload all models at startup
250
+ # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
251
+ @app.on_event("startup")
252
+ async def preload_all_models():
253
+ print("Preloading all models...")
254
+ for cat, styles in MODEL_PATHS.items():
255
+ for style in styles:
256
+ try:
257
+ load_model(cat, style)
258
+ except Exception as e:
259
+ print(f"Warning: Could not load {cat}/{style} — {e}")
260
+ print(f"Done. {len(models)} model(s) loaded.")
261
+
262
+ # ------------------ IMAGE UTILS ------------------
263
+
264
+ def save_image_tensor(tensor, path: Path):
265
+ img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
266
+ Image.fromarray(img.astype("uint8")).save(path)
267
+
268
+ def stylize_image(img: Image.Image, model, img_size: int = 512):
269
+ transform = transforms.Compose([
270
+ transforms.Resize(img_size),
271
+ transforms.ToTensor()
272
+ ])
273
+ x = transform(img).unsqueeze(0).to(device)
274
+ with torch.no_grad():
275
+ # autocast only when GPU is available, safe no-op on CPU
276
+ with torch.cuda.amp.autocast(enabled=use_amp):
277
+ y = model(x)
278
+ return y
279
+
280
+ # ------------------ CLEANUP ------------------
281
+
282
+ async def delete_file_after_delay(path: Path, delay: int = 180):
283
+ await asyncio.sleep(delay)
284
+ try:
285
+ if path.exists():
286
+ path.unlink()
287
+ print(f"Deleted {path} after {delay}s")
288
+ except Exception as e:
289
+ print(f"Error deleting file: {e}")
290
+
291
+ # ------------------ ROUTES ------------------
292
+
293
+ @app.get("/")
294
+ async def root():
295
+ return {"message": "Backend is running!", "device": str(device)}
296
+
297
+ @app.get("/api/styles")
298
+ async def get_styles():
299
+ return MODEL_PATHS
300
+
301
+ @app.post("/api/stylize")
302
+ async def stylize(
303
+ request: Request,
304
+ background_tasks: BackgroundTasks,
305
+ file: UploadFile = File(...),
306
+ category: str = Form(...),
307
+ style: str = Form(...),
308
+ ):
309
+ model = load_model(category, style)
310
+
311
+ contents = await file.read()
312
+ input_img = Image.open(io.BytesIO(contents)).convert("RGB")
313
+ output_tensor = stylize_image(input_img, model)
314
+
315
+ filename = f"{uuid.uuid4().hex}.jpg"
316
+ out_path = OUTPUT_DIR / filename
317
+ save_image_tensor(output_tensor, out_path)
318
+
319
+ background_tasks.add_task(delete_file_after_delay, out_path, 180)
320
+
321
+ base_url = str(request.base_url).rstrip("/")
322
+ return {"image_url": f"{base_url}/download/{filename}"}
323
+
324
+ @app.get("/api/download/{filename}")
325
+ async def download(filename: str):
326
+ path = OUTPUT_DIR / filename
327
+ if not path.exists():
328
+ raise HTTPException(status_code=404, detail="File not found or already deleted")
329
+ return FileResponse(path, media_type="image/jpeg", filename=filename)
models.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "human-face": {
3
+ "van-gogh": "./models/van-gogh-humans-epoch15.pth",
4
+ "claude-monet": "./models/claude-monet-humans-epoch10.pth"
5
+ },
6
+
7
+ "cars": {
8
+ "van-gogh": "./models/car-van-gogh-gpt_edgeAMP_epoch25.pth",
9
+ "claude-monet": "./models/car-claude-monet-gpt_edgeAMP_epoch12.pth"
10
+ },
11
+
12
+ "cats": {
13
+ "van-gogh": "./models/cats-van-gogh-gpt_edgeAMP_epoch12.pth",
14
+ "claude-monet": "./models/cat-claude-monet-gpt_edgeAMP_epoch12.pth"
15
+ },
16
+
17
+ "dogs": {
18
+ "van-gogh": "./models/dog-van-gogh-gpt_edgeAMP_epoch12.pth",
19
+ "claude-monet": "./models/dog-claude-monet-gpt_edgeAMP_epoch12.pth"
20
+ },
21
+
22
+ "landscape": {
23
+ "van-gogh": "./models/landscape-van-gogh-gpt_edgeAMP_epoch30.pth",
24
+ "claude-monet": "./models/landscape-claude-monet-gpt_edgeAMP_epoch30.pth"
25
+ }
26
+ }
package-lock.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "backend",
3
+ "lockfileVersion": 3,
4
+ "requires": true,
5
+ "packages": {}
6
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ torch
5
+ torchvision
6
+ python-multipart
7
+ aiofiles