PIWM / src /models /contour_detection_model.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
# Vehicle Detection and State Estimation using Color-Based Contour Detection
"""
Vehicle Detection Model using Color-Based Contour Detection
This module provides functionality to detect vehicles in Bird's-Eye View (BEV) images
by isolating their specific colors (green and blue) and analyzing the shapes (contours)
of the colored areas. The detected states can be exported to CSV files.
Required Libraries:
- opencv-python: For image processing, color segmentation, and contour analysis
- numpy: For numerical operations
You can install them using pip:
pip install opencv-python-headless numpy
Example usage:
from models.contour_detection_model import ContourDetectionModel
model = ContourDetectionModel()
annotated_image, vehicle_states = model.detect_vehicles('path/to/image.jpg')
model.save_states_to_csv(vehicle_states, 'output.csv')
"""
import cv2
import numpy as np
import math
import csv
import torch
from typing import List, Dict, Tuple, Optional
class ContourDetectionModel:
"""
A vehicle detection model using color-based contour detection.
This model detects vehicles by:
1. Converting images to HSV color space
2. Creating color masks for green (ego) and blue (other) vehicles
3. Finding contours in the masked regions
4. Estimating vehicle position and heading from contour geometry
"""
def __init__(self,
green_hsv_range: Tuple[List[int], List[int]] = None,
blue_hsv_range: Tuple[List[int], List[int]] = None,
min_contour_area: int = 50):
"""
Initialize the contour detection model.
Args:
green_hsv_range: Tuple of (lower, upper) HSV ranges for green vehicles
blue_hsv_range: Tuple of (lower, upper) HSV ranges for blue vehicles
min_contour_area: Minimum contour area to filter out noise
"""
# Default HSV ranges optimized for highway-env vehicles
if green_hsv_range is None:
self.lower_green = np.array([50, 100, 100])
self.upper_green = np.array([70, 255, 255])
else:
self.lower_green = np.array(green_hsv_range[0])
self.upper_green = np.array(green_hsv_range[1])
if blue_hsv_range is None:
# Expanded range for better detection in generated images
self.lower_blue = np.array([80, 80, 80]) # More tolerant lower bounds
self.upper_blue = np.array([115, 255, 255]) # Wider hue range
else:
self.lower_blue = np.array(blue_hsv_range[0])
self.upper_blue = np.array(blue_hsv_range[1])
self.min_contour_area = min_contour_area
def detect_vehicles(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Detect vehicles in an image and estimate their states.
Args:
image_path (str): Path to the input image
Returns:
Tuple containing:
- annotated_image: Image with detection annotations (or None if error)
- vehicle_states: List of dictionaries containing vehicle state information
"""
return self.estimate_vehicle_states_by_color(image_path)
def estimate_vehicle_states_by_color(self, image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Detects vehicles in an image based on color, and estimates their position and heading.
Args:
image_path (str): The path to the input image.
Returns:
tuple: A tuple containing the annotated image and a list of vehicle states.
"""
# Load the image
try:
img = cv2.imread(image_path)
if img is None:
print(f"Error: Could not read image from path: {image_path}")
return None, []
# Create a copy for drawing annotations
annotated_img = img.copy()
except Exception as e:
print(f"Error loading image: {e}")
return None, []
# Use the existing detection pipeline
return self.detect_from_image(img)
def save_states_to_csv(self, states: List[Dict], csv_file_path: str) -> None:
"""
Saves the list of vehicle states to a CSV file.
Args:
states (list): A list of dictionaries, where each dictionary is a vehicle's state.
csv_file_path (str): The path to the output CSV file.
"""
# Define the fieldnames for the CSV header. We exclude the bounding box points.
fieldnames = ['class', 'position_x', 'position_y', 'heading', 'speed']
try:
with open(csv_file_path, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction='ignore')
writer.writeheader() # Write the header row
writer.writerows(states) # Write all the state data
print(f"\nVehicle states successfully saved to: {csv_file_path}")
except IOError as e:
print(f"Error writing to CSV file: {e}")
def process_image(self, input_image_path: str,
output_image_path: str = None,
output_csv_path: str = None,
verbose: bool = True) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Complete processing pipeline for a single image.
Args:
input_image_path (str): Path to input image
output_image_path (str, optional): Path to save annotated image
output_csv_path (str, optional): Path to save CSV file
verbose (bool): Whether to print detection results
Returns:
Tuple of (annotated_image, vehicle_states)
"""
# Process the image to get states and the annotated image
annotated_image, states = self.detect_vehicles(input_image_path)
if annotated_image is not None and states:
if verbose:
print("--- Detected Vehicle States (Contour Method) ---")
# Sort states by x-position for consistent ordering
states.sort(key=lambda v: v['position_x'])
for i, state in enumerate(states):
print(f"\nVehicle #{i+1}:")
print(f" Class: {state['class']}")
print(f" Position (x, y): ({state['position_x']:.2f}, {state['position_y']:.2f})")
print(f" Heading (degrees): {state['heading']:.2f}")
print(f" Speed: {state['speed']:.2f} (Note: Placeholder value)")
# Save the annotated image if path provided
if output_image_path:
cv2.imwrite(output_image_path, annotated_image)
if verbose:
print(f"\nAnnotated image saved to: {output_image_path}")
# Save the states to a CSV file if path provided
if output_csv_path:
self.save_states_to_csv(states, output_csv_path)
elif not states and verbose:
print("No vehicles were detected in the image.")
return annotated_image, states
def detect_from_image(self, img) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Internal method to detect vehicles from numpy image array.
Args:
img: Input image in BGR format
annotated_img: Copy of image for annotations
Returns:
Tuple of (annotated_image, vehicle_states)
"""
annotated_img = img.copy()
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# 3. Define color ranges for the vehicles
# These values are tuned for the specific green and blue in the provided image.
# Format: [Hue, Saturation, Value]
# Green vehicle (Ego)
lower_green = np.array([50, 100, 100])
upper_green = np.array([70, 255, 255])
# Blue vehicles (Corrected Range)
lower_blue = np.array([85, 100, 100])
upper_blue = np.array([110, 255, 255])
# 4. Create masks for each color
mask_green = cv2.inRange(hsv_img, lower_green, upper_green)
mask_blue = cv2.inRange(hsv_img, lower_blue, upper_blue)
# 5. Find contours for each mask separately for robust classification
contours_green, _ = cv2.findContours(mask_green, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours_blue, _ = cv2.findContours(mask_blue, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Combine contours with their respective class names
all_contours = []
for c in contours_green:
all_contours.append((c, "ego_vehicle"))
for c in contours_blue:
all_contours.append((c, "other_vehicle"))
vehicle_states = []
# 6. Iterate through each detected contour
for contour, class_name in all_contours:
# Filter out very small contours that might be noise
if cv2.contourArea(contour) < 50:
continue
# --- State Estimation ---
# a. Get the minimum area rotated rectangle
# This is perfect for finding the orientation of non-upright rectangles.
rect = cv2.minAreaRect(contour)
(pos_x, pos_y), _, _ = rect
# Get the 4 corners of the rotated rectangle for drawing and heading calculation
box_points = cv2.boxPoints(rect)
box_points = np.intp(box_points)
# b. Heading (Robust Calculation)
# We find the longer side of the rectangle and calculate its angle.
edge1 = np.linalg.norm(box_points[0] - box_points[1])
edge2 = np.linalg.norm(box_points[1] - box_points[2])
# Determine the vector corresponding to the vehicle's length (the longer side)
if edge1 > edge2:
delta_x = box_points[1][0] - box_points[0][0]
delta_y = box_points[1][1] - box_points[0][1]
else:
delta_x = box_points[2][0] - box_points[1][0]
delta_y = box_points[2][1] - box_points[1][1]
# Calculate the angle of this vector
angle_rad = math.atan2(delta_y, delta_x)
heading = math.degrees(angle_rad)
# As all vehicles in highway-env move to the right, we ensure the
# heading is in the right-hand plane (between -90 and 90 degrees).
if heading > 90:
heading -= 180
elif heading < -90:
heading += 180
# c. Speed
# Speed calculation requires tracking across multiple frames.
# Since we only have one frame, we'll set it to 0.
speed = 0.0 # Placeholder
# Store the state
vehicle_states.append({
"class": class_name,
"bounding_box_points": box_points.tolist(),
"position_x": pos_x,
"position_y": pos_y,
"speed": speed,
"heading": heading
})
# --- Visualization ---
# Draw the rotated bounding box
cv2.drawContours(annotated_img, [box_points], 0, (0, 255, 255), 2) # Yellow box
# Draw the center point
cv2.circle(annotated_img, (int(pos_x), int(pos_y)), 5, (0, 0, 255), -1) # Red dot
# Draw the heading vector
length = 40 # Length of the heading line
angle_rad_viz = np.deg2rad(heading) # Use the corrected heading for visualization
end_x = int(pos_x + length * np.cos(angle_rad_viz))
end_y = int(pos_y + length * np.sin(angle_rad_viz))
cv2.line(annotated_img, (int(pos_x), int(pos_y)), (end_x, end_y), (255, 0, 0), 2) # Blue line
# Put text label
label = f"H: {heading:.1f}"
cv2.putText(annotated_img, label, (box_points[1][0], box_points[1][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
return annotated_img, vehicle_states
# Legacy function for backward compatibility
def estimate_vehicle_states_by_color(image_path: str) -> Tuple[Optional[np.ndarray], List[Dict]]:
"""
Legacy function for backward compatibility.
Creates a default model instance and processes the image.
"""
model = ContourDetectionModel()
return model.estimate_vehicle_states_by_color(image_path)
def save_states_to_csv(states: List[Dict], csv_file_path: str) -> None:
"""
Legacy function for backward compatibility.
"""
model = ContourDetectionModel()
model.save_states_to_csv(states, csv_file_path)
if __name__ == '__main__':
# Example usage when run as script
# You can modify these paths as needed
input_image_path = '/home/alienware3/Documents/diamond/frames/frame_0.png'
output_image_path = 'frame_000001_contour_detected.png'
output_csv_path = 'vehicle_states.csv'
# Create model instance
model = ContourDetectionModel()
img = cv2.imread(input_image_path)
# Process the image
annotated_image, states = model.detect_from_image(
img
)
cv2.imwrite(output_image_path, annotated_image)