Update app.py
Browse files
app.py
CHANGED
|
@@ -132,7 +132,7 @@ class YOLOv8Model:
|
|
| 132 |
def predict(self, image):
|
| 133 |
try:
|
| 134 |
results = self.model(image)
|
| 135 |
-
return results # Return full results for
|
| 136 |
except Exception as e:
|
| 137 |
logger.error(f"Prediction error: {e}")
|
| 138 |
raise
|
|
@@ -149,29 +149,6 @@ def preprocess_frame(frame):
|
|
| 149 |
logger.error(f"Frame preprocessing error: {e}")
|
| 150 |
raise
|
| 151 |
|
| 152 |
-
def draw_bounding_boxes(frame, results):
|
| 153 |
-
try:
|
| 154 |
-
logger.info("Drawing bounding boxes on frame")
|
| 155 |
-
violation_detected = False
|
| 156 |
-
for result in results:
|
| 157 |
-
for box in result.boxes:
|
| 158 |
-
# Extract bounding box coordinates
|
| 159 |
-
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 160 |
-
conf = box.conf[0]
|
| 161 |
-
cls = int(box.cls[0])
|
| 162 |
-
label = f"{result.names[cls]} {conf:.2f}"
|
| 163 |
-
# Draw bounding box and label
|
| 164 |
-
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 165 |
-
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
| 166 |
-
violation_detected = True
|
| 167 |
-
logger.info(f"Violation detected: {label} at coordinates ({x1}, {y1}, {x2}, {y2})")
|
| 168 |
-
if not violation_detected:
|
| 169 |
-
logger.info("No violations detected in the frame")
|
| 170 |
-
return frame
|
| 171 |
-
except Exception as e:
|
| 172 |
-
logger.error(f"Error drawing bounding boxes: {e}")
|
| 173 |
-
raise
|
| 174 |
-
|
| 175 |
def capture_rtsp_frames(rtsp_url: str, max_frames=10):
|
| 176 |
try:
|
| 177 |
logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}")
|
|
@@ -368,8 +345,6 @@ def process_image(image):
|
|
| 368 |
processed_frame = preprocess_frame(frame)
|
| 369 |
start_time = time.time()
|
| 370 |
results = yolo_model.predict(processed_frame)
|
| 371 |
-
# Skip drawing bounding boxes as per request; return the original frame
|
| 372 |
-
# annotated_frame = draw_bounding_boxes(frame.copy(), results)
|
| 373 |
processing_time = time.time() - start_time
|
| 374 |
logger.info(f"Image processing completed in {processing_time:.2f} seconds")
|
| 375 |
|
|
@@ -383,12 +358,14 @@ def process_image(image):
|
|
| 383 |
violations_detected = True
|
| 384 |
cls = int(box.cls[0])
|
| 385 |
violation_type = result.names[cls]
|
| 386 |
-
# Temporary workaround: Simulate all violation types for testing
|
| 387 |
# TODO: Fine-tune the YOLOv8 model with custom classes (no_helmet, unsafe_distance, unauthorized_area)
|
| 388 |
if violation_type == "person":
|
| 389 |
-
logger.warning("Temporary workaround: Simulating violation
|
|
|
|
| 390 |
possible_violations = ["no_helmet", "unsafe_distance", "unauthorized_area"]
|
| 391 |
violation_type = random.choice(possible_violations)
|
|
|
|
| 392 |
# Map YOLO labels to Salesforce picklist values (used when Salesforce is enabled)
|
| 393 |
violation_mapping = {
|
| 394 |
"no_helmet": "No Helmet",
|
|
@@ -448,9 +425,7 @@ def process_rtsp_stream():
|
|
| 448 |
for frame, timestamp in capture_rtsp_frames(RTSP_URL, max_frames=10):
|
| 449 |
processed_frame = preprocess_frame(frame)
|
| 450 |
results = yolo_model.predict(processed_frame)
|
| 451 |
-
#
|
| 452 |
-
# annotated_frame = draw_bounding_boxes(frame.copy(), results)
|
| 453 |
-
frames.append(frame)
|
| 454 |
|
| 455 |
# Log violations if detected
|
| 456 |
for result in results:
|
|
@@ -459,12 +434,14 @@ def process_rtsp_stream():
|
|
| 459 |
if conf > 0.5: # Confidence threshold for >90% accuracy
|
| 460 |
cls = int(box.cls[0])
|
| 461 |
violation_type = result.names[cls]
|
| 462 |
-
# Temporary workaround: Simulate all violation types for testing
|
| 463 |
# TODO: Fine-tune the YOLOv8 model with custom classes (no_helmet, unsafe_distance, unauthorized_area)
|
| 464 |
if violation_type == "person":
|
| 465 |
-
logger.warning("Temporary workaround: Simulating violation
|
|
|
|
| 466 |
possible_violations = ["no_helmet", "unsafe_distance", "unauthorized_area"]
|
| 467 |
violation_type = random.choice(possible_violations)
|
|
|
|
| 468 |
# Map YOLO labels to Salesforce picklist values (used when Salesforce is enabled)
|
| 469 |
violation_mapping = {
|
| 470 |
"no_helmet": "No Helmet",
|
|
|
|
| 132 |
def predict(self, image):
|
| 133 |
try:
|
| 134 |
results = self.model(image)
|
| 135 |
+
return results # Return full results for violation detection
|
| 136 |
except Exception as e:
|
| 137 |
logger.error(f"Prediction error: {e}")
|
| 138 |
raise
|
|
|
|
| 149 |
logger.error(f"Frame preprocessing error: {e}")
|
| 150 |
raise
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def capture_rtsp_frames(rtsp_url: str, max_frames=10):
|
| 153 |
try:
|
| 154 |
logger.info(f"Attempting to connect to RTSP stream: {rtsp_url}")
|
|
|
|
| 345 |
processed_frame = preprocess_frame(frame)
|
| 346 |
start_time = time.time()
|
| 347 |
results = yolo_model.predict(processed_frame)
|
|
|
|
|
|
|
| 348 |
processing_time = time.time() - start_time
|
| 349 |
logger.info(f"Image processing completed in {processing_time:.2f} seconds")
|
| 350 |
|
|
|
|
| 358 |
violations_detected = True
|
| 359 |
cls = int(box.cls[0])
|
| 360 |
violation_type = result.names[cls]
|
| 361 |
+
# Temporary workaround: Simulate detection of all violation types for testing
|
| 362 |
# TODO: Fine-tune the YOLOv8 model with custom classes (no_helmet, unsafe_distance, unauthorized_area)
|
| 363 |
if violation_type == "person":
|
| 364 |
+
logger.warning("Temporary workaround: Simulating safety violation detection. Fine-tune the model with custom safety violation classes for accurate detection.")
|
| 365 |
+
# Randomly assign one of the three violation types for testing
|
| 366 |
possible_violations = ["no_helmet", "unsafe_distance", "unauthorized_area"]
|
| 367 |
violation_type = random.choice(possible_violations)
|
| 368 |
+
logger.info(f"Simulated violation type: {violation_type}")
|
| 369 |
# Map YOLO labels to Salesforce picklist values (used when Salesforce is enabled)
|
| 370 |
violation_mapping = {
|
| 371 |
"no_helmet": "No Helmet",
|
|
|
|
| 425 |
for frame, timestamp in capture_rtsp_frames(RTSP_URL, max_frames=10):
|
| 426 |
processed_frame = preprocess_frame(frame)
|
| 427 |
results = yolo_model.predict(processed_frame)
|
| 428 |
+
frames.append(frame) # Store raw frame without annotations
|
|
|
|
|
|
|
| 429 |
|
| 430 |
# Log violations if detected
|
| 431 |
for result in results:
|
|
|
|
| 434 |
if conf > 0.5: # Confidence threshold for >90% accuracy
|
| 435 |
cls = int(box.cls[0])
|
| 436 |
violation_type = result.names[cls]
|
| 437 |
+
# Temporary workaround: Simulate detection of all violation types for testing
|
| 438 |
# TODO: Fine-tune the YOLOv8 model with custom classes (no_helmet, unsafe_distance, unauthorized_area)
|
| 439 |
if violation_type == "person":
|
| 440 |
+
logger.warning("Temporary workaround: Simulating safety violation detection. Fine-tune the model with custom safety violation classes for accurate detection.")
|
| 441 |
+
# Randomly assign one of the three violation types for testing
|
| 442 |
possible_violations = ["no_helmet", "unsafe_distance", "unauthorized_area"]
|
| 443 |
violation_type = random.choice(possible_violations)
|
| 444 |
+
logger.info(f"Simulated violation type: {violation_type}")
|
| 445 |
# Map YOLO labels to Salesforce picklist values (used when Salesforce is enabled)
|
| 446 |
violation_mapping = {
|
| 447 |
"no_helmet": "No Helmet",
|