Spaces:
Runtime error
Runtime error
| import itertools | |
| import json | |
| import os | |
| import pickle | |
| from argparse import ArgumentParser | |
| from functools import partial | |
| from multiprocessing import cpu_count, Pool | |
| import numpy as np | |
| from tqdm import tqdm | |
| from ultralytics import YOLO | |
| from input_output.tao_format_output import generate_tao_format_output | |
| from tracking.tracker import Annotation, Box, Tracker | |
| RAW_RESULTS_PICKLE_NAME = "raw_results.pkl" | |
| MIN_SCORE_FOR_MATCH_VALUES = list(np.arange(0.1, 1.0, 0.25)) | |
| MIN_FRAMES_VALUES = range(1, 26, 5) | |
| MAX_MISSING_FRAMES_VALUES = range(0, 25, 5) | |
| # MIN_SCORE_FOR_MATCH_VALUES = list(np.arange(0.01, 0.17, 0.025)) | |
| # MIN_FRAMES_VALUES = range(10, 17, 1) | |
| # MAX_MISSING_FRAMES_VALUES = range(7, 13, 1) | |
| PARAMETER_VALUES = list(itertools.product(MIN_SCORE_FOR_MATCH_VALUES, MIN_FRAMES_VALUES, MAX_MISSING_FRAMES_VALUES)) | |
| def predict_without_tracking_single_video(model, frames): | |
| raw_results_per_frame = list() | |
| for frame in tqdm(frames, leave=False): | |
| raw_results = model.predict(frame[1], verbose=False, conf=0.001) | |
| raw_results_per_frame.append(raw_results[0]) | |
| return raw_results_per_frame | |
| def filter_and_get_annotations_for_video(raw_results_per_frame, confidence): | |
| return [ | |
| filter_and_get_annotations_for_frame(raw_results, confidence) | |
| for raw_results in raw_results_per_frame | |
| ] | |
| def filter_and_get_annotations_for_frame(raw_results, confidence): | |
| return [ | |
| Annotation(Box(*box.xyxy[0].tolist()), int(box.cls[0]), float(box.conf[0])) | |
| for box in raw_results.boxes | |
| if float(box.conf[0]) >= confidence and int(box.cls[0]) == 0 # tmot dataset contains only "person" annotations | |
| ] | |
| def parse_tao_annotations(tao_annotations_file_path): | |
| with open(tao_annotations_file_path) as tao_annotations_file: | |
| return json.load(tao_annotations_file) | |
| def parse_video_frames_from_tao(video_id, tao_annotations, images_seq_dir_path): | |
| video_name = next(video["name"] for video in tao_annotations["videos"] if video["id"] == video_id) | |
| return sorted([ | |
| ( | |
| image["frame_index"], | |
| os.path.join(images_seq_dir_path, video_name, "thermal", image["file_name"]) | |
| ) | |
| for image in tao_annotations["images"] | |
| if image["video_id"] == video_id | |
| ], key=lambda x: x[0]) | |
| def track_for_params_all_videos(untracked_results_for_video_id, params): | |
| return { | |
| video_id: track_for_params_single_video(untracked_results, params) | |
| for video_id, untracked_results in untracked_results_for_video_id.items() | |
| } | |
| def track_for_params_single_video(untracked_results, params): | |
| min_score_for_match, min_frames, max_missing_frames = params | |
| tracker = Tracker(np.full((1, 1), 1), min_score_for_match=min_score_for_match, min_frames=min_frames, max_missing_frames=max_missing_frames) | |
| for raw_annotations in untracked_results: | |
| tracker.advance_frame(raw_annotations) | |
| tracker.finish() | |
| return tracker | |
| def track_for_params_and_save_results(params, untracked_results_for_video_id, video_name_per_id, video_ids, results_dir_path, confidence): | |
| results_per_video_id = track_for_params_all_videos(untracked_results_for_video_id, params) | |
| tao_output = generate_tao_format_output([ | |
| (video_name_per_id[video_id], results_per_video_id[video_id]) | |
| for video_id in video_ids | |
| ]) | |
| min_score_for_match, min_frames, max_missing_frames = params | |
| save_path = ( | |
| results_dir_path, f"{confidence}", | |
| f"{min_score_for_match}_{min_frames}_{max_missing_frames}", | |
| "data", | |
| ) | |
| os.makedirs(os.path.join(*save_path), exist_ok=True) | |
| with open(os.path.join(*save_path, "results.json"), "w", ) as results_file: | |
| json.dump(tao_output["annotations"], results_file, indent=4) | |
| def main(model_name, confidence, images_seq_dir_path, tao_annotations_file_path, results_dir_path, use_pickle): | |
| model = YOLO(model_name) | |
| tao_annotations = parse_tao_annotations(tao_annotations_file_path) | |
| video_ids = sorted([video["id"] for video in tao_annotations["videos"]]) | |
| video_name_per_id = { | |
| video["id"]: video["name"] | |
| for video in tao_annotations["videos"] | |
| } | |
| frames_per_video_id = { | |
| video_id: parse_video_frames_from_tao(video_id, tao_annotations, images_seq_dir_path) | |
| for video_id in video_ids | |
| } | |
| if not use_pickle or not os.path.isfile(RAW_RESULTS_PICKLE_NAME): | |
| print("predicting on videos") | |
| raw_results_for_video_ids = dict() | |
| for video_id, frames in tqdm(frames_per_video_id.items()): | |
| raw_results_for_video_ids[video_id] = predict_without_tracking_single_video(model, frames) | |
| with open(RAW_RESULTS_PICKLE_NAME, "wb") as f: | |
| pickle.dump(raw_results_for_video_ids, f) | |
| else: | |
| print("loading predictions from pickle") | |
| with open(RAW_RESULTS_PICKLE_NAME, "rb") as f: | |
| raw_results_for_video_ids = pickle.load(f) | |
| untracked_results_for_video_id = { | |
| video_id: filter_and_get_annotations_for_video(raw_results_per_frame, confidence) | |
| for video_id, raw_results_per_frame in raw_results_for_video_ids.items() | |
| } | |
| print("tracking for all parameters") | |
| worker = partial( | |
| track_for_params_and_save_results, | |
| untracked_results_for_video_id=untracked_results_for_video_id, | |
| video_name_per_id=video_name_per_id, | |
| video_ids=video_ids, | |
| results_dir_path=results_dir_path, | |
| confidence=confidence, | |
| ) | |
| with Pool(processes=max(1, 1 - cpu_count())) as pool: | |
| list(tqdm(pool.imap_unordered(worker, PARAMETER_VALUES), total=len(PARAMETER_VALUES))) | |
| if __name__ == '__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument("model") | |
| parser.add_argument("confidence", type=float) | |
| parser.add_argument("images_seq_dir") | |
| parser.add_argument("tao_annotations") | |
| parser.add_argument("results_dir") | |
| parser.add_argument("--ignore-pickle", action="store_true") | |
| args = parser.parse_args() | |
| main(args.model, args.confidence, args.images_seq_dir, args.tao_annotations, args.results_dir, not args.ignore_pickle) | |