mansi.modi@streebo.com commited on
Commit
3bacb10
·
1 Parent(s): 2c83160
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +157 -0
  3. requirements.txt +5 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torchvision
5
+ from torchvision.transforms import functional as F
6
+ import time
7
+ import random
8
+
9
+
10
+ # Set up colors for visualization (vibrant colors for better visual impact)
11
+ COLORS = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (0, 255, 255), (255, 255, 0),
12
+ (255, 0, 255), (80, 70, 180), (250, 80, 190), (245, 145, 50)]
13
+
14
+ # COCO dataset classes (only keep the ones we want to detect)
15
+ # You can modify this list to focus only on specific objects
16
+ CLASSES = [
17
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
18
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
19
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
20
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
21
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
22
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
23
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
24
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
25
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
26
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
27
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
28
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
29
+ ]
30
+
31
+ # Define our target classes (simplified for a more focused demo)
32
+ TARGET_CLASSES = ['person', 'bottle', 'cell phone', 'cup', 'laptop', 'chair']
33
+
34
+ def get_prediction(img, threshold=0.7):
35
+ """Get model predictions and filter based on confidence threshold and target classes"""
36
+ transform = F.to_tensor(img)
37
+ prediction = model([transform])
38
+
39
+ # Filter predictions by confidence and target classes
40
+ masks = []
41
+ boxes = []
42
+ labels = []
43
+ scores = []
44
+
45
+ pred_classes = [CLASSES[i] for i in prediction[0]['labels']]
46
+ pred_masks = prediction[0]['masks'].detach().cpu().numpy()
47
+ pred_boxes = prediction[0]['boxes'].detach().cpu().numpy()
48
+ pred_scores = prediction[0]['scores'].detach().cpu().numpy()
49
+
50
+ for i, score in enumerate(pred_scores):
51
+ if score > threshold and pred_classes[i] in TARGET_CLASSES:
52
+ masks.append(pred_masks[i][0])
53
+ boxes.append(pred_boxes[i])
54
+ labels.append(pred_classes[i])
55
+ scores.append(score)
56
+
57
+ return masks, boxes, labels, scores
58
+
59
+ def random_color():
60
+ """Generate a random vibrant color for better visual impact"""
61
+ return COLORS[random.randint(0, len(COLORS)-1)]
62
+
63
+ def visualize(img, masks, boxes, labels, scores):
64
+ """Apply visually appealing overlays to the image"""
65
+ height, width = img.shape[:2]
66
+ alpha = 0.5 # Transparency factor for mask overlay
67
+
68
+ for i, (mask, box, label, score) in enumerate(zip(masks, boxes, labels, scores)):
69
+ color = COLORS[i % len(COLORS)]
70
+
71
+ # Apply color mask with transparency
72
+ mask_bin = (mask > 0.5).astype(np.uint8)
73
+ colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
74
+ colored_mask[mask_bin == 1] = color
75
+ img = cv2.addWeighted(img, 1, colored_mask, alpha, 0)
76
+
77
+ # Draw bounding box with fancy parameters
78
+ x1, y1, x2, y2 = box.astype(int)
79
+ thickness = max(1, int(score * 3)) # Thicker lines for higher confidence
80
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
81
+
82
+ # Add fancy label with confidence
83
+ label_text = f"{label}: {score:.2f}"
84
+ label_size, baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
85
+ y1 = max(y1, label_size[1])
86
+
87
+ # Draw label background
88
+ cv2.rectangle(img, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), color, -1)
89
+
90
+ # Draw label text in white
91
+ cv2.putText(img, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
92
+
93
+ # Add effect: Draw contour of mask
94
+ contours, _ = cv2.findContours(mask_bin, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
95
+ cv2.drawContours(img, contours, -1, color, 2)
96
+
97
+ # Add a fancy title
98
+ cv2.putText(img, "Mask R-CNN Detection", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
99
+
100
+ # Add detected class counts
101
+ class_counts = {}
102
+ for label in labels:
103
+ class_counts[label] = class_counts.get(label, 0) + 1
104
+
105
+ y_pos = 60
106
+ for cls, count in class_counts.items():
107
+ text = f"{cls}: {count}"
108
+ cv2.putText(img, text, (10, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
109
+ y_pos += 25
110
+
111
+ return img
112
+
113
+ # Load a pre-trained Mask R-CNN model
114
+ print("Loading Mask R-CNN model...")
115
+ model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
116
+ model.eval()
117
+ if torch.cuda.is_available():
118
+ model.cuda()
119
+ print("Using GPU for inference")
120
+ else:
121
+ print("Using CPU for inference")
122
+
123
+ # Start video capture
124
+ print("Starting webcam...")
125
+ cap = cv2.VideoCapture(0) # Use 0 for default webcam
126
+
127
+ # Set properties for better quality
128
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
129
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
130
+
131
+ print("Press 'q' to quit")
132
+
133
+ while True:
134
+ ret, frame = cap.read()
135
+ if not ret:
136
+ break
137
+
138
+ # FPS calculation removed
139
+
140
+ # Get predictions
141
+ with torch.no_grad():
142
+ masks, boxes, labels, scores = get_prediction(frame)
143
+
144
+ # Visualize results
145
+ result = visualize(frame, masks, boxes, labels, scores)
146
+
147
+ # Show the result
148
+ cv2.imshow('Mask R-CNN Real-time Object Detection', result)
149
+
150
+ # Break the loop if 'q' is pressed
151
+ if cv2.waitKey(30) & 0xFF == ord('q'):
152
+ break
153
+
154
+ # Release resources
155
+ cap.release()
156
+ cv2.destroyAllWindows()
157
+ print("Demo finished!")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ torch
4
+ torchvision
5
+ gradio