Update app.py
Browse files
app.py
CHANGED
|
@@ -16,16 +16,15 @@ def create_monitor_interface():
|
|
| 16 |
def __init__(self):
|
| 17 |
self.client = Groq()
|
| 18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
| 19 |
-
self.max_image_size = (800, 800)
|
| 20 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
| 21 |
self.last_analysis_time = 0
|
| 22 |
-
self.analysis_interval = 2
|
| 23 |
-
self.last_observations = []
|
| 24 |
|
| 25 |
def resize_image(self, image):
|
| 26 |
height, width = image.shape[:2]
|
| 27 |
|
| 28 |
-
# Only resize if image is too large
|
| 29 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
| 30 |
aspect = width / height
|
| 31 |
if width > height:
|
|
@@ -50,11 +49,10 @@ def create_monitor_interface():
|
|
| 50 |
frame = self.resize_image(frame)
|
| 51 |
frame_pil = PILImage.fromarray(frame)
|
| 52 |
|
| 53 |
-
# Convert to base64 with better quality
|
| 54 |
buffered = io.BytesIO()
|
| 55 |
frame_pil.save(buffered,
|
| 56 |
format="JPEG",
|
| 57 |
-
quality=85,
|
| 58 |
optimize=True)
|
| 59 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 60 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
@@ -111,12 +109,10 @@ def create_monitor_interface():
|
|
| 111 |
'bottom-right': (2*width//3, 2*height//3, width, height)
|
| 112 |
}
|
| 113 |
|
| 114 |
-
# Find the best matching region
|
| 115 |
for region_name, coords in regions.items():
|
| 116 |
if region_name in position.lower():
|
| 117 |
return coords
|
| 118 |
|
| 119 |
-
# Default to center if no match
|
| 120 |
return regions['center']
|
| 121 |
|
| 122 |
def draw_observations(self, image, observations):
|
|
@@ -128,7 +124,6 @@ def create_monitor_interface():
|
|
| 128 |
for idx, obs in enumerate(observations):
|
| 129 |
color = self.colors[idx % len(self.colors)]
|
| 130 |
|
| 131 |
-
# Try to extract position from observation
|
| 132 |
parts = obs.split(':')
|
| 133 |
if len(parts) >= 2:
|
| 134 |
position = parts[0]
|
|
@@ -137,17 +132,13 @@ def create_monitor_interface():
|
|
| 137 |
position = 'center'
|
| 138 |
description = obs
|
| 139 |
|
| 140 |
-
# Get coordinates based on position
|
| 141 |
x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
|
| 142 |
|
| 143 |
-
# Draw rectangle
|
| 144 |
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
| 145 |
|
| 146 |
-
# Add label with background
|
| 147 |
label = description[:50] + "..." if len(description) > 50 else description
|
| 148 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
| 149 |
|
| 150 |
-
# Ensure label stays within image bounds
|
| 151 |
label_x = max(0, min(x1, width - label_size[0]))
|
| 152 |
label_y = max(20, y1 - 5)
|
| 153 |
|
|
@@ -164,12 +155,10 @@ def create_monitor_interface():
|
|
| 164 |
|
| 165 |
current_time = time.time()
|
| 166 |
|
| 167 |
-
# Only perform analysis if enough time has passed
|
| 168 |
if current_time - self.last_analysis_time >= self.analysis_interval:
|
| 169 |
analysis = self.analyze_frame(frame)
|
| 170 |
self.last_analysis_time = current_time
|
| 171 |
|
| 172 |
-
# Parse observations
|
| 173 |
observations = []
|
| 174 |
for line in analysis.split('\n'):
|
| 175 |
line = line.strip()
|
|
@@ -183,7 +172,6 @@ def create_monitor_interface():
|
|
| 183 |
|
| 184 |
self.last_observations = observations
|
| 185 |
|
| 186 |
-
# Draw observations on the frame
|
| 187 |
display_frame = frame.copy()
|
| 188 |
annotated_frame = self.draw_observations(display_frame, self.last_observations)
|
| 189 |
|
|
@@ -196,12 +184,12 @@ def create_monitor_interface():
|
|
| 196 |
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
|
| 197 |
|
| 198 |
with gr.Row():
|
| 199 |
-
|
| 200 |
output_image = gr.Image(label="Analysis")
|
| 201 |
|
| 202 |
analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
|
| 203 |
|
| 204 |
-
def
|
| 205 |
if image is None:
|
| 206 |
return None, "No image provided"
|
| 207 |
try:
|
|
@@ -211,12 +199,19 @@ def create_monitor_interface():
|
|
| 211 |
print(f"Processing error: {str(e)}")
|
| 212 |
return None, f"Error processing image: {str(e)}"
|
| 213 |
|
| 214 |
-
|
| 215 |
-
fn=
|
| 216 |
-
|
| 217 |
-
|
| 218 |
)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
return demo
|
| 221 |
|
| 222 |
demo = create_monitor_interface()
|
|
|
|
| 16 |
def __init__(self):
|
| 17 |
self.client = Groq()
|
| 18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
| 19 |
+
self.max_image_size = (800, 800)
|
| 20 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
| 21 |
self.last_analysis_time = 0
|
| 22 |
+
self.analysis_interval = 2
|
| 23 |
+
self.last_observations = []
|
| 24 |
|
| 25 |
def resize_image(self, image):
|
| 26 |
height, width = image.shape[:2]
|
| 27 |
|
|
|
|
| 28 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
| 29 |
aspect = width / height
|
| 30 |
if width > height:
|
|
|
|
| 49 |
frame = self.resize_image(frame)
|
| 50 |
frame_pil = PILImage.fromarray(frame)
|
| 51 |
|
|
|
|
| 52 |
buffered = io.BytesIO()
|
| 53 |
frame_pil.save(buffered,
|
| 54 |
format="JPEG",
|
| 55 |
+
quality=85,
|
| 56 |
optimize=True)
|
| 57 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 58 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
|
|
| 109 |
'bottom-right': (2*width//3, 2*height//3, width, height)
|
| 110 |
}
|
| 111 |
|
|
|
|
| 112 |
for region_name, coords in regions.items():
|
| 113 |
if region_name in position.lower():
|
| 114 |
return coords
|
| 115 |
|
|
|
|
| 116 |
return regions['center']
|
| 117 |
|
| 118 |
def draw_observations(self, image, observations):
|
|
|
|
| 124 |
for idx, obs in enumerate(observations):
|
| 125 |
color = self.colors[idx % len(self.colors)]
|
| 126 |
|
|
|
|
| 127 |
parts = obs.split(':')
|
| 128 |
if len(parts) >= 2:
|
| 129 |
position = parts[0]
|
|
|
|
| 132 |
position = 'center'
|
| 133 |
description = obs
|
| 134 |
|
|
|
|
| 135 |
x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
|
| 136 |
|
|
|
|
| 137 |
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
| 138 |
|
|
|
|
| 139 |
label = description[:50] + "..." if len(description) > 50 else description
|
| 140 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
| 141 |
|
|
|
|
| 142 |
label_x = max(0, min(x1, width - label_size[0]))
|
| 143 |
label_y = max(20, y1 - 5)
|
| 144 |
|
|
|
|
| 155 |
|
| 156 |
current_time = time.time()
|
| 157 |
|
|
|
|
| 158 |
if current_time - self.last_analysis_time >= self.analysis_interval:
|
| 159 |
analysis = self.analyze_frame(frame)
|
| 160 |
self.last_analysis_time = current_time
|
| 161 |
|
|
|
|
| 162 |
observations = []
|
| 163 |
for line in analysis.split('\n'):
|
| 164 |
line = line.strip()
|
|
|
|
| 172 |
|
| 173 |
self.last_observations = observations
|
| 174 |
|
|
|
|
| 175 |
display_frame = frame.copy()
|
| 176 |
annotated_frame = self.draw_observations(display_frame, self.last_observations)
|
| 177 |
|
|
|
|
| 184 |
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
|
| 185 |
|
| 186 |
with gr.Row():
|
| 187 |
+
input_image = gr.Image(label="Upload Image")
|
| 188 |
output_image = gr.Image(label="Analysis")
|
| 189 |
|
| 190 |
analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
|
| 191 |
|
| 192 |
+
def analyze_image(image):
|
| 193 |
if image is None:
|
| 194 |
return None, "No image provided"
|
| 195 |
try:
|
|
|
|
| 199 |
print(f"Processing error: {str(e)}")
|
| 200 |
return None, f"Error processing image: {str(e)}"
|
| 201 |
|
| 202 |
+
input_image.change(
|
| 203 |
+
fn=analyze_image,
|
| 204 |
+
inputs=input_image,
|
| 205 |
+
outputs=[output_image, analysis_text]
|
| 206 |
)
|
| 207 |
|
| 208 |
+
gr.Markdown("""
|
| 209 |
+
## Instructions:
|
| 210 |
+
1. Upload an image to analyze safety concerns
|
| 211 |
+
2. View annotated results and detailed analysis
|
| 212 |
+
3. Each box highlights a potential safety issue
|
| 213 |
+
""")
|
| 214 |
+
|
| 215 |
return demo
|
| 216 |
|
| 217 |
demo = create_monitor_interface()
|