Spaces:
Runtime error
Runtime error
| import torch | |
| import subprocess | |
| import numpy as np | |
| import supervision as sv | |
| import cv2 | |
| import os | |
| from glob import glob | |
| from tqdm import tqdm | |
| import math | |
| def plot_predictions( | |
| image: str, | |
| labels: list[str], | |
| scores: list[float], | |
| boxes: list[float], | |
| opacity: float = 1.0 | |
| ) -> np.ndarray: | |
| image_source = cv2.imread(image) | |
| image_source_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB) | |
| boxes = sv.Detections(xyxy=boxes) | |
| labels = [ | |
| f"{phrase} {logit:.2f}" | |
| for phrase, logit | |
| in zip(labels, scores) | |
| ] | |
| height, width, _ = image_source_rgb.shape | |
| thickness = math.ceil(width/200) | |
| text_scale = width/1500 | |
| text_thickness = math.ceil(text_scale*1.5) | |
| bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness) | |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, text_thickness=text_thickness) | |
| # Create a semi-transparent overlay | |
| overlay = image_source_rgb.copy() | |
| # Apply bounding box annotations to the overlay | |
| overlay = bbox_annotator.annotate(scene=overlay, detections=boxes) | |
| overlay = label_annotator.annotate(scene=overlay, detections=boxes, labels=labels) | |
| # Blend overlay with original image using the specified opacity | |
| annotated_frame = cv2.addWeighted(overlay, opacity, image_source_rgb, 1 - opacity, 0) | |
| annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) | |
| return annotated_frame_bgr | |
| def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str: | |
| """ Converts mp4 to pngs for each frame of the video. | |
| Args: input_path is the path to the mp4 file, save_path is the directory to save the frames. | |
| Returns: save_path, fps the number of frames per second. | |
| """ | |
| # get frames per second | |
| fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS)) | |
| # run subprocess to convert mp4 to pngs | |
| os.system(f"ffmpeg -i {input_path} -vf 'fps={fps},scale=iw*{scale_factor}:ih*{scale_factor}' {save_path}/frame%08d.png") | |
| # subprocess.run(["ffmpeg", "-i", input_path, "-vf", f"scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}", f"{save_path}/frame%08d.png"]) | |
| return fps | |
| def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str: | |
| """ | |
| Takes a list of frames as numpy arrays and writes them to a video file. | |
| """ | |
| # Get the list of frames | |
| frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png'))) | |
| # Prepare the VideoWriter | |
| frame = cv2.imread(frame_list[0]) | |
| height, width, _ = frame.shape | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # Use multithreading to read frames faster | |
| from concurrent.futures import ThreadPoolExecutor | |
| with ThreadPoolExecutor() as executor: | |
| frames = list(executor.map(cv2.imread, frame_list)) | |
| # Write frames to the video | |
| with tqdm(total=len(frame_list), desc='Stitching frames') as pbar: | |
| for frame in frames: | |
| out.write(frame) | |
| pbar.update(1) | |
| return output_path | |
| def count_pos(phrases, text_target): | |
| """ | |
| Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase | |
| """ | |
| num_pos = 0 | |
| for sublist in phrases: | |
| if sublist == None: | |
| continue | |
| for phrase in sublist: | |
| if phrase == text_target: | |
| num_pos += 1 | |
| break | |
| return num_pos |