HayLahav commited on
Commit
6d66207
·
verified ·
1 Parent(s): 287cfc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +547 -79
app.py CHANGED
@@ -4,119 +4,473 @@ import requests
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
- from ultralytics import YOLO # YOLOv5 model
8
  import cv2
9
  import numpy as np
 
 
 
10
  from Gradio_UI import GradioUI
11
 
12
  @tool
13
- def get_yolov5_coco_detections(image_path: str) -> dict:
14
- """Detects objects using YOLOv5 on the COCO dataset and provides structured outputs."""
15
- model = YOLO("yolov5s.pt")
16
- image = cv2.imread(image_path)
17
- results = model(image)
18
 
19
- detections = []
20
- if results:
21
- for r in results.pred[0]:
22
- x1, y1, x2, y2, conf, cls = r.tolist()
23
- class_name = model.names[int(cls)]
24
- detections.append({"object": class_name, "confidence": conf, "bbox": (x1, y1, x2, y2)})
25
 
26
- return {"detected_objects": detections} if detections else {"detected_objects": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- @tool
29
- def detect_road_lanes(image_path: str) -> dict:
30
- """Detects road lanes using a YOLOv5 model trained for lane detection."""
31
- model = YOLO("yolov5-lane.pt")
32
- image = cv2.imread(image_path)
33
- results = model(image)
34
 
35
- lane_detections = []
36
- if results:
37
- for r in results.pred[0]:
38
- x1, y1, x2, y2, conf, cls = r.tolist()
39
- lane_detections.append({"lane": f"Lane {cls}", "confidence": conf, "bbox": (x1, y1, x2, y2)})
40
 
41
- return {"detected_lanes": lane_detections} if lane_detections else {"detected_lanes": []}
 
42
 
43
- @tool
44
- def driving_situation_analyzer(image_path: str) -> dict:
45
- """Analyzes road conditions by integrating object detections and lane information."""
46
- objects_info = get_yolov5_coco_detections(image_path)
47
- lanes_info = detect_road_lanes(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- detected_objects = objects_info.get("detected_objects", [])
50
- detected_lanes = lanes_info.get("detected_lanes", [])
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  situation = []
53
 
54
  if any(obj["object"] in ["car", "truck", "bus"] for obj in detected_objects):
55
  situation.append("Traffic detected ahead, maintain safe distance.")
56
 
57
- if any(obj["object"] == "pedestrian" for obj in detected_objects):
58
  situation.append("Pedestrian detected, be prepared to stop.")
59
 
60
  if any(obj["object"] == "traffic light" for obj in detected_objects):
61
  situation.append("Traffic light detected, slow down if red.")
 
 
 
62
 
63
- if not detected_lanes:
64
  situation.append("Lane markings not detected, potential risk of veering.")
65
-
66
- elif len(detected_lanes) == 1:
67
  situation.append("Single lane detected, ensure proper lane following.")
68
-
69
- elif len(detected_lanes) >= 2:
70
- situation.append("Multiple lanes detected, stay within lane boundaries.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  return {
73
  "situation_summary": " | ".join(situation) if situation else "Road situation unclear, proceed with caution.",
74
  "detected_objects": detected_objects,
75
- "detected_lanes": detected_lanes
 
 
 
 
 
76
  }
77
 
 
78
  @tool
79
- def predict_trajectory(image_path: str) -> dict:
80
  """Predicts vehicle trajectory based on the driving situation analysis.
81
 
82
- Uses OpenPilot-style motion prediction based on detected lanes, objects, and road conditions.
 
 
 
 
 
 
83
  """
84
- analysis = driving_situation_analyzer(image_path)
85
- detected_objects = analysis["detected_objects"]
86
- detected_lanes = analysis["detected_lanes"]
87
- summary = analysis["situation_summary"]
88
-
 
 
 
 
 
 
 
 
 
 
89
  trajectory = []
90
-
91
- # Define simple trajectory logic
92
- if "Traffic detected" in summary:
93
- trajectory.append("Reduce speed, maintain a safe distance.")
94
-
95
- if "Pedestrian detected" in summary:
96
- trajectory.append("Prepare for sudden braking or yielding.")
97
-
98
- if "Traffic light detected" in summary:
99
- trajectory.append("Adjust speed based on light status.")
100
-
101
- if "Lane markings not detected" in summary:
102
- trajectory.append("Risk of lane departure, drive cautiously.")
103
-
104
- if len(detected_lanes) >= 2:
105
- trajectory.append("Stay centered in the lane, adjust for merging vehicles.")
106
-
107
- # Generate simple trajectory points (Mocked example)
108
- future_positions = []
109
- x, y = 0, 0
110
- for t in range(10): # Predict 10 future steps
111
- x += np.random.uniform(-0.5, 0.5) # Small lateral deviation
112
- y += 1 # Move forward
113
- future_positions.append((x, y))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  return {
116
  "trajectory_recommendation": " | ".join(trajectory) if trajectory else "Maintain current path.",
117
- "future_positions": future_positions
 
 
 
 
 
 
 
118
  }
119
 
 
120
  @tool
121
  def get_current_time_in_timezone(timezone: str) -> str:
122
  """Fetches the current local time in a specified timezone."""
@@ -127,8 +481,11 @@ def get_current_time_in_timezone(timezone: str) -> str:
127
  except Exception as e:
128
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
129
 
 
 
130
  final_answer = FinalAnswerTool()
131
 
 
132
  model = HfApiModel(
133
  max_tokens=2096,
134
  temperature=0.5,
@@ -136,19 +493,29 @@ model = HfApiModel(
136
  custom_role_conversions=None,
137
  )
138
 
139
- image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
140
-
141
- with open("prompts.yaml", 'r') as stream:
142
- prompt_templates = yaml.safe_load(stream)
 
 
 
 
 
 
 
 
143
 
 
144
  agent = CodeAgent(
145
  model=model,
146
  tools=[
147
  final_answer,
148
- get_yolov5_coco_detections,
149
  detect_road_lanes,
150
  driving_situation_analyzer,
151
- predict_trajectory
 
152
  ],
153
  max_steps=6,
154
  verbosity_level=1,
@@ -159,4 +526,105 @@ agent = CodeAgent(
159
  prompt_templates=prompt_templates
160
  )
161
 
162
- GradioUI(agent).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
+ from ultralytics import YOLO # YOLOv8 model
8
  import cv2
9
  import numpy as np
10
+ import os
11
+ import tempfile
12
+ import gradio as gr
13
  from Gradio_UI import GradioUI
14
 
15
  @tool
16
+ def get_yolov8_coco_detections(video_path: str) -> str:
17
+ """Detects objects in an MP4 video file using YOLOv8.
 
 
 
18
 
19
+ Args:
20
+ video_path: Path to the input video.
 
 
 
 
21
 
22
+ Returns:
23
+ Processed video file path with detections.
24
+ """
25
+ model = YOLO("yolov8s.pt") # Load pre-trained YOLOv8 model
26
+ cap = cv2.VideoCapture(video_path) # Load video
27
+
28
+ if not cap.isOpened():
29
+ return f"Error: Could not open video file at {video_path}"
30
+
31
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
32
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
33
+ fps = cap.get(cv2.CAP_PROP_FPS)
34
+
35
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for output video
36
+ output_path = "output_video.mp4" # Save processed video
37
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
38
+
39
+ unique_detections = set()
40
+
41
+ while cap.isOpened():
42
+ ret, frame = cap.read()
43
+ if not ret:
44
+ break # End of video
45
+
46
+ results = model(frame) # Run YOLOv8 inference
47
+
48
+ for r in results:
49
+ boxes = r.boxes
50
+ for box in boxes:
51
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
52
+ conf = box.conf[0].item()
53
+ cls = int(box.cls[0].item())
54
+ class_name = model.names[cls]
55
+
56
+ unique_detections.add(f"{class_name}")
57
+
58
+ # Draw bounding box
59
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
60
+ # Add label
61
+ label = f"{class_name} {conf:.2f}"
62
+ cv2.putText(frame, label, (int(x1), int(y1)-10),
63
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
64
+
65
+ out.write(frame) # Save frame to output video
66
+
67
+ cap.release()
68
+ out.release()
69
+
70
+ detections_list = list(unique_detections)
71
+
72
+ return {
73
+ "output_path": output_path,
74
+ "detected_objects": [{"object": obj} for obj in detections_list]
75
+ }
76
 
 
 
 
 
 
 
77
 
78
+ @tool
79
+ def detect_road_lanes(video_path: str) -> dict:
80
+ """Detects lane markings in an MP4 video using YOLOv8-seg.
 
 
81
 
82
+ Args:
83
+ video_path: Path to the input video.
84
 
85
+ Returns:
86
+ Dictionary with processed video path and lane detection results.
87
+ """
88
+ # Check if we already have downloaded the model, if not, download it
89
+ model_path = "yolov8s-seg.pt"
90
+ if not os.path.exists(model_path):
91
+ # First, download YOLOv8 segmentation model
92
+ model = YOLO("yolov8s-seg.pt")
93
+ else:
94
+ model = YOLO(model_path)
95
+
96
+ cap = cv2.VideoCapture(video_path)
97
+
98
+ if not cap.isOpened():
99
+ return {"error": f"Could not open video file at {video_path}"}
100
+
101
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
+ fps = cap.get(cv2.CAP_PROP_FPS)
104
+
105
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
106
+ output_path = "lanes_output.mp4"
107
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
108
+
109
+ # For lane detection specifically
110
+ lane_count = 0
111
+ detected_lanes = []
112
+
113
+ while cap.isOpened():
114
+ ret, frame = cap.read()
115
+ if not ret:
116
+ break
117
+
118
+ # Run segmentation model for lane detection
119
+ # YOLOv8-seg can identify roads and potentially lane markings
120
+ results = model(frame, classes=[0, 1, 2, 3, 7]) # Focus on relevant classes like road, person, car
121
+
122
+ # Create a visualization frame
123
+ vis_frame = frame.copy()
124
+
125
+ # Use the segmentation masks to help identify lanes
126
+ if hasattr(results[0], 'masks') and results[0].masks is not None:
127
+ masks = results[0].masks
128
+
129
+ # Enhance lane detection with traditional computer vision
130
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
131
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
132
+ edges = cv2.Canny(blur, 50, 150)
133
+
134
+ # Create a mask focused on the lower portion of the image (where lanes typically are)
135
+ mask = np.zeros_like(edges)
136
+ height, width = edges.shape
137
+ polygon = np.array([[(0, height), (width, height), (width, height//2), (0, height//2)]], dtype=np.int32)
138
+ cv2.fillPoly(mask, polygon, 255)
139
+ masked_edges = cv2.bitwise_and(edges, mask)
140
+
141
+ # Apply Hough transform to detect lines
142
+ lines = cv2.HoughLinesP(masked_edges, 1, np.pi/180, 50, minLineLength=100, maxLineGap=50)
143
+
144
+ current_lane_count = 0
145
+ lane_lines = []
146
+
147
+ if lines is not None:
148
+ for line in lines:
149
+ x1, y1, x2, y2 = line[0]
150
+
151
+ # Filter out horizontal lines (not lanes)
152
+ if abs(x2 - x1) > 0 and abs(y2 - y1) / abs(x2 - x1) > 0.5: # Slope threshold
153
+ cv2.line(vis_frame, (x1, y1), (x2, y2), (0, 0, 255), 2) # Red lane markings
154
+ lane_lines.append(((x1, y1), (x2, y2)))
155
+
156
+ # Count lanes by clustering similar lines
157
+ if lane_lines:
158
+ # Simple clustering: group lines with similar slopes
159
+ slopes = []
160
+ for ((x1, y1), (x2, y2)) in lane_lines:
161
+ # Avoid division by zero
162
+ if x2 != x1:
163
+ slope = (y2 - y1) / (x2 - x1)
164
+ slopes.append(slope)
165
+
166
+ # Cluster slopes to identify unique lanes
167
+ unique_slopes = []
168
+ for slope in slopes:
169
+ is_new = True
170
+ for us in unique_slopes:
171
+ if abs(slope - us) < 0.2: # Threshold for considering slopes similar
172
+ is_new = False
173
+ break
174
+ if is_new:
175
+ unique_slopes.append(slope)
176
+
177
+ current_lane_count = len(unique_slopes)
178
+ lane_count = max(lane_count, current_lane_count)
179
+
180
+ # Update detected lanes information
181
+ detected_lanes = [{"lane_id": i, "slope": s} for i, s in enumerate(unique_slopes)]
182
+
183
+ # Add lane count text
184
+ cv2.putText(vis_frame, f"Detected lanes: {current_lane_count}", (50, 50),
185
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
186
+
187
+ # Add segmentation visualization
188
+ if hasattr(results[0], 'masks') and results[0].masks is not None:
189
+ masks = results[0].masks
190
+ for mask in masks:
191
+ # Convert mask to binary image
192
+ seg_mask = mask.data.cpu().numpy()[0].astype(np.uint8) * 255
193
+ # Resize mask to frame size
194
+ seg_mask = cv2.resize(seg_mask, (width, height))
195
+ # Create colored overlay for the mask
196
+ color_mask = np.zeros_like(vis_frame)
197
+ color_mask[seg_mask > 0] = [0, 255, 255] # Yellow color for segmentation
198
+ # Add the mask as semi-transparent overlay
199
+ vis_frame = cv2.addWeighted(vis_frame, 1, color_mask, 0.3, 0)
200
+
201
+ out.write(vis_frame)
202
+
203
+ cap.release()
204
+ out.release()
205
+
206
+ return {
207
+ "output_path": output_path,
208
+ "detected_lanes": detected_lanes,
209
+ "lane_count": lane_count
210
+ }
211
 
 
 
212
 
213
+ @tool
214
+ def driving_situation_analyzer(video_path: str) -> dict:
215
+ """Analyzes road conditions by integrating YOLOv8 object detections and lane information.
216
+
217
+ Args:
218
+ video_path: Path to the input video.
219
+
220
+ Returns:
221
+ A dictionary containing situation analysis.
222
+ """
223
+ # Run object detection with YOLOv8
224
+ object_results = get_yolov8_coco_detections(video_path)
225
+
226
+ if "error" in object_results:
227
+ return {"error": object_results["error"]}
228
+
229
+ # Run lane detection with YOLOv8
230
+ lane_results = detect_road_lanes(video_path)
231
+
232
+ if "error" in lane_results:
233
+ return {"error": lane_results["error"]}
234
+
235
+ # Extract information from results
236
+ detected_objects = object_results.get("detected_objects", [])
237
+ detected_lanes = lane_results.get("detected_lanes", [])
238
+ lane_count = lane_results.get("lane_count", 0)
239
+
240
+ # Analyze the driving situation
241
  situation = []
242
 
243
  if any(obj["object"] in ["car", "truck", "bus"] for obj in detected_objects):
244
  situation.append("Traffic detected ahead, maintain safe distance.")
245
 
246
+ if any(obj["object"] == "person" for obj in detected_objects):
247
  situation.append("Pedestrian detected, be prepared to stop.")
248
 
249
  if any(obj["object"] == "traffic light" for obj in detected_objects):
250
  situation.append("Traffic light detected, slow down if red.")
251
+
252
+ if any(obj["object"] == "stop sign" for obj in detected_objects):
253
+ situation.append("Stop sign detected, prepare to stop.")
254
 
255
+ if lane_count == 0:
256
  situation.append("Lane markings not detected, potential risk of veering.")
257
+ elif lane_count == 1:
 
258
  situation.append("Single lane detected, ensure proper lane following.")
259
+ elif lane_count >= 2:
260
+ situation.append(f"{lane_count} lanes detected, stay within lane boundaries.")
261
+
262
+ # Evaluate road complexity
263
+ road_complexity = "LOW"
264
+ if lane_count > 2:
265
+ road_complexity = "MEDIUM"
266
+
267
+ vehicle_count = sum(1 for obj in detected_objects if obj["object"] in ["car", "truck", "bus"])
268
+ if vehicle_count > 3:
269
+ road_complexity = "HIGH"
270
+
271
+ if any(obj["object"] == "person" for obj in detected_objects) and vehicle_count > 1:
272
+ road_complexity = "HIGH"
273
+
274
+ # Evaluate safety level
275
+ safety_level = "HIGH"
276
+ if "Pedestrian detected" in " ".join(situation):
277
+ safety_level = "MEDIUM"
278
+ if "Lane markings not detected" in " ".join(situation):
279
+ safety_level = "LOW"
280
+ if vehicle_count > 5:
281
+ safety_level = "MEDIUM"
282
 
283
  return {
284
  "situation_summary": " | ".join(situation) if situation else "Road situation unclear, proceed with caution.",
285
  "detected_objects": detected_objects,
286
+ "detected_lanes": detected_lanes,
287
+ "lane_count": lane_count,
288
+ "road_complexity": road_complexity,
289
+ "safety_level": safety_level,
290
+ "objects_video": object_results.get("output_path"),
291
+ "lanes_video": lane_results.get("output_path")
292
  }
293
 
294
+
295
  @tool
296
+ def predict_trajectory(video_path: str) -> dict:
297
  """Predicts vehicle trajectory based on the driving situation analysis.
298
 
299
+ Uses YOLOv8-based analysis of detected lanes, objects, and road conditions.
300
+
301
+ Args:
302
+ video_path: Path to the input video.
303
+
304
+ Returns:
305
+ A dictionary containing trajectory predictions.
306
  """
307
+ # First get the comprehensive analysis
308
+ analysis = driving_situation_analyzer(video_path)
309
+
310
+ if "error" in analysis:
311
+ return {"error": analysis["error"]}
312
+
313
+ # Extract key information for trajectory planning
314
+ detected_objects = analysis.get("detected_objects", [])
315
+ detected_lanes = analysis.get("detected_lanes", [])
316
+ lane_count = analysis.get("lane_count", 0)
317
+ road_complexity = analysis.get("road_complexity", "MEDIUM")
318
+ safety_level = analysis.get("safety_level", "MEDIUM")
319
+ summary = analysis.get("situation_summary", "")
320
+
321
+ # Plan trajectory based on situation analysis
322
  trajectory = []
323
+
324
+ # Safety level affects overall driving strategy
325
+ if safety_level == "LOW":
326
+ trajectory.append("Reduce speed significantly, proceed with extreme caution.")
327
+ elif safety_level == "MEDIUM":
328
+ trajectory.append("Maintain moderate speed, be alert for changing conditions.")
329
+ else: # HIGH
330
+ trajectory.append("Normal driving conditions, maintain safe speed.")
331
+
332
+ # Road complexity affects navigation approach
333
+ if road_complexity == "HIGH":
334
+ trajectory.append("Complex traffic environment, navigate with extra caution.")
335
+
336
+ # Object-specific trajectory adjustments
337
+ for obj in detected_objects:
338
+ obj_name = obj["object"].lower()
339
+
340
+ if "person" in obj_name:
341
+ trajectory.append("Yield to pedestrians, prepare for potential stopping.")
342
+
343
+ if obj_name in ["car", "truck", "bus"]:
344
+ trajectory.append("Vehicle detected, maintain safe following distance.")
345
+
346
+ if obj_name == "traffic light":
347
+ trajectory.append("Approach intersection carefully, prepare to stop if light changes.")
348
+
349
+ if obj_name == "stop sign":
350
+ trajectory.append("Slow down and prepare to stop completely at the stop sign.")
351
+
352
+ # Lane-specific trajectory planning
353
+ if lane_count == 0:
354
+ trajectory.append("No lanes detected, follow visual road boundaries carefully.")
355
+ elif lane_count == 1:
356
+ trajectory.append("Single lane detected, maintain centered position.")
357
+ else:
358
+ trajectory.append(f"{lane_count} lanes available, stay within current lane.")
359
+
360
+ # Create video visualization of the predicted trajectory
361
+ # Get last frame from the analysis video
362
+ lane_video = analysis.get("lanes_video")
363
+ if lane_video and os.path.exists(lane_video):
364
+ cap = cv2.VideoCapture(lane_video)
365
+ else:
366
+ cap = cv2.VideoCapture(video_path)
367
+
368
+ # Get the last frame for visualization
369
+ frame = None
370
+ while cap.isOpened():
371
+ ret, current_frame = cap.read()
372
+ if not ret:
373
+ break
374
+ frame = current_frame
375
+ cap.release()
376
+
377
+ # Generate trajectory visualization if we have a frame
378
+ if frame is not None:
379
+ height, width = frame.shape[:2]
380
+
381
+ # Create a copy of the frame for trajectory visualization
382
+ trajectory_frame = frame.copy()
383
+
384
+ # Draw starting point at bottom center
385
+ start_x, start_y = width // 2, height - 50
386
+ cv2.circle(trajectory_frame, (start_x, start_y), 5, (255, 255, 0), -1)
387
+
388
+ # Generate trajectory points based on analysis
389
+ future_positions = []
390
+ x, y = start_x, start_y
391
+
392
+ # Adjust trajectory based on safety level and road conditions
393
+ lateral_variation = 5 # Default lateral variation
394
+ if safety_level == "LOW":
395
+ lateral_variation = 3 # Less lateral movement when unsafe
396
+ elif road_complexity == "HIGH":
397
+ lateral_variation = 8 # More potential variation in complex roads
398
+
399
+ # Generate trajectory points
400
+ for t in range(10):
401
+ # Move forward (up in the image)
402
+ y -= 30 # Move up by 30 pixels
403
+
404
+ # Calculate lateral adjustment based on conditions
405
+ if "pedestrian" in summary.lower():
406
+ # Move away from pedestrians (assume they're on the right)
407
+ x -= np.random.uniform(0, lateral_variation)
408
+ elif "traffic" in summary.lower() and t > 5:
409
+ # Slight movement to adjust for traffic ahead
410
+ x += np.random.uniform(-lateral_variation/2, lateral_variation/2)
411
+ else:
412
+ # Normal driving with slight randomness
413
+ x += np.random.uniform(-lateral_variation/3, lateral_variation/3)
414
+
415
+ # Ensure point is within frame and reasonable drivable area
416
+ x = max(width * 0.2, min(width * 0.8, x)) # Keep within 20-80% of width
417
+ y = max(0, min(height-1, y))
418
+
419
+ future_positions.append((int(x), int(y)))
420
+
421
+ # Draw point on the frame
422
+ cv2.circle(trajectory_frame, (int(x), int(y)), 5, (255, 255, 0), -1)
423
+
424
+ # Connect points with lines
425
+ for i in range(1, len(future_positions)):
426
+ cv2.line(trajectory_frame, future_positions[i-1], future_positions[i], (255, 255, 0), 2)
427
+
428
+ # Add trajectory recommendation text
429
+ trajectory_text = " | ".join(trajectory) if trajectory else "Maintain current path."
430
+ y_pos = 30
431
+ # Split text into multiple lines if too long
432
+ for line in [trajectory_text[i:i+60] for i in range(0, len(trajectory_text), 60)]:
433
+ cv2.putText(trajectory_frame, line, (30, y_pos),
434
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
435
+ y_pos += 30
436
+
437
+ # Add safety level indicator
438
+ safety_color = (0, 255, 0) if safety_level == "HIGH" else \
439
+ (0, 255, 255) if safety_level == "MEDIUM" else \
440
+ (0, 0, 255) # Red for LOW safety
441
+
442
+ cv2.putText(trajectory_frame, f"Safety: {safety_level}", (width-150, 30),
443
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, safety_color, 2)
444
+
445
+ # Save the trajectory visualization
446
+ output_path = "trajectory_output.mp4"
447
+
448
+ # Create a short video showing the trajectory
449
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
450
+ out = cv2.VideoWriter(output_path, fourcc, 1, (width, height))
451
+
452
+ # Write the frame several times to create a video
453
+ for _ in range(3):
454
+ out.write(trajectory_frame)
455
+
456
+ out.release()
457
+ else:
458
+ future_positions = []
459
+ output_path = None
460
 
461
  return {
462
  "trajectory_recommendation": " | ".join(trajectory) if trajectory else "Maintain current path.",
463
+ "future_positions": future_positions,
464
+ "safety_level": safety_level,
465
+ "road_complexity": road_complexity,
466
+ "trajectory_video": output_path if os.path.exists("trajectory_output.mp4") else None,
467
+ "analysis_videos": {
468
+ "objects": analysis.get("objects_video"),
469
+ "lanes": analysis.get("lanes_video")
470
+ }
471
  }
472
 
473
+
474
  @tool
475
  def get_current_time_in_timezone(timezone: str) -> str:
476
  """Fetches the current local time in a specified timezone."""
 
481
  except Exception as e:
482
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
483
 
484
+
485
+ # Setup FinalAnswerTool
486
  final_answer = FinalAnswerTool()
487
 
488
+ # Setup model
489
  model = HfApiModel(
490
  max_tokens=2096,
491
  temperature=0.5,
 
493
  custom_role_conversions=None,
494
  )
495
 
496
+ # Create or load prompts.yaml
497
+ if not os.path.exists("prompts.yaml"):
498
+ prompts = {
499
+ "default": "You are an autonomous driving assistant that helps analyze road scenes and make driving decisions.",
500
+ "prefix": "Analyze the following driving scenario: ",
501
+ "suffix": "Provide a detailed analysis with safety recommendations."
502
+ }
503
+ with open("prompts.yaml", 'w') as file:
504
+ yaml.dump(prompts, file)
505
+ else:
506
+ with open("prompts.yaml", 'r') as stream:
507
+ prompt_templates = yaml.safe_load(stream)
508
 
509
+ # Define agent
510
  agent = CodeAgent(
511
  model=model,
512
  tools=[
513
  final_answer,
514
+ get_yolov8_coco_detections,
515
  detect_road_lanes,
516
  driving_situation_analyzer,
517
+ predict_trajectory,
518
+ get_current_time_in_timezone
519
  ],
520
  max_steps=6,
521
  verbosity_level=1,
 
526
  prompt_templates=prompt_templates
527
  )
528
 
529
+ # Define Gradio interface for testing without the GradioUI wrapper
530
+ def create_gradio_interface():
531
+ with gr.Blocks(title="Autonomous Driving Video Analysis with YOLOv8") as demo:
532
+ gr.Markdown("# Autonomous Driving Video Analysis with YOLOv8")
533
+
534
+ with gr.Row():
535
+ with gr.Column():
536
+ video_input = gr.Video(label="Upload Driving Video")
537
+ analysis_type = gr.Radio(
538
+ ["Object Detection (YOLOv8)", "Lane Detection (YOLOv8)", "Situation Analysis", "Trajectory Prediction"],
539
+ label="Analysis Type",
540
+ value="Object Detection (YOLOv8)"
541
+ )
542
+ analyze_btn = gr.Button("Analyze Video")
543
+
544
+ with gr.Column():
545
+ result_text = gr.Textbox(label="Analysis Results", lines=10)
546
+
547
+ with gr.Tabs():
548
+ with gr.TabItem("Object Detection"):
549
+ object_video_output = gr.Video(label="Object Detection Results")
550
+ with gr.TabItem("Lane Detection"):
551
+ lane_video_output = gr.Video(label="Lane Detection Results")
552
+ with gr.TabItem("Trajectory Prediction"):
553
+ trajectory_video_output = gr.Video(label="Trajectory Prediction")
554
+
555
+ def process_video(video_path, analysis_type):
556
+ """Process the video based on the selected analysis type."""
557
+ if not video_path:
558
+ return "Please upload a video file.", None, None, None
559
+
560
+ # Save the uploaded video to a temporary file
561
+ temp_dir = tempfile.gettempdir()
562
+ temp_video_path = os.path.join(temp_dir, "input_video.mp4")
563
+
564
+ # Save the uploaded video
565
+ with open(temp_video_path, "wb") as f:
566
+ if hasattr(video_path, "read"):
567
+ # If video_path is a file-like object (from Gradio)
568
+ f.write(video_path.read())
569
+ else:
570
+ # If video_path is already a path
571
+ with open(video_path, "rb") as source_file:
572
+ f.write(source_file.read())
573
+
574
+ result = None
575
+ object_video = None
576
+ lane_video = None
577
+ trajectory_video = None
578
+
579
+ try:
580
+ if analysis_type == "Object Detection (YOLOv8)":
581
+ result = get_yolov8_coco_detections(temp_video_path)
582
+ if isinstance(result, dict) and "output_path" in result and os.path.exists(result["output_path"]):
583
+ object_video = result["output_path"]
584
+
585
+ elif analysis_type == "Lane Detection (YOLOv8)":
586
+ result = detect_road_lanes(temp_video_path)
587
+ if isinstance(result, dict) and "output_path" in result and os.path.exists(result["output_path"]):
588
+ lane_video = result["output_path"]
589
+
590
+ elif analysis_type == "Situation Analysis":
591
+ result = driving_situation_analyzer(temp_video_path)
592
+ if isinstance(result, dict):
593
+ if "objects_video" in result and result["objects_video"] and os.path.exists(result["objects_video"]):
594
+ object_video = result["objects_video"]
595
+ if "lanes_video" in result and result["lanes_video"] and os.path.exists(result["lanes_video"]):
596
+ lane_video = result["lanes_video"]
597
+
598
+ elif analysis_type == "Trajectory Prediction":
599
+ result = predict_trajectory(temp_video_path)
600
+ if isinstance(result, dict):
601
+ if "analysis_videos" in result:
602
+ videos = result["analysis_videos"]
603
+ if "objects" in videos and videos["objects"] and os.path.exists(videos["objects"]):
604
+ object_video = videos["objects"]
605
+ if "lanes" in videos and videos["lanes"] and os.path.exists(videos["lanes"]):
606
+ lane_video = videos["lanes"]
607
+ if "trajectory_video" in result and result["trajectory_video"] and os.path.exists(result["trajectory_video"]):
608
+ trajectory_video = result["trajectory_video"]
609
+
610
+ except Exception as e:
611
+ return f"Error processing video: {str(e)}", None, None, None
612
+
613
+ return str(result), object_video, lane_video, trajectory_video
614
+
615
+ analyze_btn.click(
616
+ fn=process_video,
617
+ inputs=[video_input, analysis_type],
618
+ outputs=[result_text, object_video_output, lane_video_output, trajectory_video_output]
619
+ )
620
+
621
+ return demo
622
+
623
+ # Try to use the GradioUI wrapper if it's available, otherwise use our custom interface
624
+ try:
625
+ # Launch using the GradioUI wrapper from the original code
626
+ GradioUI(agent).launch()
627
+ except Exception as e:
628
+ print(f"Error using GradioUI wrapper: {e}")
629
+ print("Launching custom Gradio interface instead")
630
+ create_gradio_interface().launch()