Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from collections import deque | |
| from datetime import datetime | |
| from ultralytics import YOLO | |
| import time | |
| import tempfile | |
| import os | |
| class RotatingPadShirtCounter: | |
| """ | |
| Robust shirt counter for rotating pad system. | |
| Logic: Count when empty pad ENTERS the ROI (after shirt was removed) | |
| """ | |
| def __init__(self, | |
| model_path='best.pt', | |
| roi_center=(320, 240), | |
| roi_radius=180, | |
| min_conf=0.5, | |
| stability_frames=5): | |
| # Load YOLO model | |
| print(f"Loading YOLO model from: {model_path}") | |
| self.model = YOLO(model_path) | |
| self.model_names = self.model.names | |
| print(f"Model classes: {self.model_names}") | |
| # ROI Configuration | |
| self.roi_center = roi_center | |
| self.roi_radius = roi_radius | |
| self.min_conf = min_conf | |
| # State tracking | |
| self.current_state = "UNKNOWN" | |
| self.prev_state = "UNKNOWN" | |
| self.state_buffer = deque(maxlen=stability_frames) | |
| self.stability_frames = stability_frames | |
| # Counting logic | |
| self.shirt_count = 0 | |
| # Prevent double counting | |
| self.last_count_time = time.time() | |
| self.min_time_between_counts = 3.0 | |
| # Detection history | |
| self.detection_history = deque(maxlen=30) | |
| self.pad_away_frames = 0 | |
| self.min_pad_away_frames = 80 | |
| # Logging | |
| self.event_log = [] | |
| self.debug_mode = True | |
| def detect_in_roi(self, frame): | |
| """Run YOLO detection and filter by ROI""" | |
| results = self.model.predict(frame, conf=self.min_conf, verbose=False) | |
| has_empty_pad_in_roi = False | |
| has_occupied_pad_in_roi = False | |
| all_detections = [] | |
| for result in results: | |
| boxes = result.boxes | |
| for box in boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| conf = float(box.conf[0].cpu().numpy()) | |
| class_id = int(box.cls[0].cpu().numpy()) | |
| class_name = self.model_names[class_id] | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| dist = np.sqrt((center_x - self.roi_center[0])**2 + | |
| (center_y - self.roi_center[1])**2) | |
| in_roi = dist < self.roi_radius | |
| detection = { | |
| 'bbox': [x1, y1, x2, y2], | |
| 'center': (center_x, center_y), | |
| 'confidence': conf, | |
| 'class': class_name, | |
| 'in_roi': in_roi | |
| } | |
| all_detections.append(detection) | |
| if in_roi: | |
| if class_name == 'empty_pad': | |
| has_empty_pad_in_roi = True | |
| else: | |
| has_occupied_pad_in_roi = True | |
| return has_empty_pad_in_roi, has_occupied_pad_in_roi, all_detections | |
| def determine_state(self, has_empty, has_occupied): | |
| """Determine current state based on detections""" | |
| if has_empty: | |
| return "EMPTY_IN_ROI" | |
| elif has_occupied: | |
| return "OCCUPIED_IN_ROI" | |
| else: | |
| return "PAD_AWAY" | |
| def update_state_buffer(self, state): | |
| """Add to buffer and return stable state""" | |
| self.state_buffer.append(state) | |
| if len(self.state_buffer) < self.stability_frames: | |
| return self.current_state | |
| state_counts = {} | |
| for s in self.state_buffer: | |
| state_counts[s] = state_counts.get(s, 0) + 1 | |
| stable_state = max(state_counts, key=state_counts.get) | |
| if state_counts[stable_state] >= len(self.state_buffer) * 0.6: | |
| return stable_state | |
| return self.current_state | |
| def should_count(self): | |
| """KEY COUNTING LOGIC""" | |
| if self.prev_state == "PAD_AWAY" and self.current_state == "OCCUPIED_IN_ROI": | |
| time_since_last = time.time() - self.last_count_time | |
| if (time_since_last >= self.min_time_between_counts and | |
| self.pad_away_frames >= self.min_pad_away_frames): | |
| return True, f"Shirt on pad after PAD_AWAY for {self.pad_away_frames} frames" | |
| return False, None | |
| def process_frame(self, frame): | |
| """Main processing loop""" | |
| has_empty, has_occupied, detections = self.detect_in_roi(frame) | |
| instant_state = self.determine_state(has_empty, has_occupied) | |
| stable_state = self.update_state_buffer(instant_state) | |
| if self.current_state == "PAD_AWAY": | |
| self.pad_away_frames += 1 | |
| else: | |
| self.pad_away_frames = 0 | |
| state_changed = (stable_state != self.current_state) | |
| if state_changed: | |
| self.prev_state = self.current_state | |
| self.current_state = stable_state | |
| should_count, reason = self.should_count() | |
| if should_count: | |
| self.shirt_count += 1 | |
| self.last_count_time = time.time() | |
| self.log_event("SHIRT_COUNTED", reason) | |
| print(f"π― SHIRT #{self.shirt_count} COUNTED! - {reason}") | |
| else: | |
| self.log_event("STATE_CHANGE", f"{self.prev_state} -> {self.current_state}") | |
| vis_frame = self.draw_visualization(frame, detections, instant_state) | |
| return vis_frame | |
| def draw_visualization(self, frame, detections, instant_state): | |
| """Draw debug information on frame""" | |
| vis = frame.copy() | |
| cv2.circle(vis, self.roi_center, self.roi_radius, (0, 255, 255), 3) | |
| cv2.circle(vis, self.roi_center, 5, (0, 255, 255), -1) | |
| for det in detections: | |
| x1, y1, x2, y2 = map(int, det['bbox']) | |
| conf = det['confidence'] | |
| cls = det['class'] | |
| in_roi = det['in_roi'] | |
| color = (0, 255, 0) if cls == 'empty_pad' else (0, 0, 255) | |
| thickness = 3 if in_roi else 2 | |
| cv2.rectangle(vis, (x1, y1), (x2, y2), color, thickness) | |
| label = f"{cls} {conf:.2f}" | |
| if in_roi: | |
| label += " [ROI]" | |
| cv2.putText(vis, label, (x1, y1-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
| panel_height = 180 | |
| panel = np.zeros((panel_height, vis.shape[1], 3), dtype=np.uint8) | |
| cv2.putText(panel, f"SHIRTS COUNTED: {self.shirt_count}", (20, 50), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 3) | |
| state_color = { | |
| "EMPTY_IN_ROI": (0, 255, 0), | |
| "OCCUPIED_IN_ROI": (0, 165, 255), | |
| "PAD_AWAY": (255, 0, 0), | |
| "UNKNOWN": (128, 128, 128) | |
| }.get(self.current_state, (255, 255, 255)) | |
| cv2.putText(panel, f"State: {self.current_state}", (20, 90), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, state_color, 2) | |
| cv2.putText(panel, f"Instant: {instant_state}", (20, 120), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1) | |
| buffer_str = ''.join([ | |
| 'E' if s == "EMPTY_IN_ROI" else | |
| 'O' if s == "OCCUPIED_IN_ROI" else | |
| 'A' if s == "PAD_AWAY" else '?' | |
| for s in self.state_buffer | |
| ]) | |
| cv2.putText(panel, f"Buffer: [{buffer_str}]", (20, 150), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (180, 180, 180), 1) | |
| vis = np.vstack([panel, vis]) | |
| return vis | |
| def log_event(self, event_type, details): | |
| """Log events for debugging""" | |
| self.event_log.append({ | |
| 'timestamp': datetime.now().strftime('%H:%M:%S.%f')[:-3], | |
| 'event': event_type, | |
| 'details': details, | |
| 'count': self.shirt_count, | |
| 'state': self.current_state | |
| }) | |
| def get_stats(self): | |
| """Get statistics""" | |
| return { | |
| 'total_shirts': self.shirt_count, | |
| 'current_state': self.current_state, | |
| 'events': self.event_log | |
| } | |
| def process_video(video_path, roi_radius, min_confidence, stability_frames, progress=gr.Progress()): | |
| """Process uploaded video""" | |
| if video_path is None: | |
| return None, "β οΈ Please upload a video first!" | |
| progress(0, desc="Opening video...") | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, "β Error: Cannot open video file" | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| roi_center = (width // 2, height // 2) | |
| progress(0.1, desc="Loading model...") | |
| counter = RotatingPadShirtCounter( | |
| model_path='best.pt', | |
| roi_center=roi_center, | |
| roi_radius=int(roi_radius), | |
| min_conf=min_confidence, | |
| stability_frames=int(stability_frames) | |
| ) | |
| output_height = height + 180 | |
| # Create temporary output file | |
| temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| output_path = temp_output.name | |
| temp_output.close() | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, output_height)) | |
| if not out.isOpened(): | |
| cap.release() | |
| return None, "β Error: Cannot create output video" | |
| progress(0.2, desc="Processing video...") | |
| frame_count = 0 | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| vis_frame = counter.process_frame(frame) | |
| frame_progress = (frame_count / total_frames) * 100 | |
| cv2.putText(vis_frame, f"Frame: {frame_count}/{total_frames} ({frame_progress:.1f}%)", | |
| (width - 350, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2) | |
| out.write(vis_frame) | |
| if frame_count % 30 == 0: | |
| progress(0.2 + (frame_count / total_frames) * 0.75, | |
| desc=f"Processing: {frame_count}/{total_frames} frames | Shirts: {counter.shirt_count}") | |
| except Exception as e: | |
| cap.release() | |
| out.release() | |
| return None, f"β Error during processing: {str(e)}" | |
| finally: | |
| cap.release() | |
| out.release() | |
| progress(1.0, desc="Complete!") | |
| stats = counter.get_stats() | |
| result_text = f""" | |
| β **Processing Complete!** | |
| π **Results:** | |
| - Total Frames Processed: {frame_count:,} | |
| - **Shirts Counted: {stats['total_shirts']}** | |
| - Final State: {stats['current_state']} | |
| π **Event Log (Shirt Counts):** | |
| """ | |
| for evt in stats['events']: | |
| if evt['event'] == 'SHIRT_COUNTED': | |
| result_text += f"\n β [{evt['timestamp']}] Shirt #{evt['count']} - {evt['details']}" | |
| if stats['total_shirts'] == 0: | |
| result_text += "\n\nβ οΈ No shirts detected. Try adjusting parameters or ensure video shows the rotating pad system." | |
| return output_path, result_text | |
| # Gradio Interface | |
| with gr.Blocks(title="Rotating Pad Shirt Counter", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π Rotating Pad Shirt Counter | |
| ### Demo Showcase - Limited Training Model | |
| **β οΈ Important Note:** This is a demonstration model trained on only **half of a single video** for showcase purposes. | |
| Performance may vary with different videos, lighting conditions, or camera angles. | |
| ### How it works: | |
| 1. Upload a video showing a rotating pad system with shirts | |
| 2. The model detects when shirts are placed on the pad | |
| 3. System counts shirts as they rotate through the Region of Interest (ROI) | |
| ### Best Results: | |
| - Similar camera angle and lighting to training data | |
| - Clear view of the rotating pad | |
| - Videos from the same or similar production line | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video", height=400) | |
| with gr.Accordion("βοΈ Advanced Settings (Optional)", open=False): | |
| roi_radius = gr.Slider( | |
| minimum=100, maximum=300, value=180, step=10, | |
| label="ROI Radius (pixels)", | |
| info="Detection area size around center" | |
| ) | |
| min_confidence = gr.Slider( | |
| minimum=0.5, maximum=0.99, value=0.98, step=0.01, | |
| label="Minimum Confidence", | |
| info="Higher = more strict detection" | |
| ) | |
| stability_frames = gr.Slider( | |
| minimum=3, maximum=30, value=15, step=1, | |
| label="Stability Frames", | |
| info="Frames needed to confirm state change" | |
| ) | |
| process_btn = gr.Button("π Process Video", variant="primary", size="lg") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Processed Output", height=400) | |
| result_text = gr.Textbox( | |
| label="Results & Statistics", | |
| lines=10, | |
| max_lines=15 | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Model Information: | |
| - **Classes Detected:** `empty_pad`, `occupied_pad` (shirt on pad) | |
| - **Training Data:** Half portion of single production video | |
| - **Purpose:** Demonstration and proof-of-concept | |
| - **Limitations:** May not generalize well to different environments | |
| ### π‘ Tips: | |
| - Start with default settings | |
| - If no shirts detected, try lowering confidence threshold | |
| - If too many false counts, increase stability frames | |
| - ROI radius should cover the area where pad appears | |
| """) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, roi_radius, min_confidence, stability_frames], | |
| outputs=[video_output, result_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |