Update app.py
Browse files
app.py
CHANGED
|
@@ -2,277 +2,24 @@ import gradio as gr
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
from groq import Groq
|
|
|
|
| 5 |
from PIL import Image as PILImage
|
| 6 |
import io
|
| 7 |
-
import base64
|
| 8 |
-
import torch
|
| 9 |
-
import warnings
|
| 10 |
-
from typing import Tuple, List, Dict, Optional
|
| 11 |
import os
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 15 |
-
warnings.filterwarnings('ignore', category=UserWarning)
|
| 16 |
-
|
| 17 |
-
class RobustSafetyMonitor:
|
| 18 |
-
def __init__(self):
|
| 19 |
-
"""Initialize the safety detection tool with improved configuration."""
|
| 20 |
-
self.client = Groq()
|
| 21 |
-
self.model_name = "llama-3.2-11b-vision-preview"
|
| 22 |
-
self.max_image_size = (800, 800)
|
| 23 |
-
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
| 24 |
-
|
| 25 |
-
# Load YOLOv5 with optimized settings
|
| 26 |
-
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
| 27 |
-
self.yolo_model.conf = 0.25 # Lower confidence threshold
|
| 28 |
-
self.yolo_model.iou = 0.45 # Adjusted IOU threshold
|
| 29 |
-
self.yolo_model.classes = None # Detect all classes
|
| 30 |
-
self.yolo_model.max_det = 50 # Increased maximum detections
|
| 31 |
-
self.yolo_model.cpu()
|
| 32 |
-
self.yolo_model.eval()
|
| 33 |
-
|
| 34 |
-
# Construction-specific keywords
|
| 35 |
-
self.construction_keywords = [
|
| 36 |
-
'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
|
| 37 |
-
'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
|
| 38 |
-
'construction', 'building', 'structure'
|
| 39 |
-
]
|
| 40 |
-
|
| 41 |
-
def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
|
| 42 |
-
"""Process image for analysis."""
|
| 43 |
-
if frame is None:
|
| 44 |
-
raise ValueError("No image provided")
|
| 45 |
-
|
| 46 |
-
if len(frame.shape) == 2:
|
| 47 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
| 48 |
-
elif len(frame.shape) == 3 and frame.shape[2] == 4:
|
| 49 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
| 50 |
-
|
| 51 |
-
return self.resize_image(frame)
|
| 52 |
-
|
| 53 |
-
def resize_image(self, image: np.ndarray) -> np.ndarray:
|
| 54 |
-
"""Resize image while maintaining aspect ratio."""
|
| 55 |
-
height, width = image.shape[:2]
|
| 56 |
-
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
| 57 |
-
aspect = width / height
|
| 58 |
-
if width > height:
|
| 59 |
-
new_width = self.max_image_size[0]
|
| 60 |
-
new_height = int(new_width / aspect)
|
| 61 |
-
else:
|
| 62 |
-
new_height = self.max_image_size[1]
|
| 63 |
-
new_width = int(new_height * aspect)
|
| 64 |
-
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 65 |
-
return image
|
| 66 |
-
|
| 67 |
-
def encode_image(self, frame: np.ndarray) -> str:
|
| 68 |
-
"""Convert image to base64 encoding."""
|
| 69 |
-
try:
|
| 70 |
-
frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 71 |
-
buffered = io.BytesIO()
|
| 72 |
-
frame_pil.save(buffered, format="JPEG", quality=95)
|
| 73 |
-
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 74 |
-
return f"data:image/jpeg;base64,{img_base64}"
|
| 75 |
-
except Exception as e:
|
| 76 |
-
raise ValueError(f"Error encoding image: {str(e)}")
|
| 77 |
-
|
| 78 |
-
def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
|
| 79 |
-
"""Enhanced object detection using YOLOv5."""
|
| 80 |
-
try:
|
| 81 |
-
# Ensure proper image format
|
| 82 |
-
if len(frame.shape) == 2:
|
| 83 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
| 84 |
-
elif frame.shape[2] == 4:
|
| 85 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
|
| 86 |
-
|
| 87 |
-
# Run inference with augmentation
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
results = self.yolo_model(frame, augment=True)
|
| 90 |
-
|
| 91 |
-
# Get detections
|
| 92 |
-
bbox_data = results.xyxy[0].cpu().numpy()
|
| 93 |
-
labels = results.names
|
| 94 |
-
|
| 95 |
-
# Filter and process detections
|
| 96 |
-
processed_boxes = []
|
| 97 |
-
for box in bbox_data:
|
| 98 |
-
x1, y1, x2, y2, conf, cls = box
|
| 99 |
-
if conf > 0.25: # Keep lower confidence threshold
|
| 100 |
-
processed_boxes.append(box)
|
| 101 |
-
|
| 102 |
-
return np.array(processed_boxes), labels
|
| 103 |
-
except Exception as e:
|
| 104 |
-
print(f"Error in object detection: {str(e)}")
|
| 105 |
-
return np.array([]), {}
|
| 106 |
-
|
| 107 |
-
def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
|
| 108 |
-
"""Perform safety analysis using Llama Vision."""
|
| 109 |
-
if frame is None:
|
| 110 |
-
return [], "No frame received"
|
| 111 |
-
|
| 112 |
-
try:
|
| 113 |
-
frame = self.preprocess_image(frame)
|
| 114 |
-
image_base64 = self.encode_image(frame)
|
| 115 |
-
|
| 116 |
-
completion = self.client.chat.completions.create(
|
| 117 |
-
model=self.model_name,
|
| 118 |
-
messages=[
|
| 119 |
-
{
|
| 120 |
-
"role": "user",
|
| 121 |
-
"content": [
|
| 122 |
-
{
|
| 123 |
-
"type": "text",
|
| 124 |
-
"text": """Analyze this workplace image for safety risks. Focus on:
|
| 125 |
-
1. Worker posture and positioning
|
| 126 |
-
2. Equipment and tool safety
|
| 127 |
-
3. Environmental hazards
|
| 128 |
-
4. PPE compliance
|
| 129 |
-
5. Material handling
|
| 130 |
-
|
| 131 |
-
List each risk on a new line starting with 'Risk:'.
|
| 132 |
-
Format: Risk: [Object/Area] - [Detailed description of hazard]"""
|
| 133 |
-
},
|
| 134 |
-
{
|
| 135 |
-
"type": "image_url",
|
| 136 |
-
"image_url": {
|
| 137 |
-
"url": image_base64
|
| 138 |
-
}
|
| 139 |
-
}
|
| 140 |
-
]
|
| 141 |
-
}
|
| 142 |
-
],
|
| 143 |
-
temperature=0.7,
|
| 144 |
-
max_tokens=1024,
|
| 145 |
-
stream=False
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
try:
|
| 149 |
-
response = completion.choices[0].message.content
|
| 150 |
-
except AttributeError:
|
| 151 |
-
response = str(completion.choices[0].message)
|
| 152 |
-
|
| 153 |
-
safety_issues = self.parse_safety_analysis(response)
|
| 154 |
-
return safety_issues, response
|
| 155 |
-
|
| 156 |
-
except Exception as e:
|
| 157 |
-
print(f"Analysis error: {str(e)}")
|
| 158 |
-
return [], f"Analysis Error: {str(e)}"
|
| 159 |
-
|
| 160 |
-
def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
|
| 161 |
-
labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
|
| 162 |
-
"""Improved bounding box visualization."""
|
| 163 |
-
image_copy = image.copy()
|
| 164 |
-
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 165 |
-
font_scale = 0.5
|
| 166 |
-
thickness = 2
|
| 167 |
-
|
| 168 |
-
for idx, bbox in enumerate(bboxes):
|
| 169 |
-
try:
|
| 170 |
-
x1, y1, x2, y2, conf, class_id = bbox
|
| 171 |
-
label = labels[int(class_id)]
|
| 172 |
-
|
| 173 |
-
# Check if object is construction-related
|
| 174 |
-
is_relevant = any(keyword in label.lower() for keyword in self.construction_keywords)
|
| 175 |
-
|
| 176 |
-
if is_relevant or conf > 0.35:
|
| 177 |
-
color = self.colors[idx % len(self.colors)]
|
| 178 |
-
|
| 179 |
-
# Convert coordinates to integers
|
| 180 |
-
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
| 181 |
-
|
| 182 |
-
# Draw bounding box
|
| 183 |
-
cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
|
| 184 |
-
|
| 185 |
-
# Check for associated safety issues
|
| 186 |
-
risk_found = False
|
| 187 |
-
for safety_issue in safety_issues:
|
| 188 |
-
issue_keywords = safety_issue.get('object', '').lower().split()
|
| 189 |
-
if any(keyword in label.lower() for keyword in issue_keywords):
|
| 190 |
-
label_text = f"Risk: {safety_issue.get('description', '')}"
|
| 191 |
-
y_pos = max(y1 - 10, 20)
|
| 192 |
-
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
| 193 |
-
font_scale, (0, 0, 255), thickness)
|
| 194 |
-
risk_found = True
|
| 195 |
-
break
|
| 196 |
-
|
| 197 |
-
if not risk_found:
|
| 198 |
-
label_text = f"{label} {conf:.2f}"
|
| 199 |
-
y_pos = max(y1 - 10, 20)
|
| 200 |
-
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
| 201 |
-
font_scale, color, thickness)
|
| 202 |
-
|
| 203 |
-
# Mark high-risk areas
|
| 204 |
-
if conf > 0.5 and any(risk_word in label.lower() for risk_word in
|
| 205 |
-
['worker', 'person', 'equipment', 'machine']):
|
| 206 |
-
cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
|
| 207 |
-
5, (0, 0, 255), -1)
|
| 208 |
-
|
| 209 |
-
except Exception as e:
|
| 210 |
-
print(f"Error drawing box: {str(e)}")
|
| 211 |
-
continue
|
| 212 |
-
|
| 213 |
-
return image_copy
|
| 214 |
-
|
| 215 |
-
def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
|
| 216 |
-
"""Main processing pipeline for safety analysis."""
|
| 217 |
-
if frame is None:
|
| 218 |
-
return None, "No image provided"
|
| 219 |
-
|
| 220 |
-
try:
|
| 221 |
-
# Detect objects
|
| 222 |
-
bbox_data, labels = self.detect_objects(frame)
|
| 223 |
-
|
| 224 |
-
# Get safety analysis
|
| 225 |
-
safety_issues, analysis = self.analyze_frame(frame)
|
| 226 |
-
|
| 227 |
-
# Draw annotations
|
| 228 |
-
annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
|
| 229 |
-
|
| 230 |
-
return annotated_frame, analysis
|
| 231 |
-
|
| 232 |
-
except Exception as e:
|
| 233 |
-
print(f"Processing error: {str(e)}")
|
| 234 |
-
return None, f"Error processing image: {str(e)}"
|
| 235 |
-
|
| 236 |
-
def parse_safety_analysis(self, analysis: str) -> List[Dict]:
|
| 237 |
-
"""Parse the safety analysis text."""
|
| 238 |
-
safety_issues = []
|
| 239 |
-
|
| 240 |
-
if not isinstance(analysis, str):
|
| 241 |
-
return safety_issues
|
| 242 |
-
|
| 243 |
-
for line in analysis.split('\n'):
|
| 244 |
-
if "risk:" in line.lower():
|
| 245 |
-
try:
|
| 246 |
-
parts = line.lower().split('risk:', 1)[1].strip()
|
| 247 |
-
if '-' in parts:
|
| 248 |
-
obj, desc = parts.split('-', 1)
|
| 249 |
-
else:
|
| 250 |
-
obj, desc = parts, parts
|
| 251 |
-
|
| 252 |
-
safety_issues.append({
|
| 253 |
-
"object": obj.strip(),
|
| 254 |
-
"description": desc.strip()
|
| 255 |
-
})
|
| 256 |
-
except Exception as e:
|
| 257 |
-
print(f"Error parsing line: {line}, Error: {str(e)}")
|
| 258 |
-
continue
|
| 259 |
-
|
| 260 |
-
return safety_issues
|
| 261 |
-
|
| 262 |
|
| 263 |
def create_monitor_interface():
|
| 264 |
api_key = os.getenv("GROQ_API_KEY")
|
| 265 |
|
| 266 |
class SafetyMonitor:
|
| 267 |
def __init__(self):
|
| 268 |
-
"""Initialize Safety Monitor with configuration."""
|
| 269 |
self.client = Groq()
|
| 270 |
self.model_name = "llama-3.2-90b-vision-preview"
|
| 271 |
-
self.max_image_size = (800, 800)
|
| 272 |
-
self.colors = [(
|
| 273 |
-
|
| 274 |
def resize_image(self, image):
|
| 275 |
-
"""Resize image while maintaining aspect ratio."""
|
| 276 |
height, width = image.shape[:2]
|
| 277 |
aspect = width / height
|
| 278 |
|
|
@@ -286,7 +33,6 @@ def create_monitor_interface():
|
|
| 286 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 287 |
|
| 288 |
def analyze_frame(self, frame: np.ndarray) -> str:
|
| 289 |
-
"""Analyze frame for safety concerns."""
|
| 290 |
if frame is None:
|
| 291 |
return "No frame received"
|
| 292 |
|
|
@@ -299,11 +45,11 @@ def create_monitor_interface():
|
|
| 299 |
frame = self.resize_image(frame)
|
| 300 |
frame_pil = PILImage.fromarray(frame)
|
| 301 |
|
| 302 |
-
# Convert to base64
|
| 303 |
buffered = io.BytesIO()
|
| 304 |
frame_pil.save(buffered,
|
| 305 |
format="JPEG",
|
| 306 |
-
quality=
|
| 307 |
optimize=True)
|
| 308 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 309 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
@@ -317,21 +63,9 @@ def create_monitor_interface():
|
|
| 317 |
"content": [
|
| 318 |
{
|
| 319 |
"type": "text",
|
| 320 |
-
"text": """Analyze this workplace image
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
Format each finding as:
|
| 325 |
-
- <location>position:detailed safety description</location>
|
| 326 |
-
|
| 327 |
-
Consider:
|
| 328 |
-
- PPE usage and compliance
|
| 329 |
-
- Ergonomic risks
|
| 330 |
-
- Equipment safety
|
| 331 |
-
- Environmental hazards
|
| 332 |
-
- Work procedures
|
| 333 |
-
- Material handling
|
| 334 |
-
"""
|
| 335 |
},
|
| 336 |
{
|
| 337 |
"type": "image_url",
|
|
@@ -340,139 +74,78 @@ def create_monitor_interface():
|
|
| 340 |
}
|
| 341 |
}
|
| 342 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
}
|
| 344 |
],
|
| 345 |
-
temperature=0.
|
| 346 |
-
max_tokens=
|
| 347 |
-
|
|
|
|
|
|
|
| 348 |
)
|
| 349 |
return completion.choices[0].message.content
|
| 350 |
except Exception as e:
|
| 351 |
-
print(f"
|
| 352 |
return f"Analysis Error: {str(e)}"
|
| 353 |
|
| 354 |
def draw_observations(self, image, observations):
|
| 355 |
-
"""Draw safety observations with accurate locations."""
|
| 356 |
height, width = image.shape[:2]
|
| 357 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 358 |
font_scale = 0.5
|
| 359 |
thickness = 2
|
| 360 |
|
| 361 |
-
|
| 362 |
-
"""Get coordinates based on location description."""
|
| 363 |
-
location_text = location_text.lower()
|
| 364 |
-
regions = {
|
| 365 |
-
# Basic positions
|
| 366 |
-
'center': (width//3, height//3, 2*width//3, 2*height//3),
|
| 367 |
-
'top': (width//4, 0, 3*width//4, height//3),
|
| 368 |
-
'bottom': (width//4, 2*height//3, 3*width//4, height),
|
| 369 |
-
'left': (0, height//4, width//3, 3*height//4),
|
| 370 |
-
'right': (2*width//3, height//4, width, 3*height//4),
|
| 371 |
-
'top-left': (0, 0, width//3, height//3),
|
| 372 |
-
'top-right': (2*width//3, 0, width, height//3),
|
| 373 |
-
'bottom-left': (0, 2*height//3, width//3, height),
|
| 374 |
-
'bottom-right': (2*width//3, 2*height//3, width, height),
|
| 375 |
-
|
| 376 |
-
# Work areas
|
| 377 |
-
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
|
| 378 |
-
'machine': (2*width//3, 0, width, height),
|
| 379 |
-
'equipment': (2*width//3, height//3, width, 2*height//3),
|
| 380 |
-
'material': (0, 2*height//3, width//3, height),
|
| 381 |
-
'ground': (0, 2*height//3, width, height),
|
| 382 |
-
'floor': (0, 3*height//4, width, height),
|
| 383 |
-
|
| 384 |
-
# Body regions
|
| 385 |
-
'body': (width//3, height//3, 2*width//3, 2*height//3),
|
| 386 |
-
'hands': (width//2, height//2, 3*width//4, 2*height//3),
|
| 387 |
-
'head': (width//3, 0, 2*width//3, height//4),
|
| 388 |
-
'feet': (width//3, 3*height//4, 2*width//3, height),
|
| 389 |
-
'back': (width//3, height//3, 2*width//3, 2*height//3),
|
| 390 |
-
'knees': (width//3, 2*height//3, 2*width//3, height),
|
| 391 |
-
|
| 392 |
-
# Special areas
|
| 393 |
-
'workspace': (width//4, height//4, 3*width//4, 3*height//4),
|
| 394 |
-
'working-area': (width//4, height//4, 3*width//4, 3*height//4),
|
| 395 |
-
'surrounding': (0, 0, width, height),
|
| 396 |
-
'background': (0, 0, width, height)
|
| 397 |
-
}
|
| 398 |
-
|
| 399 |
-
# Find best matching region
|
| 400 |
-
best_match = 'center' # default
|
| 401 |
-
max_match_length = 0
|
| 402 |
-
|
| 403 |
-
for region_name in regions.keys():
|
| 404 |
-
if region_name in location_text and len(region_name) > max_match_length:
|
| 405 |
-
best_match = region_name
|
| 406 |
-
max_match_length = len(region_name)
|
| 407 |
-
|
| 408 |
-
return regions[best_match]
|
| 409 |
-
|
| 410 |
for idx, obs in enumerate(observations):
|
| 411 |
color = self.colors[idx % len(self.colors)]
|
| 412 |
|
| 413 |
-
#
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
else:
|
| 419 |
-
location = 'center'
|
| 420 |
-
description = obs
|
| 421 |
-
|
| 422 |
-
# Get region coordinates
|
| 423 |
-
x1, y1, x2, y2 = get_region_coordinates(location)
|
| 424 |
|
| 425 |
# Draw rectangle
|
| 426 |
-
cv2.rectangle(image, (
|
| 427 |
|
| 428 |
-
# Add label
|
| 429 |
-
label =
|
| 430 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
text_x = max(0, x1)
|
| 434 |
-
text_y = max(20, y1 - 5)
|
| 435 |
-
|
| 436 |
-
# Draw text background
|
| 437 |
-
cv2.rectangle(image,
|
| 438 |
-
(text_x, text_y - label_size[1] - 5),
|
| 439 |
-
(text_x + label_size[0], text_y),
|
| 440 |
-
color, -1)
|
| 441 |
-
|
| 442 |
-
# Draw text
|
| 443 |
-
cv2.putText(image, label, (text_x, text_y - 5),
|
| 444 |
-
font, font_scale, (255, 255, 255), thickness)
|
| 445 |
|
| 446 |
return image
|
| 447 |
|
| 448 |
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
|
| 449 |
-
"""Process frame and generate safety analysis."""
|
| 450 |
if frame is None:
|
| 451 |
return None, "No image provided"
|
| 452 |
|
| 453 |
analysis = self.analyze_frame(frame)
|
| 454 |
display_frame = self.resize_image(frame.copy())
|
| 455 |
|
| 456 |
-
# Parse observations
|
| 457 |
observations = []
|
| 458 |
for line in analysis.split('\n'):
|
| 459 |
line = line.strip()
|
| 460 |
if line.startswith('-'):
|
|
|
|
| 461 |
if '<location>' in line and '</location>' in line:
|
| 462 |
start = line.find('<location>') + len('<location>')
|
| 463 |
end = line.find('</location>')
|
| 464 |
-
observation = line[
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
| 467 |
|
| 468 |
-
# Draw observations
|
| 469 |
-
|
| 470 |
-
annotated_frame = self.draw_observations(display_frame, observations)
|
| 471 |
-
return annotated_frame, analysis
|
| 472 |
|
| 473 |
-
return
|
| 474 |
|
| 475 |
-
# Create interface
|
| 476 |
monitor = SafetyMonitor()
|
| 477 |
|
| 478 |
with gr.Blocks() as demo:
|
|
@@ -480,7 +153,7 @@ def create_monitor_interface():
|
|
| 480 |
|
| 481 |
with gr.Row():
|
| 482 |
input_image = gr.Image(label="Upload Image")
|
| 483 |
-
output_image = gr.Image(label="
|
| 484 |
|
| 485 |
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
|
| 486 |
|
|
@@ -500,15 +173,7 @@ def create_monitor_interface():
|
|
| 500 |
outputs=[output_image, analysis_text]
|
| 501 |
)
|
| 502 |
|
| 503 |
-
gr.Markdown("""
|
| 504 |
-
## Instructions:
|
| 505 |
-
1. Upload a workplace image
|
| 506 |
-
2. View detected safety concerns
|
| 507 |
-
3. Check detailed analysis
|
| 508 |
-
""")
|
| 509 |
-
|
| 510 |
return demo
|
| 511 |
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
demo.launch()
|
|
|
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
from groq import Groq
|
| 5 |
+
import time
|
| 6 |
from PIL import Image as PILImage
|
| 7 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import os
|
| 9 |
+
import base64
|
| 10 |
+
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def create_monitor_interface():
|
| 13 |
api_key = os.getenv("GROQ_API_KEY")
|
| 14 |
|
| 15 |
class SafetyMonitor:
|
| 16 |
def __init__(self):
|
|
|
|
| 17 |
self.client = Groq()
|
| 18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
| 19 |
+
self.max_image_size = (800, 800) # Increased size for better visibility
|
| 20 |
+
self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
|
| 21 |
+
|
| 22 |
def resize_image(self, image):
|
|
|
|
| 23 |
height, width = image.shape[:2]
|
| 24 |
aspect = width / height
|
| 25 |
|
|
|
|
| 33 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
| 34 |
|
| 35 |
def analyze_frame(self, frame: np.ndarray) -> str:
|
|
|
|
| 36 |
if frame is None:
|
| 37 |
return "No frame received"
|
| 38 |
|
|
|
|
| 45 |
frame = self.resize_image(frame)
|
| 46 |
frame_pil = PILImage.fromarray(frame)
|
| 47 |
|
| 48 |
+
# Convert to base64 with minimal quality
|
| 49 |
buffered = io.BytesIO()
|
| 50 |
frame_pil.save(buffered,
|
| 51 |
format="JPEG",
|
| 52 |
+
quality=30,
|
| 53 |
optimize=True)
|
| 54 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 55 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
|
|
| 63 |
"content": [
|
| 64 |
{
|
| 65 |
"type": "text",
|
| 66 |
+
"text": """Analyze this workplace image and describe each safety concern in this format:
|
| 67 |
+
- <location>Description</location>
|
| 68 |
+
Use one line per issue, starting with a dash and location in tags."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
},
|
| 70 |
{
|
| 71 |
"type": "image_url",
|
|
|
|
| 74 |
}
|
| 75 |
}
|
| 76 |
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"role": "assistant",
|
| 80 |
+
"content": ""
|
| 81 |
}
|
| 82 |
],
|
| 83 |
+
temperature=0.1,
|
| 84 |
+
max_tokens=150,
|
| 85 |
+
top_p=1,
|
| 86 |
+
stream=False,
|
| 87 |
+
stop=None
|
| 88 |
)
|
| 89 |
return completion.choices[0].message.content
|
| 90 |
except Exception as e:
|
| 91 |
+
print(f"Detailed error: {str(e)}")
|
| 92 |
return f"Analysis Error: {str(e)}"
|
| 93 |
|
| 94 |
def draw_observations(self, image, observations):
|
|
|
|
| 95 |
height, width = image.shape[:2]
|
| 96 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 97 |
font_scale = 0.5
|
| 98 |
thickness = 2
|
| 99 |
|
| 100 |
+
# Generate random positions for each observation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
for idx, obs in enumerate(observations):
|
| 102 |
color = self.colors[idx % len(self.colors)]
|
| 103 |
|
| 104 |
+
# Generate random box position
|
| 105 |
+
box_width = width // 3
|
| 106 |
+
box_height = height // 3
|
| 107 |
+
x = random.randint(0, width - box_width)
|
| 108 |
+
y = random.randint(0, height - box_height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Draw rectangle
|
| 111 |
+
cv2.rectangle(image, (x, y), (x + box_width, y + box_height), color, 2)
|
| 112 |
|
| 113 |
+
# Add label with background
|
| 114 |
+
label = obs[:40] + "..." if len(obs) > 40 else obs
|
| 115 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
| 116 |
+
cv2.rectangle(image, (x, y - 20), (x + label_size[0], y), color, -1)
|
| 117 |
+
cv2.putText(image, label, (x, y - 5), font, font_scale, (255, 255, 255), thickness)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
return image
|
| 120 |
|
| 121 |
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
|
|
|
|
| 122 |
if frame is None:
|
| 123 |
return None, "No image provided"
|
| 124 |
|
| 125 |
analysis = self.analyze_frame(frame)
|
| 126 |
display_frame = self.resize_image(frame.copy())
|
| 127 |
|
| 128 |
+
# Parse observations from the analysis
|
| 129 |
observations = []
|
| 130 |
for line in analysis.split('\n'):
|
| 131 |
line = line.strip()
|
| 132 |
if line.startswith('-'):
|
| 133 |
+
# Extract text between <location> tags if present
|
| 134 |
if '<location>' in line and '</location>' in line:
|
| 135 |
start = line.find('<location>') + len('<location>')
|
| 136 |
end = line.find('</location>')
|
| 137 |
+
observation = line[end + len('</location>'):].strip()
|
| 138 |
+
else:
|
| 139 |
+
observation = line[1:].strip() # Remove the dash
|
| 140 |
+
if observation:
|
| 141 |
+
observations.append(observation)
|
| 142 |
|
| 143 |
+
# Draw observations on the image
|
| 144 |
+
annotated_frame = self.draw_observations(display_frame, observations)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
return annotated_frame, analysis
|
| 147 |
|
| 148 |
+
# Create the main interface
|
| 149 |
monitor = SafetyMonitor()
|
| 150 |
|
| 151 |
with gr.Blocks() as demo:
|
|
|
|
| 153 |
|
| 154 |
with gr.Row():
|
| 155 |
input_image = gr.Image(label="Upload Image")
|
| 156 |
+
output_image = gr.Image(label="Annotated Results")
|
| 157 |
|
| 158 |
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
|
| 159 |
|
|
|
|
| 173 |
outputs=[output_image, analysis_text]
|
| 174 |
)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
return demo
|
| 177 |
|
| 178 |
+
demo = create_monitor_interface()
|
| 179 |
+
demo.launch()
|
|
|