Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import uuid | |
| import gradio as gr | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import supervision as sv | |
| from ultralyticsplus import YOLO, download_from_hub | |
| hf_model_ids = ["chanelcolgate/rods-count-v1", "chanelcolgate/cab-v1"] | |
| image_paths = [ | |
| [image_path, "chanelcolgate/rods-count-v1", 640, 0.6, 0.45] | |
| for image_path in glob.glob("./images/*.jpg") | |
| ] | |
| video_paths = [ | |
| [video_path, "chanelcolgate/cab-v1"] | |
| for video_path in glob.glob("./videos/*.mp4") | |
| ] | |
| def get_center_of_bbox(bbox): | |
| x1, y1, x2, y2 = bbox | |
| return int((x1 + x2) / 2), int((y1 + y2) / 2) | |
| def get_bbox_width(bbox): | |
| return int(bbox[2] - bbox[0]) | |
| def draw_circle(pil_image, bbox, color, id): | |
| # Convert PIL image to a numpy array (OpenCV format) | |
| cv_image = np.array(pil_image) | |
| # Convert RGB to BGR (OpenCV format) | |
| cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) | |
| x_center, y_center = get_center_of_bbox(bbox) | |
| width = get_bbox_width(bbox) | |
| # Draw the circle on the image | |
| cv2.circle( | |
| cv_image, | |
| center=(x_center, y_center), | |
| radius=int(width * 0.5 * 0.6), | |
| color=color, | |
| thickness=1, | |
| ) | |
| cv2.putText( | |
| cv_image, | |
| f"{id}", | |
| (x_center - 6, y_center + 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (255, 249, 208), | |
| 2, | |
| ) | |
| # Convert BGR back to RGB (PIL format) | |
| cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) | |
| # Convert the numpy array back to a PIL Image | |
| pil_image = Image.fromarray(cv_image) | |
| return pil_image | |
| def count_predictions( | |
| image=None, | |
| hf_model_id="chanelcolgate/rods-count-v1", | |
| image_size=640, | |
| conf_threshold=0.25, | |
| iou_threshold=0.45, | |
| ): | |
| model_path = download_from_hub(hf_model_id) | |
| model = YOLO(model_path) | |
| results = model( | |
| image, imgsz=image_size, conf=conf_threshold, iou=iou_threshold | |
| ) | |
| detections = sv.Detections.from_ultralytics(results[0]) | |
| for id, detection in enumerate(detections): | |
| image = image.copy() | |
| bbox = detection[0].tolist() | |
| image = draw_circle(image, bbox, (90, 178, 255), id + 1) | |
| return image, len(detections) | |
| def count_across_line( | |
| source_video_path=None, | |
| hf_model_id="chanelcolgate/cab-v1", | |
| ): | |
| TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4") | |
| LINE_START = sv.Point(976, 212) | |
| LINE_END = sv.Point(976, 1276) | |
| model_path = download_from_hub(hf_model_id) | |
| model = YOLO(model_path) | |
| byte_tracker = sv.ByteTrack( | |
| track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30 | |
| ) | |
| video_info = sv.VideoInfo.from_video_path(source_video_path) | |
| generator = sv.get_video_frames_generator(source_video_path) | |
| line_zone = sv.LineZone(start=LINE_START, end=LINE_END) | |
| box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2) | |
| trace_annotator = sv.TraceAnnotator(thickness=4, trace_length=50) | |
| line_zone_annotator = sv.LineZoneAnnotator( | |
| thickness=4, text_thickness=4, text_scale=2 | |
| ) | |
| def callback(frame: np.ndarray, index: int) -> np.ndarray: | |
| results = model.predict(frame) | |
| cls_names = results[0].names | |
| detection = sv.Detections.from_ultralytics(results[0]) | |
| detection_supervision = byte_tracker.update_with_detections(detection) | |
| labels_convert = [ | |
| f"#{tracker_id} {cls_names[class_id]} {confidence:0.2f}" | |
| for _, _, confidence, class_id, tracker_id, _ in detection_supervision | |
| ] | |
| annotated_frame = trace_annotator.annotate( | |
| scene=frame.copy(), detections=detection_supervision | |
| ) | |
| annotated_frame = box_annotator.annotate( | |
| scene=annotated_frame, | |
| detections=detection_supervision, | |
| skip_label=True, | |
| # labels=labels_convert, | |
| ) | |
| # update line counter | |
| line_zone.trigger(detection_supervision) | |
| # return frame with box and line annotated result | |
| return line_zone_annotator.annotate( | |
| annotated_frame, line_counter=line_zone | |
| ) | |
| # process the whole video | |
| sv.process_video( | |
| source_path=source_video_path, | |
| target_path=TARGET_VIDEO_PATH, | |
| callback=callback, | |
| ) | |
| return TARGET_VIDEO_PATH, line_zone.out_count | |
| def count_in_zone( | |
| source_video_path=None, | |
| hf_model_id="chanelcolgate/cab-v1", | |
| ): | |
| TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4") | |
| colors = sv.ColorPalette.default() | |
| polygons = [ | |
| np.array([[88, 292], [748, 284], [736, 1160], [96, 1148]]), | |
| np.array([[844, 240], [844, 1132], [1580, 1124], [1584, 264]]), | |
| ] | |
| model_path = download_from_hub(hf_model_id) | |
| model = YOLO(model_path) | |
| byte_tracker = sv.ByteTrack( | |
| track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30 | |
| ) | |
| video_info = sv.VideoInfo.from_video_path(source_video_path) | |
| generator = sv.get_video_frames_generator(source_video_path) | |
| zones = [ | |
| sv.PolygonZone( | |
| polygon=polygon, frame_resolution_wh=video_info.resolution_wh | |
| ) | |
| for polygon in polygons | |
| ] | |
| zone_annotators = [ | |
| sv.PolygonZoneAnnotator( | |
| zone=zone, | |
| color=colors.by_idx(index), | |
| thickness=4, | |
| text_thickness=4, | |
| text_scale=2, | |
| ) | |
| for index, zone in enumerate(zones) | |
| ] | |
| box_annotators = [ | |
| sv.BoxAnnotator( | |
| thickness=4, | |
| text_thickness=4, | |
| text_scale=2, | |
| color=colors.by_idx(index), | |
| ) | |
| for index in range(len(polygons)) | |
| ] | |
| def callback(frame: np.ndarray, index: int) -> np.ndarray: | |
| results = model.predict(frame) | |
| detection = sv.Detections.from_ultralytics(results[0]) | |
| detection_supervision = byte_tracker.update_with_detections(detection) | |
| for zone, zone_annotator, box_annotator in zip( | |
| zones, zone_annotators, box_annotators | |
| ): | |
| zone.trigger(detections=detection_supervision) | |
| frame = box_annotator.annotate( | |
| scene=frame, detections=detection_supervision, skip_label=True | |
| ) | |
| frame = zone_annotator.annotate(scene=frame) | |
| return frame | |
| sv.process_video( | |
| source_path=source_video_path, | |
| target_path=TARGET_VIDEO_PATH, | |
| callback=callback, | |
| ) | |
| return TARGET_VIDEO_PATH, [zone.current_count for zone in zones] | |
| title = "Demo Counting" | |
| interface_count_predictions = gr.Interface( | |
| fn=count_predictions, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Dropdown(hf_model_ids), | |
| gr.Slider( | |
| minimum=320, maximum=1280, value=640, step=32, label="Image Size" | |
| ), | |
| gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.05, | |
| label="Confidence Threshold", | |
| ), | |
| gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.45, | |
| step=0.05, | |
| label="IOU Threshold", | |
| ), | |
| ], | |
| outputs=[gr.Image(type="pil"), gr.Textbox(show_label=False)], | |
| title="Count Predictions", | |
| examples=image_paths, | |
| cache_examples=True if image_paths else False, | |
| ) | |
| interface_count_across_line = gr.Interface( | |
| fn=count_across_line, | |
| inputs=[ | |
| gr.Video(label="Input Video"), | |
| gr.Dropdown(hf_model_ids), | |
| ], | |
| outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)], | |
| title="Count Across Line", | |
| examples=video_paths, | |
| cache_examples=True if video_paths else False, | |
| ) | |
| interface_count_in_zone = gr.Interface( | |
| fn=count_in_zone, | |
| inputs=[gr.Video(label="Input Video"), gr.Dropdown(hf_model_ids)], | |
| outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)], | |
| title="Count in Zone", | |
| examples=video_paths, | |
| cache_examples=True if video_paths else False, | |
| ) | |
| gr.TabbedInterface( | |
| [ | |
| interface_count_predictions, | |
| interface_count_across_line, | |
| interface_count_in_zone, | |
| ], | |
| tab_names=["Count Predictions", "Count Across Line", "Count in Zone"], | |
| title="Demo Counting", | |
| ).queue().launch() | |