e1250 commited on
Commit
c0dbd21
·
1 Parent(s): 739dca3

feat: refactoring code, adding contaxtmanager, create pipeline

Browse files
api/dependencies.py CHANGED
@@ -15,4 +15,4 @@ def get_safety_detection_model(request: HTTPConnection):
15
 
16
 
17
  def get_redis(request: HTTPConnection):
18
- return request.app.state.redis
 
15
 
16
 
17
  def get_redis(request: HTTPConnection):
18
+ return request.app.state.redis
api/routers/camera_stream.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from domain.detection_box_center import calculate_detection_box_center
2
  from api.dependencies import get_safety_detection_model
3
  from api.dependencies import get_detection_model, get_depth_model
@@ -52,8 +54,9 @@ async def websocket_detect(
52
  logger.info(f"Client ID >>{camera_id}<< Connected...")
53
 
54
  step_counter = itertools.count()
 
 
55
 
56
- loop = asyncio.get_running_loop()
57
  # Queue removing old images in case they were being stacked
58
  frame_queue: asyncio.Queue = asyncio.Queue(maxsize=1)
59
 
@@ -78,106 +81,21 @@ async def websocket_detect(
78
  try:
79
  logger.info(f"Camera {camera_id} start sending frames...")
80
 
81
- def decode_frame(fb):
82
- return cv.imdecode(np.frombuffer(fb, np.uint8), cv.IMREAD_COLOR)
83
-
84
  # Keep receiving messages in a loop until disconnection.
85
  while True:
86
  frame_bytes = await frame_queue.get()
87
 
88
- # Profiling
89
- t0 = time.time()
90
- image_array = await loop.run_in_executor(
91
- None, decode_frame, frame_bytes
92
- )
93
- decode_duration_seconds.labels(camera_id).observe(
94
- round(time.time() - t0, 3)
95
- )
96
- mlflow.log_metric(
97
- "frame_processing_time",
98
- round(time.time() - t0, 3),
99
- next(step_counter),
100
- )
101
-
102
- # Apply detection models
103
- t0 = time.time()
104
- detection_task = loop.run_in_executor(
105
- None, detector.detect, image_array
106
- )
107
- safety_task = loop.run_in_executor(
108
- None, safety_detector.detect, image_array
109
- )
110
- detections, safety_detection = await asyncio.gather(
111
- detection_task, safety_task
112
- )
113
- detection_duration_seconds.labels(camera_id).observe(
114
- round(time.time() - t0, 3)
115
- )
116
- mlflow.log_metric(
117
- "detection_duration_seconds",
118
- round(time.time() - t0, 3),
119
- next(step_counter),
120
- )
121
-
122
- # Profiling
123
- frame_processing_duration_seconds.labels(camera_id).observe(
124
- round(time.time() - t0, 3)
125
- )
126
- logger.debug("Frame processed", camera_id=camera_id)
127
- mlflow.log_metric(
128
- "frame_processing duration time",
129
- round(time.time() - t0, 3),
130
- next(step_counter),
131
- )
132
-
133
- boxes_center, boxes_center_ratio = calculate_detection_box_center(detections.detections, image_array.shape[1])
134
-
135
- t0 = time.time()
136
- depth_points = (
137
- await loop.run_in_executor(
138
- None, depth_model.calculate_depth, image_array, boxes_center
139
- )
140
- if boxes_center
141
- else []
142
- )
143
- depth_duration_seconds.labels(camera_id).observe(
144
- round(time.time() - t0, 3)
145
- )
146
- mlflow.log_metric(
147
- "depth_duration_seconds",
148
- round(time.time() - t0, 3),
149
- next(step_counter),
150
- )
151
-
152
- detection_metadata = [
153
- DetectionMetadata(depth=depth, xRatio=xRatio)
154
- for depth, xRatio in zip(depth_points, boxes_center_ratio)
155
- ]
156
- metadata = CameraMetadata(
157
- camera_id=camera_id,
158
- is_danger=True if safety_detection else False,
159
- detection_metadata=detection_metadata,
160
- )
161
-
162
- await redis.publish("dashboard_stream", metadata.model_dump_json())
163
- # Even if the camera was disconnected, redis is still going to show its data, which is not accurate.
164
- # Instead, we set expiry date for the camera data.
165
- await redis.setex(
166
- f"camera:{camera_id}:latest", # And this is the key, or tag
167
- 10, # in seconds
168
- metadata.model_dump_json(),
169
- )
170
 
171
  # Note that JSONResponse doesn't work here, as it is for HTTP
172
- await websocket.send_json({"status": 200, "camera_id": camera_id})
173
 
174
  except Exception as e:
175
  logger.error(f"Processing Error: {e}", camera_id=camera_id)
176
  raise
177
 
178
- with mlflow.start_run(
179
- run_name=f"camera_{camera_id}", nested=True, parent_run_id=state.mlflow_run_id
180
- ):
181
  log_config()
182
 
183
  try:
@@ -197,4 +115,4 @@ async def websocket_detect(
197
  await redis.srem(
198
  "cameras:active", camera_id
199
  ) # Remove the camera from redis connected cameras
200
- active_cameras.dec()
 
1
+ from backend.utils.profiling import profile_step
2
+ from backend.services.pipeline import ProcessingPipeline
3
  from domain.detection_box_center import calculate_detection_box_center
4
  from api.dependencies import get_safety_detection_model
5
  from api.dependencies import get_detection_model, get_depth_model
 
54
  logger.info(f"Client ID >>{camera_id}<< Connected...")
55
 
56
  step_counter = itertools.count()
57
+ pipeline = ProcessingPipeline(detector, depth_model, safety_detector, redis)
58
+
59
 
 
60
  # Queue removing old images in case they were being stacked
61
  frame_queue: asyncio.Queue = asyncio.Queue(maxsize=1)
62
 
 
81
  try:
82
  logger.info(f"Camera {camera_id} start sending frames...")
83
 
 
 
 
84
  # Keep receiving messages in a loop until disconnection.
85
  while True:
86
  frame_bytes = await frame_queue.get()
87
 
88
+
89
+ results = await pipeline.run(camera_id, frame_bytes, next(step_counter))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Note that JSONResponse doesn't work here, as it is for HTTP
92
+ await websocket.send_json(results)
93
 
94
  except Exception as e:
95
  logger.error(f"Processing Error: {e}", camera_id=camera_id)
96
  raise
97
 
98
+ with mlflow.start_run(run_name=f"camera_{camera_id}", nested=True, parent_run_id=state.mlflow_run_id):
 
 
99
  log_config()
100
 
101
  try:
 
115
  await redis.srem(
116
  "cameras:active", camera_id
117
  ) # Remove the camera from redis connected cameras
118
+ active_cameras.dec()
services/pipeline.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  class ProcessingPipeline:
2
  def __init__(self, detector, depth_model, safety_detector, redis_client):
3
  self.detector = detector
@@ -5,9 +17,48 @@ class ProcessingPipeline:
5
  self.safety_detector = safety_detector
6
  self.redis_client = redis_client
7
 
8
- async def run(self, camera_id:str, image_array):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Run ai models
11
- # Domain logic
12
- # save to infra
13
- pass
 
1
+ from backend.api.routers.metrics import depth_duration_seconds
2
+ from backend.api.routers.metrics import detection_duration_seconds
3
+ from backend.api.routers.metrics import decode_duration_seconds
4
+ from backend.utils.profiling import profile_step
5
+ from backend.domain.detection_box_center import calculate_detection_box_center
6
+ import asyncio
7
+ from backend.contracts.camera_metadata import DetectionMetadata
8
+ from backend.contracts.camera_metadata import CameraMetadata
9
+ import cv2 as cv
10
+ import numpy as np
11
+
12
+
13
  class ProcessingPipeline:
14
  def __init__(self, detector, depth_model, safety_detector, redis_client):
15
  self.detector = detector
 
17
  self.safety_detector = safety_detector
18
  self.redis_client = redis_client
19
 
20
+ def _decode_frame(fb):
21
+ return cv.imdecode(np.frombuffer(fb, np.uint8), cv.IMREAD_COLOR)
22
+
23
+ def _camera_metadata(self, camera_id, safety_detection, depth_points, boxes_center_ratio) -> CameraMetadata:
24
+ detection_metadata = [
25
+ DetectionMetadata(depth=depth, xRatio=xRatio) for depth, xRatio in zip(depth_points, boxes_center_ratio)
26
+ ]
27
+ metadata = CameraMetadata(
28
+ camera_id=camera_id,
29
+ is_danger=True if safety_detection else False,
30
+ detection_metadata=detection_metadata,
31
+ )
32
+ return metadata
33
+
34
+ async def run(self, camera_id:str, image_array, frame_count):
35
+ loop = asyncio.get_running_loop()
36
+
37
+ with profile_step("frame_processing_time", decode_duration_seconds, camera_id, frame_count):
38
+ image_array = await loop.run_in_executor(None, self._decode_frame, image_array)
39
+
40
+ with profile_step("detection_duration_seconds", detection_duration_seconds, camera_id, frame_count):
41
+ detection_task = loop.run_in_executor(None, self.detector.detect, image_array)
42
+ safety_task = loop.run_in_executor(None, self.safety_detector.detect, image_array)
43
+ detections, safety_detection = await asyncio.gather(detection_task, safety_task)
44
+
45
+ boxes_center, boxes_center_ratio = calculate_detection_box_center(detections.detections, image_array.shape[1])
46
+
47
+ depth_points = []
48
+ if boxes_center:
49
+ with profile_step("depth_duration_seconds", depth_duration_seconds, camera_id, frame_count):
50
+ depth_points = await loop.run_in_executor(None, self.depth_model.calculate_depth, image_array, boxes_center)
51
+
52
+ metadata = self._camera_metadata(camera_id, safety_detection, depth_points, boxes_center_ratio)
53
+
54
+ await self.redis.publish("dashboard_stream", metadata.model_dump_json())
55
+ # Even if the camera was disconnected, redis is still going to show its data, which is not accurate.
56
+ # Instead, we set expiry date for the camera data.
57
+ await self.redis.setex(
58
+ f"camera:{camera_id}:latest", # And this is the key, or tag
59
+ 10, # in seconds
60
+ metadata.model_dump_json(),
61
+ )
62
 
63
+ # Note that JSONResponse doesn't work here, as it is for HTTP
64
+ return {"status": 200, "camera_id": camera_id}
 
 
utils/profiling.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import time
3
+ import mlflow
4
+
5
+ @contextmanager
6
+ def profile_step(expr_name: str, prometheus_logger, camera_id, frame_count=None):
7
+ """With statement utility to time block of code"""
8
+ start_time = time.time()
9
+
10
+ try:
11
+ # Code inside with statement
12
+ yield
13
+ finally:
14
+ duration = round(time.time() - start_time, 4)
15
+ prometheus_logger.labels(camera_id).observe(duration)
16
+ mlflow.log_metric(
17
+ expr_name,
18
+ duration,
19
+ frame_count,
20
+ )