Spaces:
Runtime error
Runtime error
mansi.modi@streebo.com commited on
Commit ·
3bacb10
1
Parent(s): 2c83160
init
Browse files- .DS_Store +0 -0
- app.py +157 -0
- requirements.txt +5 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torchvision.transforms import functional as F
|
| 6 |
+
import time
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Set up colors for visualization (vibrant colors for better visual impact)
|
| 11 |
+
COLORS = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (0, 255, 255), (255, 255, 0),
|
| 12 |
+
(255, 0, 255), (80, 70, 180), (250, 80, 190), (245, 145, 50)]
|
| 13 |
+
|
| 14 |
+
# COCO dataset classes (only keep the ones we want to detect)
|
| 15 |
+
# You can modify this list to focus only on specific objects
|
| 16 |
+
CLASSES = [
|
| 17 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
| 18 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
| 19 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
| 20 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
| 21 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
| 22 |
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
| 23 |
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
| 24 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
| 25 |
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
| 26 |
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
| 27 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
| 28 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
# Define our target classes (simplified for a more focused demo)
|
| 32 |
+
TARGET_CLASSES = ['person', 'bottle', 'cell phone', 'cup', 'laptop', 'chair']
|
| 33 |
+
|
| 34 |
+
def get_prediction(img, threshold=0.7):
|
| 35 |
+
"""Get model predictions and filter based on confidence threshold and target classes"""
|
| 36 |
+
transform = F.to_tensor(img)
|
| 37 |
+
prediction = model([transform])
|
| 38 |
+
|
| 39 |
+
# Filter predictions by confidence and target classes
|
| 40 |
+
masks = []
|
| 41 |
+
boxes = []
|
| 42 |
+
labels = []
|
| 43 |
+
scores = []
|
| 44 |
+
|
| 45 |
+
pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
|
| 46 |
+
pred_masks = prediction[0]['masks'].detach().cpu().numpy()
|
| 47 |
+
pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
|
| 48 |
+
pred_scores = prediction[0]['scores'].detach().cpu().numpy()
|
| 49 |
+
|
| 50 |
+
for i, score in enumerate(pred_scores):
|
| 51 |
+
if score > threshold and pred_classes[i] in TARGET_CLASSES:
|
| 52 |
+
masks.append(pred_masks[i][0])
|
| 53 |
+
boxes.append(pred_boxes[i])
|
| 54 |
+
labels.append(pred_classes[i])
|
| 55 |
+
scores.append(score)
|
| 56 |
+
|
| 57 |
+
return masks, boxes, labels, scores
|
| 58 |
+
|
| 59 |
+
def random_color():
|
| 60 |
+
"""Generate a random vibrant color for better visual impact"""
|
| 61 |
+
return COLORS[random.randint(0, len(COLORS)-1)]
|
| 62 |
+
|
| 63 |
+
def visualize(img, masks, boxes, labels, scores):
|
| 64 |
+
"""Apply visually appealing overlays to the image"""
|
| 65 |
+
height, width = img.shape[:2]
|
| 66 |
+
alpha = 0.5 # Transparency factor for mask overlay
|
| 67 |
+
|
| 68 |
+
for i, (mask, box, label, score) in enumerate(zip(masks, boxes, labels, scores)):
|
| 69 |
+
color = COLORS[i % len(COLORS)]
|
| 70 |
+
|
| 71 |
+
# Apply color mask with transparency
|
| 72 |
+
mask_bin = (mask > 0.5).astype(np.uint8)
|
| 73 |
+
colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
|
| 74 |
+
colored_mask[mask_bin == 1] = color
|
| 75 |
+
img = cv2.addWeighted(img, 1, colored_mask, alpha, 0)
|
| 76 |
+
|
| 77 |
+
# Draw bounding box with fancy parameters
|
| 78 |
+
x1, y1, x2, y2 = box.astype(int)
|
| 79 |
+
thickness = max(1, int(score * 3)) # Thicker lines for higher confidence
|
| 80 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
|
| 81 |
+
|
| 82 |
+
# Add fancy label with confidence
|
| 83 |
+
label_text = f"{label}: {score:.2f}"
|
| 84 |
+
label_size, baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
|
| 85 |
+
y1 = max(y1, label_size[1])
|
| 86 |
+
|
| 87 |
+
# Draw label background
|
| 88 |
+
cv2.rectangle(img, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), color, -1)
|
| 89 |
+
|
| 90 |
+
# Draw label text in white
|
| 91 |
+
cv2.putText(img, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 92 |
+
|
| 93 |
+
# Add effect: Draw contour of mask
|
| 94 |
+
contours, _ = cv2.findContours(mask_bin, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
| 95 |
+
cv2.drawContours(img, contours, -1, color, 2)
|
| 96 |
+
|
| 97 |
+
# Add a fancy title
|
| 98 |
+
cv2.putText(img, "Mask R-CNN Detection", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
|
| 99 |
+
|
| 100 |
+
# Add detected class counts
|
| 101 |
+
class_counts = {}
|
| 102 |
+
for label in labels:
|
| 103 |
+
class_counts[label] = class_counts.get(label, 0) + 1
|
| 104 |
+
|
| 105 |
+
y_pos = 60
|
| 106 |
+
for cls, count in class_counts.items():
|
| 107 |
+
text = f"{cls}: {count}"
|
| 108 |
+
cv2.putText(img, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 109 |
+
y_pos += 25
|
| 110 |
+
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
# Load a pre-trained Mask R-CNN model
|
| 114 |
+
print("Loading Mask R-CNN model...")
|
| 115 |
+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
| 116 |
+
model.eval()
|
| 117 |
+
if torch.cuda.is_available():
|
| 118 |
+
model.cuda()
|
| 119 |
+
print("Using GPU for inference")
|
| 120 |
+
else:
|
| 121 |
+
print("Using CPU for inference")
|
| 122 |
+
|
| 123 |
+
# Start video capture
|
| 124 |
+
print("Starting webcam...")
|
| 125 |
+
cap = cv2.VideoCapture(0) # Use 0 for default webcam
|
| 126 |
+
|
| 127 |
+
# Set properties for better quality
|
| 128 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
|
| 129 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
|
| 130 |
+
|
| 131 |
+
print("Press 'q' to quit")
|
| 132 |
+
|
| 133 |
+
while True:
|
| 134 |
+
ret, frame = cap.read()
|
| 135 |
+
if not ret:
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
# FPS calculation removed
|
| 139 |
+
|
| 140 |
+
# Get predictions
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
masks, boxes, labels, scores = get_prediction(frame)
|
| 143 |
+
|
| 144 |
+
# Visualize results
|
| 145 |
+
result = visualize(frame, masks, boxes, labels, scores)
|
| 146 |
+
|
| 147 |
+
# Show the result
|
| 148 |
+
cv2.imshow('Mask R-CNN Real-time Object Detection', result)
|
| 149 |
+
|
| 150 |
+
# Break the loop if 'q' is pressed
|
| 151 |
+
if cv2.waitKey(30) & 0xFF == ord('q'):
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
# Release resources
|
| 155 |
+
cap.release()
|
| 156 |
+
cv2.destroyAllWindows()
|
| 157 |
+
print("Demo finished!")
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
gradio
|