Denny Lulak commited on
Commit
38d965c
·
1 Parent(s): 9dcd0d7
Files changed (1) hide show
  1. app.py +188 -51
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import cv2
3
  import numpy as np
4
- from fastapi import FastAPI, WebSocket, status, UploadFile, File
5
  from fastapi.responses import JSONResponse
6
  from contextlib import asynccontextmanager
7
  import uvicorn
8
  import base64
 
9
  import onnxruntime as ort
10
  from ultralytics import YOLO
11
 
@@ -13,79 +14,215 @@ from ultralytics import YOLO
13
  MODEL_PT_PATH = "model.pt"
14
  MODEL_ONNX_PATH = "model.onnx"
15
  INPUT_SIZE = 640
16
- CLASS_NAMES = ["class0", "class1"] # Replace with your class names
17
  CONF_THRESHOLD = 0.5
18
  IOU_THRESHOLD = 0.45
19
 
20
- # --- Modern FastAPI Lifespan Setup ---
21
  @asynccontextmanager
22
  async def lifespan(app: FastAPI):
23
- """Initialize model on startup."""
24
  # Convert PyTorch to ONNX if needed
25
  if not os.path.exists(MODEL_ONNX_PATH):
26
- print("Converting PyTorch model to ONNX...")
27
- model = YOLO(MODEL_PT_PATH)
28
- model.export(
29
- format="onnx",
30
- imgsz=INPUT_SIZE,
31
- opset=12,
32
- simplify=True,
33
- dynamic=False,
34
- half=False
35
- )
36
- if os.path.exists("yolov8n.onnx"):
37
- os.rename("yolov8n.onnx", MODEL_ONNX_PATH)
38
-
39
- # Load ONNX model with GPU
40
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
41
- app.state.model = ort.InferenceSession(MODEL_ONNX_PATH, providers=providers)
42
 
43
- # Warm-up
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  dummy_input = np.random.randn(1, 3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
45
  app.state.model.run(None, {"images": dummy_input})
46
- print(" Model loaded successfully!")
47
- yield
 
 
 
48
 
49
  # --- Initialize FastAPI App ---
50
- app = FastAPI(lifespan=lifespan) # Must be named 'app' for Hugging Face Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # --- Rest of your code (WebSocket and HTTP endpoints) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @app.get("/")
54
  async def health_check():
55
- return {"status": "OK", "message": "API is running"}
56
 
57
  @app.websocket("/ws/detect")
58
  async def websocket_detection(websocket: WebSocket):
59
  await websocket.accept()
60
  try:
61
  while True:
62
- data = await websocket.receive_text()
63
- _, encoded = data.split(",", 1)
64
- image_bytes = base64.b64decode(encoded)
65
- nparr = np.frombuffer(image_bytes, np.uint8)
66
- image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
67
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
-
69
- # Process and return detections
70
- results = await detect_objects(image)
71
- await websocket.send_json(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
- print(f"WebSocket error: {e}")
74
- await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
 
75
 
76
  @app.post("/detect")
77
  async def http_detect(image: UploadFile = File(...)):
78
- """HTTP endpoint for single-image detection."""
79
- contents = await image.read()
80
- nparr = np.frombuffer(contents, np.uint8)
81
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
82
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
83
- return await detect_objects(img)
84
-
85
- # --- Helper Functions ---
86
- async def detect_objects(image: np.ndarray) -> dict:
87
- """Your existing detection logic here"""
88
- return {"detections": []} # Replace with actual implementation
 
89
 
90
- # --- Hugging Face Spaces Requirement ---
91
- # The variable `app` must be defined at the top level
 
 
1
  import os
2
  import cv2
3
  import numpy as np
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, status
5
  from fastapi.responses import JSONResponse
6
  from contextlib import asynccontextmanager
7
  import uvicorn
8
  import base64
9
+ from typing import List, Tuple
10
  import onnxruntime as ort
11
  from ultralytics import YOLO
12
 
 
14
  MODEL_PT_PATH = "model.pt"
15
  MODEL_ONNX_PATH = "model.onnx"
16
  INPUT_SIZE = 640
17
+ CLASS_NAMES = ["class0", "class1"] # Replace with your actual class names
18
  CONF_THRESHOLD = 0.5
19
  IOU_THRESHOLD = 0.45
20
 
21
+ # --- Lifespan Management ---
22
  @asynccontextmanager
23
  async def lifespan(app: FastAPI):
24
+ """Initialize and clean up model resources."""
25
  # Convert PyTorch to ONNX if needed
26
  if not os.path.exists(MODEL_ONNX_PATH):
27
+ print("🔄 Converting PyTorch model to ONNX...")
28
+ try:
29
+ model = YOLO(MODEL_PT_PATH)
30
+ model.export(
31
+ format="onnx",
32
+ imgsz=INPUT_SIZE,
33
+ opset=12,
34
+ simplify=True,
35
+ dynamic=False,
36
+ half=False
37
+ )
38
+ if os.path.exists("yolov8n.onnx"):
39
+ os.rename("yolov8n.onnx", MODEL_ONNX_PATH)
40
+ print("✅ ONNX conversion successful!")
41
+ except Exception as e:
42
+ raise RuntimeError(f"ONNX conversion failed: {str(e)}")
43
 
44
+ # Initialize ONNX Runtime session with GPU
45
+ print("⚙️ Initializing ONNX Runtime session...")
46
+ providers = [
47
+ ('CUDAExecutionProvider', {
48
+ 'device_id': 0,
49
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
50
+ 'gpu_mem_limit': 2 * 1024 * 1024 * 1024, # 2GB
51
+ 'cudnn_conv_algo_search': 'HEURISTIC',
52
+ 'do_copy_in_default_stream': True,
53
+ }),
54
+ 'CPUExecutionProvider'
55
+ ]
56
+
57
+ sess_options = ort.SessionOptions()
58
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
59
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
60
+
61
+ app.state.model = ort.InferenceSession(
62
+ MODEL_ONNX_PATH,
63
+ providers=providers,
64
+ sess_options=sess_options
65
+ )
66
+
67
+ # Warm-up run
68
+ print("🔥 Warming up model...")
69
  dummy_input = np.random.randn(1, 3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
70
  app.state.model.run(None, {"images": dummy_input})
71
+ print("🚀 Model ready for inference!")
72
+
73
+ yield # App runs here
74
+
75
+ print("🛑 Cleaning up resources...")
76
 
77
  # --- Initialize FastAPI App ---
78
+ app = FastAPI(
79
+ title="YOLOv8 Object Detection API",
80
+ description="Real-time object detection with WebSocket and HTTP endpoints",
81
+ lifespan=lifespan
82
+ )
83
+
84
+ # --- Core Detection Functions ---
85
+ def letterbox_image(image: np.ndarray) -> Tuple[np.ndarray, float, Tuple[int, int]]:
86
+ """Preprocess image with letterboxing for YOLOv8."""
87
+ h, w = image.shape[:2]
88
+ scale = min(INPUT_SIZE / h, INPUT_SIZE / w)
89
+ new_h, new_w = int(h * scale), int(w * scale)
90
+
91
+ resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
92
+ canvas = np.full((INPUT_SIZE, INPUT_SIZE, 3), 114, dtype=np.uint8)
93
+ ph, pw = (INPUT_SIZE - new_h) // 2, (INPUT_SIZE - new_w) // 2
94
+ canvas[ph:ph+new_h, pw:pw+new_w] = resized
95
+
96
+ blob = canvas.astype(np.float32) / 255.0
97
+ return blob.transpose(2, 0, 1)[None, ...], scale, (pw, ph)
98
+
99
+ def nms(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> List[int]:
100
+ """Non-Maximum Suppression to filter overlapping boxes."""
101
+ keep = []
102
+ if len(boxes) == 0:
103
+ return keep
104
+
105
+ x1 = boxes[:, 0]
106
+ y1 = boxes[:, 1]
107
+ x2 = boxes[:, 2]
108
+ y2 = boxes[:, 3]
109
+ areas = (x2 - x1) * (y2 - y1)
110
+ order = scores.argsort()[::-1]
111
+
112
+ while order.size > 0:
113
+ i = order[0]
114
+ keep.append(i)
115
+ xx1 = np.maximum(x1[i], x1[order[1:]])
116
+ yy1 = np.maximum(y1[i], y1[order[1:]])
117
+ xx2 = np.minimum(x2[i], x2[order[1:]])
118
+ yy2 = np.minimum(y2[i], y2[order[1:]])
119
+ w = np.maximum(0.0, xx2 - xx1)
120
+ h = np.maximum(0.0, yy2 - yy1)
121
+ inter = w * h
122
+ iou = inter / (areas[i] + areas[order[1:]] - inter)
123
+ inds = np.where(iou <= iou_threshold)[0]
124
+ order = order[inds + 1]
125
+
126
+ return keep
127
+
128
+ async def detect_objects(image: np.ndarray) -> dict:
129
+ """Run object detection pipeline."""
130
+ # Preprocess
131
+ input_tensor, scale, (pad_w, pad_h) = letterbox_image(image)
132
+
133
+ # Inference
134
+ outputs = app.state.model.run(None, {"images": input_tensor})
135
+ predictions = np.squeeze(outputs[0]).T
136
+ scores = np.max(predictions[:, 4:], axis=1)
137
+ valid = scores > CONF_THRESHOLD
138
+ predictions = predictions[valid]
139
+
140
+ if predictions.size == 0:
141
+ return {"detections": []}
142
 
143
+ # Decode boxes
144
+ boxes = predictions[:, :4]
145
+ boxes[:, [0, 1]] = boxes[:, [0, 1]] - boxes[:, [2, 3]] / 2
146
+ boxes[:, [2, 3]] = boxes[:, [0, 1]] + boxes[:, [2, 3]]
147
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_w) / scale
148
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_h) / scale
149
+
150
+ # Clip to image bounds
151
+ h, w = image.shape[:2]
152
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, w)
153
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, h)
154
+
155
+ # NMS
156
+ class_ids = np.argmax(predictions[:, 4:], axis=1)
157
+ keep = nms(boxes, scores[valid], IOU_THRESHOLD)
158
+
159
+ # Format results
160
+ detections = []
161
+ for i in keep:
162
+ detections.append({
163
+ "class_id": int(class_ids[i]),
164
+ "class_name": CLASS_NAMES[class_ids[i]],
165
+ "confidence": float(scores[valid][i]),
166
+ "bbox": [float(boxes[i][0]), float(boxes[i][1]),
167
+ float(boxes[i][2]), float(boxes[i][3])]
168
+ })
169
+
170
+ return {"detections": detections}
171
+
172
+ # --- API Endpoints ---
173
  @app.get("/")
174
  async def health_check():
175
+ return {"status": "OK", "message": "Object Detection API is running"}
176
 
177
  @app.websocket("/ws/detect")
178
  async def websocket_detection(websocket: WebSocket):
179
  await websocket.accept()
180
  try:
181
  while True:
182
+ try:
183
+ # Receive base64 image
184
+ data = await websocket.receive_text()
185
+ if not data.startswith("data:"):
186
+ await websocket.send_json({"error": "Invalid image format"})
187
+ continue
188
+
189
+ _, encoded = data.split(",", 1)
190
+ image_bytes = base64.b64decode(encoded)
191
+ nparr = np.frombuffer(image_bytes, np.uint8)
192
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
193
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
194
+
195
+ # Process and return detections
196
+ results = await detect_objects(image)
197
+ await websocket.send_json(results)
198
+
199
+ except Exception as e:
200
+ print(f"⚠️ Processing error: {str(e)}")
201
+ await websocket.send_json({"error": str(e)})
202
+ continue
203
+
204
+ except WebSocketDisconnect:
205
+ print("Client disconnected")
206
  except Exception as e:
207
+ print(f"WebSocket error: {str(e)}")
208
+ finally:
209
+ await websocket.close()
210
 
211
  @app.post("/detect")
212
  async def http_detect(image: UploadFile = File(...)):
213
+ """Process single image via HTTP POST."""
214
+ try:
215
+ contents = await image.read()
216
+ nparr = np.frombuffer(contents, np.uint8)
217
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
218
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
219
+ return await detect_objects(img)
220
+ except Exception as e:
221
+ return JSONResponse(
222
+ status_code=status.HTTP_400_BAD_REQUEST,
223
+ content={"error": f"Image processing failed: {str(e)}"}
224
+ )
225
 
226
+ # --- For Local Development ---
227
+ if __name__ == "__main__":
228
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)