from argparse import ArgumentParser from pathlib import Path from urllib.request import urlretrieve import numpy as np from ultralytics import YOLO from detect import detect_for_video from input_output.run_config import parse_config_from_file from input_output.tao_format_output import write_tao_format_output from input_output.video_output import write_video_output from tracking.tracker import Tracker MODELS_DIR = Path(__file__).parent.parent / "models" MODEL_URLS = { "hypertuned_yolov11xl": "https://huggingface.co/RonenRusinov/HeatVision_YOLO_finetune/resolve/main/Best_hypertuned_YOLO11.pt", "finetuned_yolov11xl": "https://huggingface.co/RonenRusinov/HeatVision_YOLO_finetune/resolve/main/best_yolo11.pt", "finetuned_yolov8xl": "https://huggingface.co/RonenRusinov/HeatVision_YOLO_finetune/resolve/main/best_yolo8.pt", } CLASS_LABELS = [ "person", "bike", "car", "motor", "airplane", "bus", "train", "truck", "boat", "light", "hydrant", "sign", "parking meter", "bench", "bird", "cat", "dog", "deer", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "stroller", "rider", "scooter", "vase", "scissors", "face", "other vehicle", "license plate", ] def main(config_file_path, input_video_path, output_video_path, output_tao_path): model_name, conf, min_match_score, min_appearance_frames, max_missing_frames = parse_config_from_file(config_file_path) model = YOLO(download_and_get_model(model_name)) confusion_matrix = load_confusion_matrix(model_name) raw_detections = detect_for_video(model, input_video_path, conf) tracker = Tracker(confusion_matrix) tracker.advance_frames(raw_detections) tracker.finish() if output_tao_path is not None: video_name = Path(input_video_path).stem write_tao_format_output([(video_name, tracker)], output_tao_path) print(f"TAO-like formatted annotations saved at {output_video_path}") if output_video_path is not None: write_video_output(input_video_path, output_video_path, tracker, CLASS_LABELS) print(f"Video with bounding boxes saved at {output_video_path}") def download_and_get_model(model_name): model_file = MODELS_DIR / f"{model_name}.pt" if not model_file.is_file(): print(f"Downloading {model_name} from huggingface") urlretrieve(MODEL_URLS[model_name], model_file) print("Download Complete") return model_file def load_confusion_matrix(model_name): return np.load(str((MODELS_DIR / f"{model_name}.confusion_matrix"))) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--run-config", help="Path to the model and tracker configuration json file", dest="run_config", required=True) parser.add_argument("--input-video", help="Path to the input video file (supported format: .mp4)", dest="input_video", required=True) parser.add_argument("--output-video", help="If given, the input video with bounding boxes will be saved in this path", dest="output_video") parser.add_argument("--output-tao-annotations", help="If given, annotations in a tao-like format will be saved in this path", dest="output_tao_annotations") args = parser.parse_args() if args.output_video is None and args.output_tao_annotations is None: print("No output option given, use one or both of '--output-video' or '--output-tao-annotations'") quit() main(args.run_config, args.input_video, args.output_video, args.output_tao_annotations)