pluto90 commited on
Commit
1322ef8
·
verified ·
1 Parent(s): 877ea8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -237
app.py CHANGED
@@ -1,211 +1,14 @@
1
-
2
- # # backend/app.py
3
- # import os, io, uuid, sys, json, asyncio
4
- # from pathlib import Path
5
- # from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
- # from fastapi.responses import FileResponse, JSONResponse
7
- # from fastapi.middleware.cors import CORSMiddleware
8
- # from fastapi.staticfiles import StaticFiles
9
- # from PIL import Image
10
- # import torch
11
- # from torchvision import transforms
12
-
13
- # # ------------------ BASE SETUP ------------------
14
-
15
- # BASE_DIR = Path(__file__).resolve().parent
16
- # sys.path.append(str(BASE_DIR / "helpers"))
17
- # from helpers.transform_net import TransformerNet
18
-
19
- # app = FastAPI()
20
-
21
- # # ------------------ CORS ------------------
22
- # # In HF Spaces dashboard, set environment variable:
23
- # # FRONTEND_URL = https://your-app.vercel.app
24
- # # For local dev it defaults to localhost:5173
25
-
26
- # FRONTEND_URL = os.environ.get("FRONTEND_URL")
27
- # # FRONTEND_URL = "https://image-stylizer-deploy.vercel.app"
28
-
29
- # app.add_middleware(
30
- # CORSMiddleware,
31
- # allow_origins=[
32
- # FRONTEND_URL
33
- # # "https://image-stylizer-deploy.vercel.app",
34
- # # "http://localhost:5173", # for local testing
35
- # ],
36
- # allow_credentials=True,
37
- # allow_methods=["*"],
38
- # allow_headers=["*"],
39
- # )
40
-
41
- # # ------------------ DEVICE ------------------
42
- # # HF Spaces free tier = CPU only
43
- # # cuda.amp.autocast is disabled on CPU to avoid warnings
44
-
45
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- # use_amp = device.type == "cuda"
47
- # print(f"Running on: {device}")
48
-
49
- # # ------------------ OUTPUTS ------------------
50
-
51
- # OUTPUT_DIR = BASE_DIR / "outputs"
52
- # OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
53
-
54
- # app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
55
-
56
- # # ------------------ MODELS ------------------
57
-
58
- # models_json_path = BASE_DIR / "models.json"
59
- # if not models_json_path.exists():
60
- # raise RuntimeError(f"models.json not found at {models_json_path}")
61
-
62
- # with open(models_json_path, "r") as f:
63
- # MODEL_PATHS = json.load(f)
64
-
65
- # # Convert relative paths to absolute
66
- # for cat, styles in MODEL_PATHS.items():
67
- # for style_name, rel_path in styles.items():
68
- # p = Path(rel_path)
69
- # if not p.is_absolute():
70
- # MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
71
-
72
- # # In-memory model cache
73
- # models = {}
74
-
75
- # def load_model(category: str, style: str):
76
- # key = (category, style)
77
- # if key in models:
78
- # return models[key]
79
-
80
- # if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
81
- # raise HTTPException(status_code=400, detail="Invalid category/style")
82
-
83
- # path = MODEL_PATHS[category][style]
84
- # if not os.path.exists(path):
85
- # raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
86
-
87
- # model = TransformerNet().to(device)
88
- # model.load_state_dict(torch.load(path, map_location=device))
89
- # model.eval()
90
-
91
- # model = torch.jit.script(model)
92
-
93
- # models[key] = model
94
- # print(f"Loaded model: {category}/{style}")
95
- # return model
96
-
97
- # # Preload all models at startup
98
- # # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
99
- # @app.on_event("startup")
100
- # async def preload_all_models():
101
- # print("Preloading all models...")
102
- # for cat, styles in MODEL_PATHS.items():
103
- # for style in styles:
104
- # try:
105
- # load_model(cat, style)
106
- # except Exception as e:
107
- # print(f"Warning: Could not load {cat}/{style} — {e}")
108
- # print(f"Done. {len(models)} model(s) loaded.")
109
-
110
- # # ------------------ IMAGE UTILS ------------------
111
-
112
- # def save_image_tensor(tensor, path: Path):
113
- # img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
114
- # Image.fromarray(img.astype("uint8")).save(path)
115
-
116
- # def stylize_image(img: Image.Image, model, img_size: int = 256):
117
- # transform = transforms.Compose([
118
- # transforms.Resize(img_size),
119
- # transforms.ToTensor()
120
- # ])
121
- # x = transform(img).unsqueeze(0).to(device)
122
- # with torch.no_grad():
123
- # # autocast only when GPU is available, safe no-op on CPU
124
- # y = model(x)
125
- # return y
126
-
127
- # # ------------------ CLEANUP ------------------
128
-
129
- # async def delete_file_after_delay(path: Path, delay: int = 180):
130
- # await asyncio.sleep(delay)
131
- # try:
132
- # if path.exists():
133
- # path.unlink()
134
- # print(f"Deleted {path} after {delay}s")
135
- # except Exception as e:
136
- # print(f"Error deleting file: {e}")
137
-
138
- # # ------------------ ROUTES ------------------
139
-
140
- # @app.get("/")
141
- # async def root():
142
- # return {"message": "Backend is running!", "device": str(device)}
143
-
144
- # @app.get("/api/styles")
145
- # async def get_styles():
146
- # return MODEL_PATHS
147
-
148
- # @app.post("/api/stylize")
149
- # async def stylize(
150
- # request: Request,
151
- # background_tasks: BackgroundTasks,
152
- # file: UploadFile = File(...),
153
- # category: str = Form(...),
154
- # style: str = Form(...),
155
- # ):
156
- # model = load_model(category, style)
157
-
158
- # contents = await file.read()
159
- # input_img = Image.open(io.BytesIO(contents)).convert("RGB")
160
- # output_tensor = stylize_image(input_img, model)
161
-
162
- # filename = f"{uuid.uuid4().hex}.jpg"
163
- # out_path = OUTPUT_DIR / filename
164
- # save_image_tensor(output_tensor, out_path)
165
-
166
- # background_tasks.add_task(delete_file_after_delay, out_path, 180)
167
-
168
- # base_url = str(request.base_url).rstrip("/")
169
- # return {"image_url": f"{base_url}/download/{filename}"}
170
-
171
- # @app.get("/api/download/{filename}")
172
- # async def download(filename: str):
173
- # path = OUTPUT_DIR / filename
174
- # if not path.exists():
175
- # raise HTTPException(status_code=404, detail="File not found or already deleted")
176
- # return FileResponse(path, media_type="image/jpeg", filename=filename)
177
-
178
-
179
-
180
-
181
-
182
-
183
-
184
-
185
-
186
-
187
-
188
-
189
-
190
-
191
-
192
-
193
-
194
-
195
  import os, io, uuid, sys, json, asyncio
196
  from pathlib import Path
197
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
198
- from fastapi.responses import FileResponse
199
  from fastapi.middleware.cors import CORSMiddleware
200
  from fastapi.staticfiles import StaticFiles
201
  from PIL import Image
202
  import torch
203
  from torchvision import transforms
204
 
205
- # ------------------ PERFORMANCE SETTINGS ------------------
206
-
207
- torch.set_num_threads(1) # 🔥 critical for HF CPU
208
-
209
  # ------------------ BASE SETUP ------------------
210
 
211
  BASE_DIR = Path(__file__).resolve().parent
@@ -215,20 +18,31 @@ from helpers.transform_net import TransformerNet
215
  app = FastAPI()
216
 
217
  # ------------------ CORS ------------------
 
 
 
218
 
219
  FRONTEND_URL = os.environ.get("FRONTEND_URL")
 
220
 
221
  app.add_middleware(
222
  CORSMiddleware,
223
- allow_origins=[FRONTEND_URL] if FRONTEND_URL else ["*"],
224
- allow_credentials=False,
 
 
 
 
225
  allow_methods=["*"],
226
  allow_headers=["*"],
227
  )
228
 
229
  # ------------------ DEVICE ------------------
 
 
230
 
231
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
232
  print(f"Running on: {device}")
233
 
234
  # ------------------ OUTPUTS ------------------
@@ -257,15 +71,6 @@ for cat, styles in MODEL_PATHS.items():
257
  # In-memory model cache
258
  models = {}
259
 
260
- # ------------------ GLOBAL TRANSFORM ------------------
261
-
262
- transform = transforms.Compose([
263
- transforms.Resize(256),
264
- transforms.ToTensor()
265
- ])
266
-
267
- # ------------------ MODEL LOADER ------------------
268
-
269
  def load_model(category: str, style: str):
270
  key = (category, style)
271
  if key in models:
@@ -282,36 +87,42 @@ def load_model(category: str, style: str):
282
  model.load_state_dict(torch.load(path, map_location=device))
283
  model.eval()
284
 
285
- # 🔥 TorchScript optimization
286
  model = torch.jit.script(model)
287
-
288
- # 🔥 Warmup (removes first-request delay)
289
- dummy = torch.randn(1, 3, 256, 256).to(device)
290
- with torch.no_grad():
291
- model(dummy)
292
-
293
  models[key] = model
294
  print(f"Loaded model: {category}/{style}")
295
-
296
  return model
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  # ------------------ IMAGE UTILS ------------------
299
 
300
- def stylize_image(img: Image.Image, model):
 
 
 
 
 
 
 
 
301
  x = transform(img).unsqueeze(0).to(device)
302
  with torch.no_grad():
 
303
  y = model(x)
304
  return y
305
 
306
- def save_image_tensor(tensor, path: Path):
307
- img = tensor.detach().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
308
- Image.fromarray(img.astype("uint8")).save(
309
- path,
310
- format="JPEG",
311
- quality=85,
312
- optimize=True
313
- )
314
-
315
  # ------------------ CLEANUP ------------------
316
 
317
  async def delete_file_after_delay(path: Path, delay: int = 180):
@@ -319,7 +130,7 @@ async def delete_file_after_delay(path: Path, delay: int = 180):
319
  try:
320
  if path.exists():
321
  path.unlink()
322
- print(f"Deleted {path}")
323
  except Exception as e:
324
  print(f"Error deleting file: {e}")
325
 
@@ -345,15 +156,10 @@ async def stylize(
345
 
346
  contents = await file.read()
347
  input_img = Image.open(io.BytesIO(contents)).convert("RGB")
348
-
349
- # 🔥 Run heavy task in background thread
350
- output_tensor = await asyncio.to_thread(
351
- stylize_image, input_img, model
352
- )
353
 
354
  filename = f"{uuid.uuid4().hex}.jpg"
355
  out_path = OUTPUT_DIR / filename
356
-
357
  save_image_tensor(output_tensor, out_path)
358
 
359
  background_tasks.add_task(delete_file_after_delay, out_path, 180)
@@ -365,5 +171,6 @@ async def stylize(
365
  async def download(filename: str):
366
  path = OUTPUT_DIR / filename
367
  if not path.exists():
368
- raise HTTPException(status_code=404, detail="File not found")
369
- return FileResponse(path, media_type="image/jpeg", filename=filename)
 
 
1
+ # backend/app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os, io, uuid, sys, json, asyncio
3
  from pathlib import Path
4
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
5
+ from fastapi.responses import FileResponse, JSONResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.staticfiles import StaticFiles
8
  from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
 
 
 
 
 
12
  # ------------------ BASE SETUP ------------------
13
 
14
  BASE_DIR = Path(__file__).resolve().parent
 
18
  app = FastAPI()
19
 
20
  # ------------------ CORS ------------------
21
+ # In HF Spaces dashboard, set environment variable:
22
+ # FRONTEND_URL = https://your-app.vercel.app
23
+ # For local dev it defaults to localhost:5173
24
 
25
  FRONTEND_URL = os.environ.get("FRONTEND_URL")
26
+ # FRONTEND_URL = "https://image-stylizer-deploy.vercel.app"
27
 
28
  app.add_middleware(
29
  CORSMiddleware,
30
+ allow_origins=[
31
+ FRONTEND_URL
32
+ # "https://image-stylizer-deploy.vercel.app",
33
+ # "http://localhost:5173", # for local testing
34
+ ],
35
+ allow_credentials=True,
36
  allow_methods=["*"],
37
  allow_headers=["*"],
38
  )
39
 
40
  # ------------------ DEVICE ------------------
41
+ # HF Spaces free tier = CPU only
42
+ # cuda.amp.autocast is disabled on CPU to avoid warnings
43
 
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ use_amp = device.type == "cuda"
46
  print(f"Running on: {device}")
47
 
48
  # ------------------ OUTPUTS ------------------
 
71
  # In-memory model cache
72
  models = {}
73
 
 
 
 
 
 
 
 
 
 
74
  def load_model(category: str, style: str):
75
  key = (category, style)
76
  if key in models:
 
87
  model.load_state_dict(torch.load(path, map_location=device))
88
  model.eval()
89
 
 
90
  model = torch.jit.script(model)
91
+
 
 
 
 
 
92
  models[key] = model
93
  print(f"Loaded model: {category}/{style}")
 
94
  return model
95
 
96
+ # Preload all models at startup
97
+ # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
98
+ @app.on_event("startup")
99
+ async def preload_all_models():
100
+ print("Preloading all models...")
101
+ for cat, styles in MODEL_PATHS.items():
102
+ for style in styles:
103
+ try:
104
+ load_model(cat, style)
105
+ except Exception as e:
106
+ print(f"Warning: Could not load {cat}/{style} — {e}")
107
+ print(f"Done. {len(models)} model(s) loaded.")
108
+
109
  # ------------------ IMAGE UTILS ------------------
110
 
111
+ def save_image_tensor(tensor, path: Path):
112
+ img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
113
+ Image.fromarray(img.astype("uint8")).save(path)
114
+
115
+ def stylize_image(img: Image.Image, model, img_size: int = 256):
116
+ transform = transforms.Compose([
117
+ transforms.Resize(img_size),
118
+ transforms.ToTensor()
119
+ ])
120
  x = transform(img).unsqueeze(0).to(device)
121
  with torch.no_grad():
122
+ # autocast only when GPU is available, safe no-op on CPU
123
  y = model(x)
124
  return y
125
 
 
 
 
 
 
 
 
 
 
126
  # ------------------ CLEANUP ------------------
127
 
128
  async def delete_file_after_delay(path: Path, delay: int = 180):
 
130
  try:
131
  if path.exists():
132
  path.unlink()
133
+ print(f"Deleted {path} after {delay}s")
134
  except Exception as e:
135
  print(f"Error deleting file: {e}")
136
 
 
156
 
157
  contents = await file.read()
158
  input_img = Image.open(io.BytesIO(contents)).convert("RGB")
159
+ output_tensor = stylize_image(input_img, model)
 
 
 
 
160
 
161
  filename = f"{uuid.uuid4().hex}.jpg"
162
  out_path = OUTPUT_DIR / filename
 
163
  save_image_tensor(output_tensor, out_path)
164
 
165
  background_tasks.add_task(delete_file_after_delay, out_path, 180)
 
171
  async def download(filename: str):
172
  path = OUTPUT_DIR / filename
173
  if not path.exists():
174
+ raise HTTPException(status_code=404, detail="File not found or already deleted")
175
+ return FileResponse(path, media_type="image/jpeg", filename=filename)
176
+