pluto90 commited on
Commit
877ea8f
·
verified ·
1 Parent(s): bb59fed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -43
app.py CHANGED
@@ -1,15 +1,211 @@
1
 
2
- # backend/app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os, io, uuid, sys, json, asyncio
4
  from pathlib import Path
5
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
- from fastapi.responses import FileResponse, JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
9
  from PIL import Image
10
  import torch
11
  from torchvision import transforms
12
 
 
 
 
 
13
  # ------------------ BASE SETUP ------------------
14
 
15
  BASE_DIR = Path(__file__).resolve().parent
@@ -19,31 +215,20 @@ from helpers.transform_net import TransformerNet
19
  app = FastAPI()
20
 
21
  # ------------------ CORS ------------------
22
- # In HF Spaces dashboard, set environment variable:
23
- # FRONTEND_URL = https://your-app.vercel.app
24
- # For local dev it defaults to localhost:5173
25
 
26
  FRONTEND_URL = os.environ.get("FRONTEND_URL")
27
- # FRONTEND_URL = "https://image-stylizer-deploy.vercel.app"
28
 
29
  app.add_middleware(
30
  CORSMiddleware,
31
- allow_origins=[
32
- FRONTEND_URL
33
- # "https://image-stylizer-deploy.vercel.app",
34
- # "http://localhost:5173", # for local testing
35
- ],
36
- allow_credentials=True,
37
  allow_methods=["*"],
38
  allow_headers=["*"],
39
  )
40
 
41
  # ------------------ DEVICE ------------------
42
- # HF Spaces free tier = CPU only
43
- # cuda.amp.autocast is disabled on CPU to avoid warnings
44
 
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- use_amp = device.type == "cuda"
47
  print(f"Running on: {device}")
48
 
49
  # ------------------ OUTPUTS ------------------
@@ -72,6 +257,15 @@ for cat, styles in MODEL_PATHS.items():
72
  # In-memory model cache
73
  models = {}
74
 
 
 
 
 
 
 
 
 
 
75
  def load_model(category: str, style: str):
76
  key = (category, style)
77
  if key in models:
@@ -87,41 +281,37 @@ def load_model(category: str, style: str):
87
  model = TransformerNet().to(device)
88
  model.load_state_dict(torch.load(path, map_location=device))
89
  model.eval()
 
 
 
 
 
 
 
 
 
90
  models[key] = model
91
  print(f"Loaded model: {category}/{style}")
92
- return model
93
 
94
- # Preload all models at startup
95
- # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
96
- @app.on_event("startup")
97
- async def preload_all_models():
98
- print("Preloading all models...")
99
- for cat, styles in MODEL_PATHS.items():
100
- for style in styles:
101
- try:
102
- load_model(cat, style)
103
- except Exception as e:
104
- print(f"Warning: Could not load {cat}/{style} — {e}")
105
- print(f"Done. {len(models)} model(s) loaded.")
106
 
107
  # ------------------ IMAGE UTILS ------------------
108
 
109
- def save_image_tensor(tensor, path: Path):
110
- img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
111
- Image.fromarray(img.astype("uint8")).save(path)
112
-
113
- def stylize_image(img: Image.Image, model, img_size: int = 256):
114
- transform = transforms.Compose([
115
- transforms.Resize(img_size),
116
- transforms.ToTensor()
117
- ])
118
  x = transform(img).unsqueeze(0).to(device)
119
  with torch.no_grad():
120
- # autocast only when GPU is available, safe no-op on CPU
121
- # with torch.cuda.amp.autocast(enabled=use_amp):
122
- y = model(x)
123
  return y
124
 
 
 
 
 
 
 
 
 
 
125
  # ------------------ CLEANUP ------------------
126
 
127
  async def delete_file_after_delay(path: Path, delay: int = 180):
@@ -129,7 +319,7 @@ async def delete_file_after_delay(path: Path, delay: int = 180):
129
  try:
130
  if path.exists():
131
  path.unlink()
132
- print(f"Deleted {path} after {delay}s")
133
  except Exception as e:
134
  print(f"Error deleting file: {e}")
135
 
@@ -155,10 +345,15 @@ async def stylize(
155
 
156
  contents = await file.read()
157
  input_img = Image.open(io.BytesIO(contents)).convert("RGB")
158
- output_tensor = stylize_image(input_img, model)
 
 
 
 
159
 
160
  filename = f"{uuid.uuid4().hex}.jpg"
161
  out_path = OUTPUT_DIR / filename
 
162
  save_image_tensor(output_tensor, out_path)
163
 
164
  background_tasks.add_task(delete_file_after_delay, out_path, 180)
@@ -170,5 +365,5 @@ async def stylize(
170
  async def download(filename: str):
171
  path = OUTPUT_DIR / filename
172
  if not path.exists():
173
- raise HTTPException(status_code=404, detail="File not found or already deleted")
174
  return FileResponse(path, media_type="image/jpeg", filename=filename)
 
1
 
2
+ # # backend/app.py
3
+ # import os, io, uuid, sys, json, asyncio
4
+ # from pathlib import Path
5
+ # from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
6
+ # from fastapi.responses import FileResponse, JSONResponse
7
+ # from fastapi.middleware.cors import CORSMiddleware
8
+ # from fastapi.staticfiles import StaticFiles
9
+ # from PIL import Image
10
+ # import torch
11
+ # from torchvision import transforms
12
+
13
+ # # ------------------ BASE SETUP ------------------
14
+
15
+ # BASE_DIR = Path(__file__).resolve().parent
16
+ # sys.path.append(str(BASE_DIR / "helpers"))
17
+ # from helpers.transform_net import TransformerNet
18
+
19
+ # app = FastAPI()
20
+
21
+ # # ------------------ CORS ------------------
22
+ # # In HF Spaces dashboard, set environment variable:
23
+ # # FRONTEND_URL = https://your-app.vercel.app
24
+ # # For local dev it defaults to localhost:5173
25
+
26
+ # FRONTEND_URL = os.environ.get("FRONTEND_URL")
27
+ # # FRONTEND_URL = "https://image-stylizer-deploy.vercel.app"
28
+
29
+ # app.add_middleware(
30
+ # CORSMiddleware,
31
+ # allow_origins=[
32
+ # FRONTEND_URL
33
+ # # "https://image-stylizer-deploy.vercel.app",
34
+ # # "http://localhost:5173", # for local testing
35
+ # ],
36
+ # allow_credentials=True,
37
+ # allow_methods=["*"],
38
+ # allow_headers=["*"],
39
+ # )
40
+
41
+ # # ------------------ DEVICE ------------------
42
+ # # HF Spaces free tier = CPU only
43
+ # # cuda.amp.autocast is disabled on CPU to avoid warnings
44
+
45
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ # use_amp = device.type == "cuda"
47
+ # print(f"Running on: {device}")
48
+
49
+ # # ------------------ OUTPUTS ------------------
50
+
51
+ # OUTPUT_DIR = BASE_DIR / "outputs"
52
+ # OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
53
+
54
+ # app.mount("/download", StaticFiles(directory=str(OUTPUT_DIR)), name="download")
55
+
56
+ # # ------------------ MODELS ------------------
57
+
58
+ # models_json_path = BASE_DIR / "models.json"
59
+ # if not models_json_path.exists():
60
+ # raise RuntimeError(f"models.json not found at {models_json_path}")
61
+
62
+ # with open(models_json_path, "r") as f:
63
+ # MODEL_PATHS = json.load(f)
64
+
65
+ # # Convert relative paths to absolute
66
+ # for cat, styles in MODEL_PATHS.items():
67
+ # for style_name, rel_path in styles.items():
68
+ # p = Path(rel_path)
69
+ # if not p.is_absolute():
70
+ # MODEL_PATHS[cat][style_name] = str((BASE_DIR / rel_path).resolve())
71
+
72
+ # # In-memory model cache
73
+ # models = {}
74
+
75
+ # def load_model(category: str, style: str):
76
+ # key = (category, style)
77
+ # if key in models:
78
+ # return models[key]
79
+
80
+ # if category not in MODEL_PATHS or style not in MODEL_PATHS[category]:
81
+ # raise HTTPException(status_code=400, detail="Invalid category/style")
82
+
83
+ # path = MODEL_PATHS[category][style]
84
+ # if not os.path.exists(path):
85
+ # raise HTTPException(status_code=404, detail=f"Model file not found: {path}")
86
+
87
+ # model = TransformerNet().to(device)
88
+ # model.load_state_dict(torch.load(path, map_location=device))
89
+ # model.eval()
90
+
91
+ # model = torch.jit.script(model)
92
+
93
+ # models[key] = model
94
+ # print(f"Loaded model: {category}/{style}")
95
+ # return model
96
+
97
+ # # Preload all models at startup
98
+ # # Since each model is only 10-11 MB, all fit easily in 16 GB free RAM
99
+ # @app.on_event("startup")
100
+ # async def preload_all_models():
101
+ # print("Preloading all models...")
102
+ # for cat, styles in MODEL_PATHS.items():
103
+ # for style in styles:
104
+ # try:
105
+ # load_model(cat, style)
106
+ # except Exception as e:
107
+ # print(f"Warning: Could not load {cat}/{style} — {e}")
108
+ # print(f"Done. {len(models)} model(s) loaded.")
109
+
110
+ # # ------------------ IMAGE UTILS ------------------
111
+
112
+ # def save_image_tensor(tensor, path: Path):
113
+ # img = tensor.detach().float().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
114
+ # Image.fromarray(img.astype("uint8")).save(path)
115
+
116
+ # def stylize_image(img: Image.Image, model, img_size: int = 256):
117
+ # transform = transforms.Compose([
118
+ # transforms.Resize(img_size),
119
+ # transforms.ToTensor()
120
+ # ])
121
+ # x = transform(img).unsqueeze(0).to(device)
122
+ # with torch.no_grad():
123
+ # # autocast only when GPU is available, safe no-op on CPU
124
+ # y = model(x)
125
+ # return y
126
+
127
+ # # ------------------ CLEANUP ------------------
128
+
129
+ # async def delete_file_after_delay(path: Path, delay: int = 180):
130
+ # await asyncio.sleep(delay)
131
+ # try:
132
+ # if path.exists():
133
+ # path.unlink()
134
+ # print(f"Deleted {path} after {delay}s")
135
+ # except Exception as e:
136
+ # print(f"Error deleting file: {e}")
137
+
138
+ # # ------------------ ROUTES ------------------
139
+
140
+ # @app.get("/")
141
+ # async def root():
142
+ # return {"message": "Backend is running!", "device": str(device)}
143
+
144
+ # @app.get("/api/styles")
145
+ # async def get_styles():
146
+ # return MODEL_PATHS
147
+
148
+ # @app.post("/api/stylize")
149
+ # async def stylize(
150
+ # request: Request,
151
+ # background_tasks: BackgroundTasks,
152
+ # file: UploadFile = File(...),
153
+ # category: str = Form(...),
154
+ # style: str = Form(...),
155
+ # ):
156
+ # model = load_model(category, style)
157
+
158
+ # contents = await file.read()
159
+ # input_img = Image.open(io.BytesIO(contents)).convert("RGB")
160
+ # output_tensor = stylize_image(input_img, model)
161
+
162
+ # filename = f"{uuid.uuid4().hex}.jpg"
163
+ # out_path = OUTPUT_DIR / filename
164
+ # save_image_tensor(output_tensor, out_path)
165
+
166
+ # background_tasks.add_task(delete_file_after_delay, out_path, 180)
167
+
168
+ # base_url = str(request.base_url).rstrip("/")
169
+ # return {"image_url": f"{base_url}/download/{filename}"}
170
+
171
+ # @app.get("/api/download/{filename}")
172
+ # async def download(filename: str):
173
+ # path = OUTPUT_DIR / filename
174
+ # if not path.exists():
175
+ # raise HTTPException(status_code=404, detail="File not found or already deleted")
176
+ # return FileResponse(path, media_type="image/jpeg", filename=filename)
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
  import os, io, uuid, sys, json, asyncio
196
  from pathlib import Path
197
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request, BackgroundTasks
198
+ from fastapi.responses import FileResponse
199
  from fastapi.middleware.cors import CORSMiddleware
200
  from fastapi.staticfiles import StaticFiles
201
  from PIL import Image
202
  import torch
203
  from torchvision import transforms
204
 
205
+ # ------------------ PERFORMANCE SETTINGS ------------------
206
+
207
+ torch.set_num_threads(1) # 🔥 critical for HF CPU
208
+
209
  # ------------------ BASE SETUP ------------------
210
 
211
  BASE_DIR = Path(__file__).resolve().parent
 
215
  app = FastAPI()
216
 
217
  # ------------------ CORS ------------------
 
 
 
218
 
219
  FRONTEND_URL = os.environ.get("FRONTEND_URL")
 
220
 
221
  app.add_middleware(
222
  CORSMiddleware,
223
+ allow_origins=[FRONTEND_URL] if FRONTEND_URL else ["*"],
224
+ allow_credentials=False,
 
 
 
 
225
  allow_methods=["*"],
226
  allow_headers=["*"],
227
  )
228
 
229
  # ------------------ DEVICE ------------------
 
 
230
 
231
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
232
  print(f"Running on: {device}")
233
 
234
  # ------------------ OUTPUTS ------------------
 
257
  # In-memory model cache
258
  models = {}
259
 
260
+ # ------------------ GLOBAL TRANSFORM ------------------
261
+
262
+ transform = transforms.Compose([
263
+ transforms.Resize(256),
264
+ transforms.ToTensor()
265
+ ])
266
+
267
+ # ------------------ MODEL LOADER ------------------
268
+
269
  def load_model(category: str, style: str):
270
  key = (category, style)
271
  if key in models:
 
281
  model = TransformerNet().to(device)
282
  model.load_state_dict(torch.load(path, map_location=device))
283
  model.eval()
284
+
285
+ # 🔥 TorchScript optimization
286
+ model = torch.jit.script(model)
287
+
288
+ # 🔥 Warmup (removes first-request delay)
289
+ dummy = torch.randn(1, 3, 256, 256).to(device)
290
+ with torch.no_grad():
291
+ model(dummy)
292
+
293
  models[key] = model
294
  print(f"Loaded model: {category}/{style}")
 
295
 
296
+ return model
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  # ------------------ IMAGE UTILS ------------------
299
 
300
+ def stylize_image(img: Image.Image, model):
 
 
 
 
 
 
 
 
301
  x = transform(img).unsqueeze(0).to(device)
302
  with torch.no_grad():
303
+ y = model(x)
 
 
304
  return y
305
 
306
+ def save_image_tensor(tensor, path: Path):
307
+ img = tensor.detach().cpu()[0].clamp(0, 1).permute(1, 2, 0).numpy() * 255
308
+ Image.fromarray(img.astype("uint8")).save(
309
+ path,
310
+ format="JPEG",
311
+ quality=85,
312
+ optimize=True
313
+ )
314
+
315
  # ------------------ CLEANUP ------------------
316
 
317
  async def delete_file_after_delay(path: Path, delay: int = 180):
 
319
  try:
320
  if path.exists():
321
  path.unlink()
322
+ print(f"Deleted {path}")
323
  except Exception as e:
324
  print(f"Error deleting file: {e}")
325
 
 
345
 
346
  contents = await file.read()
347
  input_img = Image.open(io.BytesIO(contents)).convert("RGB")
348
+
349
+ # 🔥 Run heavy task in background thread
350
+ output_tensor = await asyncio.to_thread(
351
+ stylize_image, input_img, model
352
+ )
353
 
354
  filename = f"{uuid.uuid4().hex}.jpg"
355
  out_path = OUTPUT_DIR / filename
356
+
357
  save_image_tensor(output_tensor, out_path)
358
 
359
  background_tasks.add_task(delete_file_after_delay, out_path, 180)
 
365
  async def download(filename: str):
366
  path = OUTPUT_DIR / filename
367
  if not path.exists():
368
+ raise HTTPException(status_code=404, detail="File not found")
369
  return FileResponse(path, media_type="image/jpeg", filename=filename)