Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from tqdm import tqdm | |
| import cv2 | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
| import math | |
| import zipfile | |
| from utils import plot_predictions, mp4_to_png, vid_stitcher | |
| def owl_batch_video( | |
| input_vids: list[str], | |
| target_prompt: list[str], | |
| species_prompt: str, | |
| threshold: float, | |
| fps_processed: int = 1, | |
| scaling_factor: float = 0.5, | |
| batch_size: int = 8, | |
| save_dir: str = "temp/", | |
| progress=gr.Progress() | |
| ): | |
| pos_preds = [] | |
| neg_preds = [] | |
| df = pd.DataFrame(columns=["video name", "detection?"]) | |
| for vid in progress.tqdm(input_vids, desc="Processing videos"): | |
| vid_name = os.path.basename(vid) | |
| detection = owl_video_detection(vid, | |
| target_prompt, | |
| species_prompt, | |
| threshold, | |
| fps_processed=fps_processed, | |
| scaling_factor=scaling_factor, | |
| batch_size=batch_size, | |
| save_dir=save_dir) | |
| if detection == True: | |
| pos_preds.append(vid) | |
| row = pd.DataFrame({"video name": [vid_name], "detection?": ["True"]}) | |
| df = pd.concat([df, row], ignore_index=True) | |
| else: | |
| neg_preds.append(vid) | |
| row = pd.DataFrame({"video name": [vid_name], "detection?": ["False"]}) | |
| df = pd.concat([df, row], ignore_index=True) | |
| # save the df | |
| df.to_csv(f"{save_dir}/{target_prompt}_detection_results.csv") | |
| # zip the save_dir | |
| zip_file = f"results.zip" | |
| zip_directory(save_dir, zip_file) | |
| return zip_file | |
| def zip_directory(folder_path, output_zip_path): | |
| with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for root, dirs, files in os.walk(folder_path): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| # Write the file with a relative path to preserve folder structure | |
| arcname = os.path.relpath(file_path, start=folder_path) | |
| zipf.write(file_path, arcname) | |
| def preprocess_text(text_prompt: str, num_prompts: int = 1): | |
| """ | |
| Takes a string of text prompts and returns a list of lists of text prompts for each image. | |
| i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]] | |
| """ | |
| text_prompt = [s.strip() for s in text_prompt.split(",")] | |
| text_queries = [text_prompt] * num_prompts | |
| # print("text_queries:", text_queries) | |
| return text_queries | |
| def owl_batch_prediction( | |
| images: torch.Tensor, | |
| text_queries : list[str], # assuming that every image is queried with the same text prompt | |
| threshold: float, | |
| processor, | |
| model, | |
| device: str = 'cuda' | |
| ): | |
| inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Target image sizes (height, width) to rescale box predictions [batch_size, 2] | |
| target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) | |
| # Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold | |
| results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) | |
| return results | |
| def count_pos(phrases: list[str], text_targets: list[str]) -> int: | |
| """ | |
| Counts how many phrases in the list match any of the target phrases. | |
| Args: | |
| phrases: A list of strings to evaluate. | |
| text_targets: A list of target strings to match against. | |
| Returns: | |
| The number of phrases that match any of the targets. | |
| """ | |
| if len(phrases) == 0 or len(text_targets) == 0: | |
| return 0 | |
| target_set = set(text_targets) | |
| return sum(1 for phrase in phrases if phrase in target_set) | |
| def owl_video_detection( | |
| vid_path: str, | |
| text_target: list[str], | |
| text_prompt: str, | |
| threshold: float, | |
| fps_processed: int = 1, | |
| scaling_factor: float = 0.5, | |
| processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"), | |
| model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'), | |
| device: str = 'cuda', | |
| batch_size: int = 8, | |
| save_dir: str = "temp/", | |
| ): | |
| """ | |
| Runs owl on a video and saves the results to a dataframe. | |
| Returns True if text_target is detected in the video, False otherwise. | |
| Stops running owl when a text_target is detected. | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| os.makedirs(f"{save_dir}/{text_target}_positives", exist_ok=True) | |
| os.makedirs(f"{save_dir}/{text_target}_negatives", exist_ok=True) | |
| # set up df for results | |
| df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"]) | |
| # create new dirs and paths for results | |
| filename = os.path.splitext(os.path.basename(vid_path))[0] | |
| frames_dir = f"{save_dir}/{filename}_frames" | |
| os.makedirs(frames_dir, exist_ok=True) | |
| # process video and create a directory of video frames | |
| fps = mp4_to_png(vid_path, frames_dir, scaling_factor) | |
| # get all frame paths | |
| frame_filenames = os.listdir(frames_dir) | |
| frame_paths = [] # list of frame paths to process based on fps_processed | |
| # for every frame processed, add to frame_paths | |
| for i, frame in enumerate(frame_filenames): | |
| if i % fps_processed == 0: | |
| frame_paths.append(os.path.join(frames_dir, frame)) | |
| # run owl in batches | |
| for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): | |
| frame_nums = [i*fps_processed for i in range(batch_size)] | |
| batch_paths = frame_paths[i:i+batch_size] # paths for this batch | |
| images = [Image.open(image_path) for image_path in batch_paths] | |
| # run owl on this batch of frames | |
| text_queries = preprocess_text(text_prompt, len(batch_paths)) | |
| results = owl_batch_prediction(images, text_queries, threshold, processor, model, device) | |
| # get the boxes, logits, and phrases for this batch | |
| label_ids = [] | |
| for entry in results: | |
| if entry['labels'].numel() > 0: | |
| label_ids.append(entry['labels'].tolist()) | |
| else: | |
| label_ids.append(None) | |
| text = text_queries[0] # assuming that all texts in query are the same for each image | |
| labels = [] | |
| # convert label_ids to phrases, if no phrases, append None | |
| for idx in label_ids: | |
| if idx is not None: | |
| idx = [text[id] for id in idx] | |
| labels.append(idx) | |
| else: | |
| labels.append([]) | |
| batch_pos = 0 | |
| for j, image in enumerate(batch_paths): | |
| image_name = os.path.basename(image) | |
| boxes = results[j]['boxes'].cpu().numpy() | |
| scores = results[j]['scores'].cpu().numpy() | |
| count = count_pos(labels[j], text_target) | |
| row = pd.DataFrame({"frame": [image_name], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count}) | |
| df = pd.concat([df, row], ignore_index=True) | |
| # if there are detections, save the frame replacing the original frame | |
| if count > 0: | |
| annotated_frame = plot_predictions(image, labels[j], scores, boxes) | |
| cv2.imwrite(image, annotated_frame) | |
| batch_pos += 1 | |
| # if more than 2/3 batch frames are positive, return True | |
| if batch_pos > math.ceil(2/3*batch_size): | |
| vid_stitcher(frames_dir, f"{save_dir}/{text_target}_positives/{filename}_{threshold}.mp4", fps) | |
| shutil.rmtree(frames_dir) # delete the frames to save space | |
| df.to_csv(f"{save_dir}/{text_target}_positives/{filename}_{threshold}.csv", index=False) | |
| return True | |
| shutil.rmtree(frames_dir) # delete the frames to save space | |
| df.to_csv(f"{save_dir}/{text_target}_negatives/{filename}_{threshold}.csv", index=False) | |
| return False | |