Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| import os | |
| import json | |
| from tqdm import tqdm | |
| import project_subpath | |
| from backend.InferenceConfig import InferenceConfig | |
| from backend.inference import do_full_tracking | |
| def main(args, config=InferenceConfig(), verbose=True): | |
| """ | |
| Convert raw detections to tracks and saves the tracking json result | |
| Args: | |
| detections (str): path to raw detections directory. Required | |
| output (str): where tracking result will be stored. Required | |
| metadata (str): path to metadata directory. Required | |
| tracker (str): arbitrary name of tracker folder that you want to save trajectories to | |
| """ | |
| print("running detections_to_tracks.py with:", config.to_dict()) | |
| loc = args.location | |
| in_loc_dir = os.path.join(args.detections, loc) | |
| out_loc_dir = os.path.join(args.output, loc, args.tracker, "data") | |
| os.makedirs(out_loc_dir, exist_ok=True) | |
| metadata_path = os.path.join(args.metadata, loc + ".json") | |
| print(in_loc_dir) | |
| print(out_loc_dir) | |
| print(metadata_path) | |
| track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose) | |
| def track_location(in_loc_dir, out_loc_dir, metadata_path, config, verbose): | |
| seq_list = os.listdir(in_loc_dir) | |
| with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar: | |
| for seq in seq_list: | |
| pbar.update(1) | |
| if (seq.startswith(".")): continue | |
| pbar.set_description("Processing " + seq) | |
| track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose) | |
| def track(in_loc_dir, out_loc_dir, metadata_path, seq, config, verbose): | |
| json_path = os.path.join(in_loc_dir, seq, 'pred.json') | |
| inference_path = os.path.join(in_loc_dir, seq, 'inference.pt') | |
| out_path = os.path.join(out_loc_dir, seq + ".txt") | |
| device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| device = torch.device(device_name) | |
| inference = torch.load(inference_path, map_location=device) | |
| # read detection | |
| with open(json_path, 'r') as f: | |
| detection = json.load(f) | |
| image_shapes = detection['image_shapes'] | |
| width = detection['width'] | |
| height = detection['height'] | |
| # read metadata | |
| image_meter_width = -1 | |
| image_meter_height = -1 | |
| with open(metadata_path, 'r') as f: | |
| json_object = json.loads(f.read()) | |
| for sequence in json_object: | |
| if sequence['clip_name'] == seq: | |
| image_meter_width = sequence['x_meter_stop'] - sequence['x_meter_start'] | |
| image_meter_height = sequence['y_meter_start'] - sequence['y_meter_stop'] | |
| # assume all images in the sequence have the same shape | |
| real_width = image_shapes[0][0][0][1] | |
| real_height = image_shapes[0][0][0][0] | |
| # perform tracking | |
| results = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=config, gp=None, verbose=verbose) | |
| # write tracking result | |
| mot_rows = [] | |
| for frame in results['frames']: | |
| for fish in frame['fish']: | |
| bbox = fish['bbox'] | |
| row = [] | |
| right = bbox[0]*real_width | |
| top = bbox[1]*real_height | |
| w = bbox[2]*real_width - bbox[0]*real_width | |
| h = bbox[3]*real_height - bbox[1]*real_height | |
| row.append(str(frame['frame_num'] + 1)) | |
| row.append(str(fish['fish_id'] + 1)) | |
| row.append(str(int(right))) | |
| row.append(str(int(top))) | |
| row.append(str(int(w))) | |
| row.append(str(int(h))) | |
| row.append("-1") | |
| row.append("-1") | |
| row.append("-1") | |
| row.append("-1") | |
| mot_rows.append(",".join(row)) | |
| mot_text = "\n".join(mot_rows) | |
| with open(out_path, 'w') as f: | |
| f.write(mot_text) | |
| def argument_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--detections", required=True, help="Path to frame directory. Required.") | |
| parser.add_argument("--location", required=True, help="Name of location dir. Required.") | |
| parser.add_argument("--output", required=True, help="Path to output directory. Required.") | |
| parser.add_argument("--metadata", required=True, help="Path to metadata directory. Required.") | |
| parser.add_argument("--tracker", default='tracker', help="Tracker name.") | |
| return parser | |
| if __name__ == "__main__": | |
| args = argument_parser().parse_args() | |
| main(args) |