jenithjain commited on
Commit
b408502
Β·
1 Parent(s): 59c882f

Deploy DuneNet FastAPI server with Segformer model

Browse files
Files changed (5) hide show
  1. Dockerfile +26 -0
  2. README.md +20 -5
  3. main.py +496 -0
  4. models/latest_model_ft.pth +3 -0
  5. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ libgl1-mesa-glx \
8
+ libglib2.0-0 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for caching
12
+ COPY requirements.txt .
13
+
14
+ # Install CPU-only PyTorch first, then other deps
15
+ RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
16
+ pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code and model
19
+ COPY main.py .
20
+ COPY models/ models/
21
+
22
+ # Expose port 7860 (Hugging Face Spaces default)
23
+ EXPOSE 7860
24
+
25
+ # Run the server
26
+ CMD ["python", "main.py"]
README.md CHANGED
@@ -1,10 +1,25 @@
1
  ---
2
- title: Dunenet Api
3
- emoji: πŸ“‰
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DuneNet Model API
3
+ emoji: πŸš€
4
+ colorFrom: yellow
5
+ colorTo: orange
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # DuneNet Model API
12
+
13
+ FastAPI backend for DuneNet - Autonomous UGV Perception Platform.
14
+
15
+ Runs a fine-tuned Segformer (nvidia/mit-b4) model for semantic segmentation of desert terrain, providing:
16
+ - Semantic segmentation masks
17
+ - Traversability maps for autonomous navigation
18
+ - Live simulation inference with costmap grids
19
+
20
+ ## Endpoints
21
+
22
+ - `GET /` β€” Health check
23
+ - `POST /predict` β€” Full segmentation prediction
24
+ - `POST /predict/sim` β€” Simulation-optimized prediction with traversability grid
25
+ - `GET /model/info` β€” Model metadata
main.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import io
8
+ import base64
9
+ import numpy as np
10
+ from typing import Optional
11
+ import uvicorn
12
+ import albumentations as A
13
+ from albumentations.pytorch import ToTensorV2
14
+ from transformers import SegformerConfig, SegformerForSemanticSegmentation
15
+
16
+ app = FastAPI(title="DuneNet Model API", version="1.0.0")
17
+
18
+ # CORS middleware for Next.js frontend
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=[
22
+ "http://localhost:3000",
23
+ "https://*.vercel.app",
24
+ "*",
25
+ ],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Configuration
32
+ NUM_CLASSES = 10
33
+ IMG_SIZE = 512
34
+ MODEL_NAME = 'nvidia/mit-b4'
35
+
36
+ CLASS_NAMES = [
37
+ 'Trees', 'Lush Bushes', 'Dry Grass', 'Dry Bushes', 'Ground Clutter',
38
+ 'Flowers', 'Logs', 'Rocks', 'Landscape', 'Sky'
39
+ ]
40
+
41
+ CLASS_COLORS = np.array([
42
+ [34, 139, 34], # Trees
43
+ [0, 255, 127], # Lush Bushes
44
+ [189, 183, 107], # Dry Grass
45
+ [139, 119, 101], # Dry Bushes
46
+ [160, 82, 45], # Ground Clutter
47
+ [255, 105, 180], # Flowers
48
+ [139, 69, 19], # Logs
49
+ [128, 128, 128], # Rocks
50
+ [210, 180, 140], # Landscape
51
+ [135, 206, 235], # Sky
52
+ ], dtype=np.uint8)
53
+
54
+ # Traversability mapping
55
+ TRAVERSABILITY = {
56
+ 0: 'no_go', # Trees
57
+ 1: 'no_go', # Lush Bushes
58
+ 2: 'go', # Dry Grass
59
+ 3: 'caution', # Dry Bushes
60
+ 4: 'caution', # Ground Clutter
61
+ 5: 'go', # Flowers
62
+ 6: 'no_go', # Logs
63
+ 7: 'caution', # Rocks
64
+ 8: 'go', # Landscape
65
+ 9: 'sky', # Sky
66
+ }
67
+
68
+ TRAV_COLORS = {
69
+ 'go': np.array([0, 200, 0], dtype=np.uint8), # Green
70
+ 'caution': np.array([255, 180, 0], dtype=np.uint8), # Orange
71
+ 'no_go': np.array([220, 30, 30], dtype=np.uint8), # Red
72
+ 'sky': np.array([180, 210, 240], dtype=np.uint8), # Light blue
73
+ }
74
+
75
+ # Global model variable
76
+ model = None
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+
79
+ class PredictionResponse(BaseModel):
80
+ prediction: int
81
+ class_name: str
82
+ confidence: float
83
+ device_used: str
84
+ class_distribution: dict
85
+ segmentation_mask: str # base64 encoded image
86
+ overlay_image: str # base64 encoded overlay
87
+ traversability_map: str # base64 encoded traversability
88
+ traversability_overlay: str # base64 encoded traversability overlay
89
+ traversability_stats: dict # safe, caution, blocked percentages
90
+
91
+ class HealthResponse(BaseModel):
92
+ status: str
93
+ model_loaded: bool
94
+ device: str
95
+
96
+ @app.on_event("startup")
97
+ async def load_model():
98
+ """Load the Segformer model on startup"""
99
+ global model
100
+ try:
101
+ import os
102
+ possible_paths = [
103
+ "models/latest_model_ft.pth",
104
+ "api_server/models/latest_model_ft.pth",
105
+ os.path.join(os.path.dirname(__file__), "models/latest_model_ft.pth"),
106
+ ]
107
+
108
+ model_path = None
109
+ for path in possible_paths:
110
+ if os.path.exists(path):
111
+ model_path = path
112
+ break
113
+
114
+ if model_path is None:
115
+ raise FileNotFoundError("latest_model_ft.pth not found in api_server/models/")
116
+
117
+ print(f"Loading Segformer model from: {os.path.abspath(model_path)}")
118
+
119
+ # Build Segformer model
120
+ config = SegformerConfig.from_pretrained(MODEL_NAME)
121
+ config.num_labels = NUM_CLASSES
122
+ model = SegformerForSemanticSegmentation(config)
123
+
124
+ # Load checkpoint
125
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
126
+ model.load_state_dict(checkpoint['model_state_dict'])
127
+
128
+ model = model.to(device)
129
+ model.eval()
130
+
131
+ miou = checkpoint.get('miou', 0)
132
+ epoch = checkpoint.get('epoch', '?')
133
+
134
+ print(f"βœ“ Segformer model loaded successfully on {device}")
135
+ print(f" Epoch: {epoch}, Val mIoU: {miou:.4f}")
136
+ print(f" Classes: {NUM_CLASSES}")
137
+ print(f" Model: {MODEL_NAME}")
138
+
139
+ except Exception as e:
140
+ print(f"βœ— Error loading model: {e}")
141
+ import traceback
142
+ traceback.print_exc()
143
+ model = None
144
+
145
+ @app.get("/", response_model=HealthResponse)
146
+ async def health_check():
147
+ """Health check endpoint"""
148
+ return {
149
+ "status": "running",
150
+ "model_loaded": model is not None,
151
+ "device": str(device)
152
+ }
153
+
154
+ def colorize_mask(class_mask):
155
+ """Convert class mask to RGB colored image"""
156
+ h, w = class_mask.shape
157
+ rgb = np.zeros((h, w, 3), dtype=np.uint8)
158
+ for c in range(NUM_CLASSES):
159
+ rgb[class_mask == c] = CLASS_COLORS[c]
160
+ return rgb
161
+
162
+
163
+ def create_overlay(image_np, class_mask, alpha=0.5):
164
+ """Blend original image with colored segmentation mask"""
165
+ colored = colorize_mask(class_mask)
166
+ overlay = (image_np.astype(np.float32) * (1 - alpha) + colored.astype(np.float32) * alpha)
167
+ return overlay.astype(np.uint8)
168
+
169
+
170
+ def create_traversability_map(class_mask):
171
+ """Generate traversability map from segmentation mask"""
172
+ h, w = class_mask.shape
173
+ trav_mask = np.zeros((h, w, 3), dtype=np.uint8)
174
+
175
+ for class_id, category in TRAVERSABILITY.items():
176
+ region = (class_mask == class_id)
177
+ trav_mask[region] = TRAV_COLORS[category]
178
+
179
+ return trav_mask
180
+
181
+
182
+ def calculate_traversability_stats(class_mask):
183
+ """Calculate traversability statistics"""
184
+ total_pixels = class_mask.size
185
+ sky_pixels = (class_mask == 9).sum() # Sky class
186
+ ground_pixels = total_pixels - sky_pixels
187
+
188
+ if ground_pixels == 0:
189
+ return {'safe': '0%', 'caution': '0%', 'blocked': '0%'}
190
+
191
+ safe_pixels = 0
192
+ caution_pixels = 0
193
+ blocked_pixels = 0
194
+
195
+ for class_id, category in TRAVERSABILITY.items():
196
+ if category == 'sky':
197
+ continue
198
+ count = (class_mask == class_id).sum()
199
+ if category == 'go':
200
+ safe_pixels += count
201
+ elif category == 'caution':
202
+ caution_pixels += count
203
+ elif category == 'no_go':
204
+ blocked_pixels += count
205
+
206
+ return {
207
+ 'safe': f"{(safe_pixels / ground_pixels * 100):.1f}%",
208
+ 'caution': f"{(caution_pixels / ground_pixels * 100):.1f}%",
209
+ 'blocked': f"{(blocked_pixels / ground_pixels * 100):.1f}%"
210
+ }
211
+
212
+
213
+ def numpy_to_base64(image_np):
214
+ """Convert numpy array to base64 string"""
215
+ img = Image.fromarray(image_np)
216
+ buffered = io.BytesIO()
217
+ img.save(buffered, format="PNG")
218
+ img_str = base64.b64encode(buffered.getvalue()).decode()
219
+ return f"data:image/png;base64,{img_str}"
220
+
221
+
222
+ @app.post("/predict", response_model=PredictionResponse)
223
+ async def predict(file: UploadFile = File(...)):
224
+ """Make prediction on uploaded image using Segformer"""
225
+ if model is None:
226
+ raise HTTPException(status_code=503, detail="Model not loaded")
227
+
228
+ try:
229
+ # Read and process image
230
+ contents = await file.read()
231
+ image = Image.open(io.BytesIO(contents)).convert('RGB')
232
+ image_np = np.array(image)
233
+ orig_h, orig_w = image_np.shape[:2]
234
+
235
+ # Preprocessing with albumentations
236
+ transform = A.Compose([
237
+ A.Resize(height=IMG_SIZE, width=IMG_SIZE),
238
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
239
+ ToTensorV2(),
240
+ ])
241
+
242
+ aug = transform(image=image_np)
243
+ tensor = aug['image'].unsqueeze(0).to(device)
244
+
245
+ # Inference
246
+ with torch.no_grad():
247
+ use_fp16 = device.type == 'cuda'
248
+ with torch.amp.autocast(device_type=device.type, enabled=use_fp16):
249
+ outputs = model(pixel_values=tensor)
250
+
251
+ # Get logits and resize
252
+ logits = F.interpolate(
253
+ outputs.logits,
254
+ size=(IMG_SIZE, IMG_SIZE),
255
+ mode='bilinear',
256
+ align_corners=False
257
+ )
258
+
259
+ # Get probabilities
260
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
261
+
262
+ # Get prediction mask
263
+ pred_mask = np.argmax(probs, axis=0).astype(np.uint8)
264
+
265
+ # Resize prediction to original image size
266
+ pred_mask_orig = np.array(
267
+ Image.fromarray(pred_mask).resize((orig_w, orig_h), Image.NEAREST)
268
+ )
269
+
270
+ # Calculate class distribution
271
+ class_dist = {}
272
+ total_pixels = pred_mask_orig.size
273
+ for c in range(NUM_CLASSES):
274
+ count = (pred_mask_orig == c).sum()
275
+ if count > 0:
276
+ class_dist[CLASS_NAMES[c]] = f"{(count / total_pixels * 100):.1f}%"
277
+
278
+ # Get dominant class
279
+ dominant_class = np.bincount(pred_mask_orig.flatten()).argmax()
280
+ confidence = probs[dominant_class].mean()
281
+
282
+ # Generate visualizations
283
+ colored_mask = colorize_mask(pred_mask_orig)
284
+ overlay = create_overlay(image_np, pred_mask_orig, alpha=0.5)
285
+
286
+ # Generate traversability map
287
+ print(f"Generating traversability map...")
288
+ trav_map = create_traversability_map(pred_mask_orig)
289
+ print(f"Traversability map shape: {trav_map.shape}")
290
+
291
+ trav_overlay = create_overlay(image_np, pred_mask_orig, alpha=0.6)
292
+ # Replace with traversability colors
293
+ for class_id, category in TRAVERSABILITY.items():
294
+ region = (pred_mask_orig == class_id)
295
+ trav_overlay[region] = (
296
+ image_np[region].astype(np.float32) * 0.4 +
297
+ TRAV_COLORS[category].astype(np.float32) * 0.6
298
+ ).astype(np.uint8)
299
+
300
+ trav_stats = calculate_traversability_stats(pred_mask_orig)
301
+ print(f"Traversability stats: {trav_stats}")
302
+
303
+ # Convert to base64
304
+ mask_base64 = numpy_to_base64(colored_mask)
305
+ overlay_base64 = numpy_to_base64(overlay)
306
+ trav_map_base64 = numpy_to_base64(trav_map)
307
+ trav_overlay_base64 = numpy_to_base64(trav_overlay)
308
+ print(f"All images converted to base64 successfully")
309
+
310
+ return {
311
+ "prediction": int(dominant_class),
312
+ "class_name": CLASS_NAMES[dominant_class],
313
+ "confidence": float(confidence),
314
+ "device_used": str(device),
315
+ "class_distribution": class_dist,
316
+ "segmentation_mask": mask_base64,
317
+ "overlay_image": overlay_base64,
318
+ "traversability_map": trav_map_base64,
319
+ "traversability_overlay": trav_overlay_base64,
320
+ "traversability_stats": trav_stats
321
+ }
322
+ except Exception as e:
323
+ import traceback
324
+ traceback.print_exc()
325
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
326
+
327
+
328
+ # ═══════════════════════════════════════════════════════════════
329
+ # Simulation Live Inference
330
+ # ═══════════════════════════════════════════════════════════════
331
+
332
+ class SimPredictionResponse(BaseModel):
333
+ segmentation_mask: str
334
+ traversability_map: str
335
+ traversability_overlay: str
336
+ traversability_stats: dict
337
+ traversability_grid: list
338
+ class_distribution: dict
339
+ inference_time_ms: float
340
+ dominant_class: str
341
+ confidence: float
342
+
343
+
344
+ def create_traversability_grid(class_mask, grid_cols=12, grid_rows=8):
345
+ """Create a coarse traversability grid from the prediction mask.
346
+ Uses the bottom 65 % of the image (ground portion, excluding sky).
347
+ Returns 2-D list of costmap values: 0 = go, 5 = caution, 10 = no_go.
348
+ """
349
+ h, w = class_mask.shape
350
+ ground_start = int(h * 0.35)
351
+ ground_mask = class_mask[ground_start:, :]
352
+ gh, gw = ground_mask.shape
353
+
354
+ cell_h = max(1, gh // grid_rows)
355
+ cell_w = max(1, gw // grid_cols)
356
+
357
+ grid = []
358
+ for r in range(grid_rows):
359
+ row = []
360
+ for c in range(grid_cols):
361
+ y0 = r * cell_h
362
+ y1 = min((r + 1) * cell_h, gh)
363
+ x0 = c * cell_w
364
+ x1 = min((c + 1) * cell_w, gw)
365
+
366
+ cell = ground_mask[y0:y1, x0:x1]
367
+ if cell.size == 0:
368
+ row.append(0)
369
+ continue
370
+
371
+ go_count = caution_count = no_go_count = 0
372
+ for cid in range(NUM_CLASSES):
373
+ cnt = int((cell == cid).sum())
374
+ cat = TRAVERSABILITY[cid]
375
+ if cat == 'go':
376
+ go_count += cnt
377
+ elif cat == 'caution':
378
+ caution_count += cnt
379
+ elif cat == 'no_go':
380
+ no_go_count += cnt
381
+
382
+ total = go_count + caution_count + no_go_count
383
+ if total == 0:
384
+ row.append(0)
385
+ elif no_go_count / total > 0.3:
386
+ row.append(10)
387
+ elif caution_count / total > 0.3:
388
+ row.append(5)
389
+ else:
390
+ row.append(0)
391
+ grid.append(row)
392
+ return grid
393
+
394
+
395
+ @app.post("/predict/sim", response_model=SimPredictionResponse)
396
+ async def predict_sim(file: UploadFile = File(...)):
397
+ """Prediction endpoint optimised for simulation live inference.
398
+ Returns a traversability grid suitable for direct costmap updates."""
399
+ if model is None:
400
+ raise HTTPException(status_code=503, detail="Model not loaded")
401
+
402
+ import time
403
+ t0 = time.time()
404
+
405
+ try:
406
+ contents = await file.read()
407
+ image = Image.open(io.BytesIO(contents)).convert('RGB')
408
+ image_np = np.array(image)
409
+ orig_h, orig_w = image_np.shape[:2]
410
+
411
+ transform = A.Compose([
412
+ A.Resize(height=IMG_SIZE, width=IMG_SIZE),
413
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
414
+ ToTensorV2(),
415
+ ])
416
+
417
+ aug = transform(image=image_np)
418
+ tensor = aug['image'].unsqueeze(0).to(device)
419
+
420
+ with torch.no_grad():
421
+ use_fp16 = device.type == 'cuda'
422
+ with torch.amp.autocast(device_type=device.type, enabled=use_fp16):
423
+ outputs = model(pixel_values=tensor)
424
+
425
+ logits = F.interpolate(
426
+ outputs.logits,
427
+ size=(IMG_SIZE, IMG_SIZE),
428
+ mode='bilinear',
429
+ align_corners=False,
430
+ )
431
+
432
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
433
+ pred_mask = np.argmax(probs, axis=0).astype(np.uint8)
434
+ pred_mask_orig = np.array(
435
+ Image.fromarray(pred_mask).resize((orig_w, orig_h), Image.NEAREST)
436
+ )
437
+
438
+ # Visualisations
439
+ colored_mask = colorize_mask(pred_mask_orig)
440
+ trav_map = create_traversability_map(pred_mask_orig)
441
+
442
+ trav_overlay_img = image_np.copy()
443
+ for cid, category in TRAVERSABILITY.items():
444
+ region = (pred_mask_orig == cid)
445
+ trav_overlay_img[region] = (
446
+ image_np[region].astype(np.float32) * 0.4
447
+ + TRAV_COLORS[category].astype(np.float32) * 0.6
448
+ ).astype(np.uint8)
449
+
450
+ trav_stats = calculate_traversability_stats(pred_mask_orig)
451
+ trav_grid = create_traversability_grid(pred_mask_orig)
452
+
453
+ class_dist = {}
454
+ total_pixels = pred_mask_orig.size
455
+ for cid in range(NUM_CLASSES):
456
+ cnt = int((pred_mask_orig == cid).sum())
457
+ if cnt > 0:
458
+ class_dist[CLASS_NAMES[cid]] = f"{cnt / total_pixels * 100:.1f}%"
459
+
460
+ dominant = int(np.bincount(pred_mask_orig.flatten()).argmax())
461
+ conf = float(probs[dominant].mean())
462
+ elapsed = (time.time() - t0) * 1000
463
+
464
+ return {
465
+ "segmentation_mask": numpy_to_base64(colored_mask),
466
+ "traversability_map": numpy_to_base64(trav_map),
467
+ "traversability_overlay": numpy_to_base64(trav_overlay_img),
468
+ "traversability_stats": trav_stats,
469
+ "traversability_grid": trav_grid,
470
+ "class_distribution": class_dist,
471
+ "inference_time_ms": round(elapsed, 1),
472
+ "dominant_class": CLASS_NAMES[dominant],
473
+ "confidence": conf,
474
+ }
475
+ except Exception as e:
476
+ import traceback
477
+ traceback.print_exc()
478
+ raise HTTPException(status_code=500, detail=f"Sim prediction error: {str(e)}")
479
+
480
+
481
+ @app.get("/model/info")
482
+ async def model_info():
483
+ """Get model information"""
484
+ if model is None:
485
+ raise HTTPException(status_code=503, detail="Model not loaded")
486
+
487
+ return {
488
+ "model_type": str(type(model).__name__),
489
+ "device": str(device),
490
+ "parameters": sum(p.numel() for p in model.parameters() if hasattr(model, 'parameters'))
491
+ }
492
+
493
+ if __name__ == "__main__":
494
+ import os
495
+ port = int(os.environ.get("PORT", 7860))
496
+ uvicorn.run(app, host="0.0.0.0", port=port)
models/latest_model_ft.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e8980995a98075f6e15639eeeeacccd947d72dcd3360448f5af5f9b1c7defc0
3
+ size 256399803
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.0
2
+ uvicorn[standard]==0.27.0
3
+ torch>=2.6.0
4
+ torchvision>=0.21.0
5
+ pillow>=10.2.0
6
+ python-multipart>=0.0.6
7
+ numpy>=1.26.0
8
+ pydantic>=2.5.0
9
+ albumentations>=1.3.0
10
+ transformers>=4.30.0