Spaces:
Sleeping
Sleeping
| import cv2 | |
| from ultralytics import YOLO | |
| import random | |
| import gradio as gr | |
| from tqdm import tqdm | |
| class yolo_model(): | |
| def __init__(self, model_name: str): | |
| """ | |
| Initialize the YOLO-World model | |
| Args: | |
| model_name (str): The name of the model file. | |
| """ | |
| # Initialize a YOLO-World model | |
| self.model = YOLO(model_name) | |
| def load(self, model_name: str): | |
| """ | |
| Load the YOLO model | |
| Args: | |
| model_to_load (str): The name of the model file. | |
| """ | |
| try: | |
| # Load the model | |
| self.model = YOLO(model_name) | |
| except Exception as e: | |
| print(e) | |
| # Define a function to process a video | |
| def process(self, video_path: str, prompt: str, confidence: float, iou: float, progress=gr.Progress(track_tqdm=True) | |
| ) -> str: | |
| """ | |
| Process a video with YOLO-World | |
| Args: | |
| video_path (str): The input video path. | |
| confidence (float): The confidence threshold. | |
| iou (float): The IoU threshold. | |
| Returns: | |
| str: The output video path. | |
| """ | |
| try: | |
| # create a list of classes based on prompt, each class is separated by a comma | |
| classes = prompt.split(",") if prompt else None | |
| # Define the colors for each class | |
| rgb_colors = [(random.randint(0, 255), random.randint( | |
| 0, 255), random.randint(0, 255)) for _ in range(len(classes))] | |
| # Define custom classes | |
| self.model.set_classes(classes) | |
| # Set confidence and IoU thresholds | |
| self.model.conf = confidence | |
| self.model.iou = iou | |
| # Open the video file | |
| video_capture = cv2.VideoCapture(video_path) | |
| # Get the video properties | |
| frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(video_capture.get(cv2.CAP_PROP_FPS)) | |
| n_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Define the output video path | |
| output_video_path = 'output.mp4' | |
| # Define the codec and create VideoWriter object | |
| fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') | |
| video_writer = cv2.VideoWriter( | |
| output_video_path, fourcc, fps, (frame_width, frame_height), isColor=True) | |
| # Process each frame in the video | |
| for _ in tqdm(range(n_frames), desc="Processing video", file=progress): | |
| ret, frame = video_capture.read() | |
| if not ret: | |
| break # Break the loop when no frames are left | |
| # Run inference to detect your custom classes | |
| results = self.model.predict(frame) | |
| if len(results) > 0: | |
| # Extract the bounding boxes and class names | |
| boxes = results[0].boxes.cpu().numpy().data | |
| class_names = self.model.names # Load class names if you need them | |
| for box in boxes: | |
| x1, y1, x2, y2, conf, class_id = box.tolist() # Convert normalized coordinates | |
| # convert to int | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| label = f'{class_names[class_id]}: {conf:.2f}' | |
| # Draw bounding box and label | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), | |
| rgb_colors[int(class_id)], 2) | |
| cv2.putText(frame, label, (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, rgb_colors[int(class_id)], 2) | |
| # Write the grayscale frame to the output video | |
| video_writer.write(frame) | |
| # Release resources | |
| video_capture.release() | |
| video_writer.release() | |
| # Return the output video path | |
| return output_video_path | |
| except Exception as e: | |
| print(e) | |
| return None | |