import cv2 from ultralytics import YOLO import random import gradio as gr from tqdm import tqdm class yolo_model(): def __init__(self, model_name: str): """ Initialize the YOLO-World model Args: model_name (str): The name of the model file. """ # Initialize a YOLO-World model self.model = YOLO(model_name) def load(self, model_name: str): """ Load the YOLO model Args: model_to_load (str): The name of the model file. """ try: # Load the model self.model = YOLO(model_name) except Exception as e: print(e) # Define a function to process a video def process(self, video_path: str, prompt: str, confidence: float, iou: float, progress=gr.Progress(track_tqdm=True) ) -> str: """ Process a video with YOLO-World Args: video_path (str): The input video path. confidence (float): The confidence threshold. iou (float): The IoU threshold. Returns: str: The output video path. """ try: # create a list of classes based on prompt, each class is separated by a comma classes = prompt.split(",") if prompt else None # Define the colors for each class rgb_colors = [(random.randint(0, 255), random.randint( 0, 255), random.randint(0, 255)) for _ in range(len(classes))] # Define custom classes self.model.set_classes(classes) # Set confidence and IoU thresholds self.model.conf = confidence self.model.iou = iou # Open the video file video_capture = cv2.VideoCapture(video_path) # Get the video properties frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(video_capture.get(cv2.CAP_PROP_FPS)) n_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) # Define the output video path output_video_path = 'output.mp4' # Define the codec and create VideoWriter object fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') video_writer = cv2.VideoWriter( output_video_path, fourcc, fps, (frame_width, frame_height), isColor=True) # Process each frame in the video for _ in tqdm(range(n_frames), desc="Processing video", file=progress): ret, frame = video_capture.read() if not ret: break # Break the loop when no frames are left # Run inference to detect your custom classes results = self.model.predict(frame) if len(results) > 0: # Extract the bounding boxes and class names boxes = results[0].boxes.cpu().numpy().data class_names = self.model.names # Load class names if you need them for box in boxes: x1, y1, x2, y2, conf, class_id = box.tolist() # Convert normalized coordinates # convert to int x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) label = f'{class_names[class_id]}: {conf:.2f}' # Draw bounding box and label cv2.rectangle(frame, (x1, y1), (x2, y2), rgb_colors[int(class_id)], 2) cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, rgb_colors[int(class_id)], 2) # Write the grayscale frame to the output video video_writer.write(frame) # Release resources video_capture.release() video_writer.release() # Return the output video path return output_video_path except Exception as e: print(e) return None