Denny Lulak commited on
Commit
96cdae8
·
1 Parent(s): 254a582
Files changed (1) hide show
  1. app.py +79 -72
app.py CHANGED
@@ -1,80 +1,103 @@
1
  import os
2
  import cv2
3
  import numpy as np
4
- from fastapi import FastAPI, WebSocket, status
5
- from onnxruntime import InferenceSession
6
- from ultralytics import YOLO
7
  import uvicorn
8
  import base64
9
  from typing import Tuple, List
 
 
10
 
11
  # Configuration
12
  MODEL_PT_PATH = "model.pt"
13
  MODEL_ONNX_PATH = "model.onnx"
14
  INPUT_SIZE = 640
15
- CLASS_NAMES = ["class0", "class1"] # Your class names
16
  CONF_THRESHOLD = 0.5
17
  IOU_THRESHOLD = 0.45
18
 
19
- app = FastAPI(title="Object Detection API")
20
-
21
- # Load model once at startup
22
- @app.on_event("startup")
23
- async def load_model():
24
- # Convert model if needed
25
  if not os.path.exists(MODEL_ONNX_PATH):
26
  print("Converting PyTorch model to ONNX...")
27
- try:
28
- model = YOLO(MODEL_PT_PATH)
29
- model.export(
30
- format="onnx",
31
- imgsz=INPUT_SIZE,
32
- opset=12,
33
- simplify=True,
34
- dynamic=False,
35
- half=False
36
- )
37
- if os.path.exists("yolov8n.onnx"):
38
- os.rename("yolov8n.onnx", MODEL_ONNX_PATH)
39
- except Exception as e:
40
- raise RuntimeError(f"ONNX conversion failed: {str(e)}")
41
-
42
- # Initialize ONNX runtime session
43
- options = ort.SessionOptions()
44
- options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
45
- app.state.model = InferenceSession(
46
- MODEL_ONNX_PATH,
47
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
48
- sess_options=options
49
- )
50
-
51
  # Warm-up
52
  dummy_input = np.random.randn(1, 3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
53
  app.state.model.run(None, {"images": dummy_input})
 
 
 
 
 
 
54
 
 
 
 
 
55
  def preprocess_image(image: np.ndarray) -> Tuple[np.ndarray, float, Tuple[int, int]]:
56
- """Preprocess image with letterboxing"""
57
  h, w = image.shape[:2]
58
  scale = min(INPUT_SIZE / h, INPUT_SIZE / w)
59
  new_h, new_w = int(h * scale), int(w * scale)
60
-
61
  resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
 
 
62
  canvas = np.full((INPUT_SIZE, INPUT_SIZE, 3), 114, dtype=np.uint8)
63
  ph, pw = (INPUT_SIZE - new_h) // 2, (INPUT_SIZE - new_w) // 2
64
  canvas[ph:ph+new_h, pw:pw+new_w] = resized
65
-
66
  blob = canvas.astype(np.float32) / 255.0
67
  return blob.transpose(2, 0, 1)[None, ...], scale, (pw, ph)
68
 
69
- async def process_image(image: np.ndarray) -> dict:
70
- """Process image and return detection results"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Preprocess
72
- input_tensor, scale, padding = preprocess_image(image)
73
 
74
  # Inference
75
  outputs = app.state.model.run(None, {"images": input_tensor})
76
-
77
- # Post-process
78
  predictions = np.squeeze(outputs[0]).T
79
  scores = np.max(predictions[:, 4:], axis=1)
80
  valid = scores > CONF_THRESHOLD
@@ -82,76 +105,60 @@ async def process_image(image: np.ndarray) -> dict:
82
 
83
  if predictions.size == 0:
84
  return {"detections": []}
85
-
86
- # Convert boxes
87
  boxes = predictions[:, :4]
88
  boxes[:, [0, 1]] = boxes[:, [0, 1]] - boxes[:, [2, 3]] / 2
89
  boxes[:, [2, 3]] = boxes[:, [0, 1]] + boxes[:, [2, 3]]
90
-
91
- # Adjust coordinates
92
- pad_w, pad_h = padding
93
  boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_w) / scale
94
  boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_h) / scale
95
 
96
- # Clip boxes
97
  h, w = image.shape[:2]
98
  boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, w)
99
  boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, h)
100
 
101
- # Apply NMS
102
  class_ids = np.argmax(predictions[:, 4:], axis=1)
103
- indices = nms(boxes, scores[valid], IOU_THRESHOLD)
104
 
105
  # Format results
106
  detections = []
107
- for i in indices:
108
  detections.append({
109
  "class_id": int(class_ids[i]),
110
  "class_name": CLASS_NAMES[class_ids[i]],
111
  "confidence": float(scores[valid][i]),
112
- "bbox": {
113
- "x1": float(boxes[i][0]),
114
- "y1": float(boxes[i][1]),
115
- "x2": float(boxes[i][2]),
116
- "y2": float(boxes[i][3])
117
- }
118
  })
119
 
120
  return {"detections": detections}
121
 
 
122
  @app.websocket("/ws/detect")
123
- async def websocket_endpoint(websocket: WebSocket):
 
124
  await websocket.accept()
125
  try:
126
  while True:
127
- # Receive base64 image
128
  data = await websocket.receive_text()
129
- header, encoded = data.split(",", 1)
130
  image_bytes = base64.b64decode(encoded)
131
-
132
- # Convert to numpy array
133
  nparr = np.frombuffer(image_bytes, np.uint8)
134
  image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
135
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
136
 
137
- # Process and return results
138
- results = await process_image(image)
139
  await websocket.send_json(results)
140
-
141
  except Exception as e:
142
  print(f"WebSocket error: {e}")
143
  await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
144
 
145
  @app.post("/detect")
146
  async def http_detect(image: UploadFile = File(...)):
147
- # Read and decode image
148
  contents = await image.read()
149
  nparr = np.frombuffer(contents, np.uint8)
150
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
151
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
152
-
153
- # Process and return results
154
- return await process_image(img)
155
-
156
- if __name__ == "__main__":
157
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import cv2
3
  import numpy as np
4
+ from fastapi import FastAPI, WebSocket, status, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ from contextlib import asynccontextmanager
7
  import uvicorn
8
  import base64
9
  from typing import Tuple, List
10
+ import onnxruntime as ort
11
+ from ultralytics import YOLO
12
 
13
  # Configuration
14
  MODEL_PT_PATH = "model.pt"
15
  MODEL_ONNX_PATH = "model.onnx"
16
  INPUT_SIZE = 640
17
+ CLASS_NAMES = ["class0", "class1"] # Replace with your class names
18
  CONF_THRESHOLD = 0.5
19
  IOU_THRESHOLD = 0.45
20
 
21
+ # --- Modern FastAPI Lifespan Setup (Replaces @app.on_event) ---
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ """Initialize and clean up model resources."""
25
+ # Convert PyTorch to ONNX if needed
 
26
  if not os.path.exists(MODEL_ONNX_PATH):
27
  print("Converting PyTorch model to ONNX...")
28
+ model = YOLO(MODEL_PT_PATH)
29
+ model.export(
30
+ format="onnx",
31
+ imgsz=INPUT_SIZE,
32
+ opset=12,
33
+ simplify=True,
34
+ dynamic=False,
35
+ half=False
36
+ )
37
+ if os.path.exists("yolov8n.onnx"):
38
+ os.rename("yolov8n.onnx", MODEL_ONNX_PATH)
39
+
40
+ # Load ONNX model with GPU
41
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
42
+ app.state.model = ort.InferenceSession(MODEL_ONNX_PATH, providers=providers)
43
+
 
 
 
 
 
 
 
 
44
  # Warm-up
45
  dummy_input = np.random.randn(1, 3, INPUT_SIZE, INPUT_SIZE).astype(np.float32)
46
  app.state.model.run(None, {"images": dummy_input})
47
+ print("Model loaded and ready!")
48
+
49
+ yield # App runs here
50
+
51
+ # Cleanup (optional)
52
+ print("Shutting down...")
53
 
54
+ # Initialize FastAPI with lifespan
55
+ app = FastAPI(title="YOLOv8 API", lifespan=lifespan)
56
+
57
+ # --- Core Detection Functions (Same as Before) ---
58
  def preprocess_image(image: np.ndarray) -> Tuple[np.ndarray, float, Tuple[int, int]]:
59
+ """Resize and normalize image for YOLOv8 input."""
60
  h, w = image.shape[:2]
61
  scale = min(INPUT_SIZE / h, INPUT_SIZE / w)
62
  new_h, new_w = int(h * scale), int(w * scale)
 
63
  resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
64
+
65
+ # Letterboxing
66
  canvas = np.full((INPUT_SIZE, INPUT_SIZE, 3), 114, dtype=np.uint8)
67
  ph, pw = (INPUT_SIZE - new_h) // 2, (INPUT_SIZE - new_w) // 2
68
  canvas[ph:ph+new_h, pw:pw+new_w] = resized
 
69
  blob = canvas.astype(np.float32) / 255.0
70
  return blob.transpose(2, 0, 1)[None, ...], scale, (pw, ph)
71
 
72
+ def nms(boxes, scores, iou_threshold):
73
+ """Non-Maximum Suppression to filter overlapping boxes."""
74
+ keep = []
75
+ if len(boxes) == 0:
76
+ return keep
77
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
78
+ areas = (x2 - x1) * (y2 - y1)
79
+ order = scores.argsort()[::-1]
80
+
81
+ while order.size > 0:
82
+ i = order[0]
83
+ keep.append(i)
84
+ xx1 = np.maximum(x1[i], x1[order[1:]])
85
+ yy1 = np.maximum(y1[i], y1[order[1:]])
86
+ xx2 = np.minimum(x2[i], x2[order[1:]])
87
+ yy2 = np.minimum(y2[i], y2[order[1:]])
88
+ w, h = np.maximum(0.0, xx2 - xx1), np.maximum(0.0, yy2 - yy1)
89
+ iou = (w * h) / (areas[i] + areas[order[1:]] - w * h)
90
+ inds = np.where(iou <= iou_threshold)[0]
91
+ order = order[inds + 1]
92
+ return keep
93
+
94
+ async def detect_objects(image: np.ndarray) -> dict:
95
+ """Run YOLOv8 inference and return detections."""
96
  # Preprocess
97
+ input_tensor, scale, (pad_w, pad_h) = preprocess_image(image)
98
 
99
  # Inference
100
  outputs = app.state.model.run(None, {"images": input_tensor})
 
 
101
  predictions = np.squeeze(outputs[0]).T
102
  scores = np.max(predictions[:, 4:], axis=1)
103
  valid = scores > CONF_THRESHOLD
 
105
 
106
  if predictions.size == 0:
107
  return {"detections": []}
108
+
109
+ # Decode boxes
110
  boxes = predictions[:, :4]
111
  boxes[:, [0, 1]] = boxes[:, [0, 1]] - boxes[:, [2, 3]] / 2
112
  boxes[:, [2, 3]] = boxes[:, [0, 1]] + boxes[:, [2, 3]]
 
 
 
113
  boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_w) / scale
114
  boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_h) / scale
115
 
116
+ # Clip to image bounds
117
  h, w = image.shape[:2]
118
  boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, w)
119
  boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, h)
120
 
121
+ # NMS
122
  class_ids = np.argmax(predictions[:, 4:], axis=1)
123
+ keep = nms(boxes, scores[valid], IOU_THRESHOLD)
124
 
125
  # Format results
126
  detections = []
127
+ for i in keep:
128
  detections.append({
129
  "class_id": int(class_ids[i]),
130
  "class_name": CLASS_NAMES[class_ids[i]],
131
  "confidence": float(scores[valid][i]),
132
+ "bbox": [float(x) for x in boxes[i]] # [x1, y1, x2, y2]
 
 
 
 
 
133
  })
134
 
135
  return {"detections": detections}
136
 
137
+ # --- API Endpoints ---
138
  @app.websocket("/ws/detect")
139
+ async def websocket_detection(websocket: WebSocket):
140
+ """Real-time detection via WebSocket."""
141
  await websocket.accept()
142
  try:
143
  while True:
 
144
  data = await websocket.receive_text()
145
+ _, encoded = data.split(",", 1)
146
  image_bytes = base64.b64decode(encoded)
 
 
147
  nparr = np.frombuffer(image_bytes, np.uint8)
148
  image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
149
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
150
 
151
+ results = await detect_objects(image)
 
152
  await websocket.send_json(results)
 
153
  except Exception as e:
154
  print(f"WebSocket error: {e}")
155
  await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
156
 
157
  @app.post("/detect")
158
  async def http_detect(image: UploadFile = File(...)):
159
+ """HTTP endpoint for single-image detection."""
160
  contents = await image.read()
161
  nparr = np.frombuffer(contents, np.uint8)
162
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
163
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
164
+ return await detect_objects(img)