Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| from ultralytics import YOLO | |
| import os | |
| MODEL_PATH = "best.pt" | |
| TRACKER_FILE = "my_tracker.yaml" | |
| FOOTAGE_EXAMPLE_PATH = "drone_footage.mp4" | |
| tracker_config = """ | |
| tracker_type: bytetrack | |
| track_high_thresh: 0 | |
| track_low_thresh: 0 | |
| track_buffer: 300 | |
| fuse_score: True | |
| match_thresh: 0.9 | |
| new_track_thresh: 0.85 | |
| """ | |
| with open(TRACKER_FILE, "w") as f: | |
| f.write(tracker_config) | |
| model = YOLO(MODEL_PATH) | |
| def process_video(video_path, conf_threshold, iou_threshold): | |
| if video_path is None: | |
| return None | |
| MIN_FRAMES_TO_COUNT = 60 | |
| class_names = model.names | |
| track_history = {} | |
| class_counts = {name: 0 for name in class_names.values()} | |
| stable_counted_ids = set() | |
| cap = cv2.VideoCapture(video_path) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| output_path = "output_counted.mp4" | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) | |
| print("Processing video...") | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| results = model.track( | |
| frame, | |
| persist=True, | |
| verbose=False, | |
| tracker=TRACKER_FILE, | |
| conf=conf_threshold, | |
| iou=iou_threshold | |
| ) | |
| annotated_frame = results[0].plot(line_width=2, font_size=1) | |
| if results[0].boxes.id is not None: | |
| track_ids = results[0].boxes.id.int().tolist() | |
| class_indices = results[0].boxes.cls.int().tolist() | |
| for track_id, cls_index in zip(track_ids, class_indices): | |
| class_name = class_names[cls_index] | |
| if track_id not in track_history: | |
| track_history[track_id] = { | |
| 'frame_count': 1, | |
| 'class_votes': {class_name: 1} | |
| } | |
| else: | |
| track_history[track_id]['frame_count'] += 1 | |
| votes = track_history[track_id]['class_votes'] | |
| votes[class_name] = votes.get(class_name, 0) + 1 | |
| if track_history[track_id]['frame_count'] >= MIN_FRAMES_TO_COUNT and track_id not in stable_counted_ids: | |
| stable_counted_ids.add(track_id) | |
| votes = track_history[track_id]['class_votes'] | |
| stable_class = max(votes, key=votes.get) | |
| class_counts[stable_class] += 1 | |
| total_stable_count = len(stable_counted_ids) | |
| text_lines = [f'Total FFBs Counted: {total_stable_count}'] | |
| for class_name, count in class_counts.items(): | |
| if count > 0: | |
| text_lines.append(f'{class_name}: {count}') | |
| font_scale = 1.0 | |
| thickness = 2 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| (text_w, text_h), _ = cv2.getTextSize('Test', font, font_scale, thickness) | |
| line_height = text_h + 10 | |
| x_pos, y_pos = 10, 10 | |
| max_line_w = 0 | |
| for line in text_lines: | |
| (line_w, _), _ = cv2.getTextSize(line, font, font_scale, thickness) | |
| if line_w > max_line_w: | |
| max_line_w = line_w | |
| total_block_h = 10 + (line_height * len(text_lines)) - 5 | |
| total_block_w = 10 + max_line_w + 10 | |
| cv2.rectangle(annotated_frame, (x_pos, y_pos), (total_block_w, total_block_h), (0, 0, 0), -1) | |
| current_y = y_pos + text_h + 5 | |
| for line in text_lines: | |
| cv2.putText(annotated_frame, line, (x_pos + 5, current_y), font, font_scale, (255, 255, 255), thickness) | |
| current_y += line_height | |
| out.write(annotated_frame) | |
| cap.release() | |
| out.release() | |
| print(f"Final Count: {len(stable_counted_ids)}") | |
| print(f"Class Counts: {class_counts}") | |
| final_output_path = "final_web_ready.mp4" | |
| os.system(f"ffmpeg -y -i {output_path} -vcodec libx264 {final_output_path}") | |
| return final_output_path | |
| description_html = """ | |
| <p>Upload a video **(preferably drone footage)** showing Oil Palm Fresh Fruit Bunches (FFB). The model will count the detected FFBs.</p> | |
| <h3>Demo Result:</h3> | |
| <div style="display: flex; justify-content: center;"> | |
| <video width="640" height="360" controls autoplay loop muted> | |
| <source src="drone_footage_result.mp4" type="video/mp4"> | |
| Your browser does not support the video tag. | |
| </video> | |
| </div> | |
| """ | |
| iface = gr.Interface( | |
| fn=process_video, | |
| inputs=[ | |
| gr.Video(label="Upload Video"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.01, label="IoU Threshold"), | |
| ], | |
| outputs=gr.Video(label="Processed Result"), | |
| title="Oil Palm Fresh Fruit Bunch Classification and Counter", | |
| description=description_html, | |
| # Drone Footage Example | |
| examples=[ | |
| # Format: [Video_Path, Conf_Value, IoU_Value] | |
| [FOOTAGE_EXAMPLE_PATH, 0.25, 0.45] | |
| ], | |
| cache_examples=True | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(ssr_mode=False) |