Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +63 -18
- best_model_mobilenet_v3_v2.pth +3 -0
app.py
CHANGED
|
@@ -33,16 +33,16 @@ def initialize_models():
|
|
| 33 |
# Set device
|
| 34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
-
# Initialize
|
| 37 |
-
model = models.
|
| 38 |
-
model.
|
| 39 |
-
nn.Linear(
|
| 40 |
nn.Softmax(dim=1)
|
| 41 |
)
|
| 42 |
model = model.to(device)
|
| 43 |
|
| 44 |
# Load model weights
|
| 45 |
-
best_model_path = "
|
| 46 |
if not os.path.exists(best_model_path):
|
| 47 |
st.error(f"Model file not found: {best_model_path}")
|
| 48 |
return None, None, None
|
|
@@ -54,7 +54,7 @@ def initialize_models():
|
|
| 54 |
model.eval()
|
| 55 |
|
| 56 |
# Load YOLO model
|
| 57 |
-
yolo_model_path = "yolo11s.onnx"
|
| 58 |
if not os.path.exists(yolo_model_path):
|
| 59 |
st.error(f"YOLO model file not found: {yolo_model_path}")
|
| 60 |
return device, model, None
|
|
@@ -80,7 +80,8 @@ def process_image(image, model, device):
|
|
| 80 |
# Perform inference
|
| 81 |
with torch.no_grad():
|
| 82 |
output = model(input_tensor)
|
| 83 |
-
probabilities = output[0]
|
|
|
|
| 84 |
no_red_light_prob = probabilities[0].item()
|
| 85 |
red_light_prob = probabilities[1].item()
|
| 86 |
is_red_light = red_light_prob > no_red_light_prob
|
|
@@ -128,6 +129,44 @@ def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, f
|
|
| 128 |
# Put text
|
| 129 |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
def main():
|
| 132 |
st.title("Traffic Light Detection with Protection Area")
|
| 133 |
|
|
@@ -279,17 +318,23 @@ def main():
|
|
| 279 |
'confidence': confidence,
|
| 280 |
'bbox': bbox
|
| 281 |
})
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
# Add status text
|
| 295 |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
|
|
|
|
| 33 |
# Set device
|
| 34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
+
# Initialize MobileNetV3 model
|
| 37 |
+
model = models.mobilenet_v3_small(weights=None)
|
| 38 |
+
model.classifier = nn.Sequential(
|
| 39 |
+
nn.Linear(576, 2), # Direct mapping to output classes
|
| 40 |
nn.Softmax(dim=1)
|
| 41 |
)
|
| 42 |
model = model.to(device)
|
| 43 |
|
| 44 |
# Load model weights
|
| 45 |
+
best_model_path = "best_model_mobilenet_v3_v2.pth"
|
| 46 |
if not os.path.exists(best_model_path):
|
| 47 |
st.error(f"Model file not found: {best_model_path}")
|
| 48 |
return None, None, None
|
|
|
|
| 54 |
model.eval()
|
| 55 |
|
| 56 |
# Load YOLO model
|
| 57 |
+
yolo_model_path = "../yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB
|
| 58 |
if not os.path.exists(yolo_model_path):
|
| 59 |
st.error(f"YOLO model file not found: {yolo_model_path}")
|
| 60 |
return device, model, None
|
|
|
|
| 80 |
# Perform inference
|
| 81 |
with torch.no_grad():
|
| 82 |
output = model(input_tensor)
|
| 83 |
+
probabilities = output[0] # Get probabilities for both classes
|
| 84 |
+
# Class 0 is "No Red Light", Class 1 is "Red Light"
|
| 85 |
no_red_light_prob = probabilities[0].item()
|
| 86 |
red_light_prob = probabilities[1].item()
|
| 87 |
is_red_light = red_light_prob > no_red_light_prob
|
|
|
|
| 129 |
# Put text
|
| 130 |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
|
| 131 |
|
| 132 |
+
def calculate_iou(box1, box2):
|
| 133 |
+
"""Calculate Intersection over Union between two bounding boxes."""
|
| 134 |
+
x1 = max(box1[0], box2[0])
|
| 135 |
+
y1 = max(box1[1], box2[1])
|
| 136 |
+
x2 = min(box1[2], box2[2])
|
| 137 |
+
y2 = min(box1[3], box2[3])
|
| 138 |
+
|
| 139 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| 140 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 141 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 142 |
+
union = box1_area + box2_area - intersection
|
| 143 |
+
|
| 144 |
+
return intersection / union if union > 0 else 0
|
| 145 |
+
|
| 146 |
+
def merge_overlapping_detections(detections, iou_threshold=0.5):
|
| 147 |
+
"""Merge overlapping detections of the same class."""
|
| 148 |
+
if not detections:
|
| 149 |
+
return []
|
| 150 |
+
|
| 151 |
+
# Sort detections by confidence
|
| 152 |
+
detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
|
| 153 |
+
merged_detections = []
|
| 154 |
+
|
| 155 |
+
while detections:
|
| 156 |
+
best_detection = detections.pop(0)
|
| 157 |
+
i = 0
|
| 158 |
+
while i < len(detections):
|
| 159 |
+
current_detection = detections[i]
|
| 160 |
+
if (current_detection['class'] == best_detection['class'] and
|
| 161 |
+
calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold):
|
| 162 |
+
# Remove the lower confidence detection
|
| 163 |
+
detections.pop(i)
|
| 164 |
+
else:
|
| 165 |
+
i += 1
|
| 166 |
+
merged_detections.append(best_detection)
|
| 167 |
+
|
| 168 |
+
return merged_detections
|
| 169 |
+
|
| 170 |
def main():
|
| 171 |
st.title("Traffic Light Detection with Protection Area")
|
| 172 |
|
|
|
|
| 318 |
'confidence': confidence,
|
| 319 |
'bbox': bbox
|
| 320 |
})
|
| 321 |
+
|
| 322 |
+
# Merge overlapping detections
|
| 323 |
+
detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
|
| 324 |
+
|
| 325 |
+
# Draw detections
|
| 326 |
+
for det in detection_results:
|
| 327 |
+
bbox = det['bbox']
|
| 328 |
+
# Draw detection box
|
| 329 |
+
cv2.rectangle(cv_image,
|
| 330 |
+
(int(bbox[0]), int(bbox[1])),
|
| 331 |
+
(int(bbox[2]), int(bbox[3])),
|
| 332 |
+
(0, 0, 255), 2)
|
| 333 |
+
|
| 334 |
+
# Add label
|
| 335 |
+
text = f"{det['class']}: {det['confidence']:.2%}"
|
| 336 |
+
put_text_with_background(cv_image, text,
|
| 337 |
+
(int(bbox[0]), int(bbox[1]) - 10))
|
| 338 |
|
| 339 |
# Add status text
|
| 340 |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
|
best_model_mobilenet_v3_v2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9d6dbfc5f368b8dd4f06f86e2ef088c0cd88c7bfd4f686800d2ef7b256b36f7
|
| 3 |
+
size 3850192
|