Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| import os | |
| import json | |
| from tqdm import tqdm | |
| import project_subpath | |
| from backend.dataloader import create_dataloader_frames_only | |
| from backend.inference import setup_model, do_detection, do_suppression | |
| from backend.InferenceConfig import InferenceConfig | |
| from lib.yolov5.utils.general import clip_boxes, scale_boxes | |
| def main(args, config=InferenceConfig(), verbose=False): | |
| """ | |
| Construct and save MOT format detections from yolov5 based on a frame directory | |
| Args: | |
| frames (str): path to image directory | |
| output (str): where MOT detections will be stored | |
| weights (str): path to model weights | |
| """ | |
| print("In task...") | |
| print("Cuda available in task?", torch.cuda.is_available()) | |
| print("Config:", config.to_dict()) | |
| model, device = setup_model(args.weights) | |
| in_loc_dir = os.path.join(args.frames, args.location) | |
| out_loc_dir = os.path.join(args.output, args.location) | |
| metadata_path = os.path.join(args.metadata, args.location + ".json") | |
| print(in_loc_dir) | |
| print(out_loc_dir) | |
| print(metadata_path) | |
| detect_location(in_loc_dir, out_loc_dir, metadata_path, config, model, device, verbose) | |
| def detect_location(in_loc_dir, out_loc_dir, metadata_path, config, model, device, verbose): | |
| seq_list = os.listdir(in_loc_dir) | |
| with tqdm(total=len(seq_list), desc="...", ncols=0) as pbar: | |
| for seq in seq_list: | |
| image_meter = (-1, -1) | |
| with open(metadata_path, 'r') as f: | |
| json_object = json.loads(f.read()) | |
| for sequence in json_object: | |
| if sequence['clip_name'] == seq: | |
| image_meter = ( | |
| sequence['x_meter_stop'] - sequence['x_meter_start'], | |
| sequence['y_meter_stop'] - sequence['y_meter_start'] | |
| ) | |
| pbar.update(1) | |
| if (seq.startswith(".")): continue | |
| pbar.set_description("Processing " + seq) | |
| in_seq_dir = os.path.join(in_loc_dir, seq) | |
| out_seq_dir = os.path.join(out_loc_dir, seq) | |
| os.makedirs(out_seq_dir, exist_ok=True) | |
| detect_seq(in_seq_dir, out_seq_dir, image_meter, config, model, device, verbose) | |
| def detect_seq(in_seq_dir, out_seq_dir, image_meter, config, model, device, verbose): | |
| ann_list = [] | |
| frame_list = detect(in_seq_dir, image_meter, config, model, device, verbose) | |
| for frame in frame_list: | |
| if frame is not None: | |
| for ann in frame: | |
| ann_list.append({ | |
| 'image_id': ann[5], | |
| 'category_id': 0, | |
| 'bbox': [ann[0], ann[1], ann[2] - ann[0], ann[3] - ann[1]], | |
| 'score': ann[4] | |
| }) | |
| result = json.dumps(ann_list) | |
| with open(os.path.join(out_seq_dir, 'pred.json'), 'w') as f: | |
| f.write(result) | |
| def detect(in_dir, image_meter, config, model, device, verbose): | |
| #progress_log = lambda p, m: 0 | |
| # create dataloader | |
| dataloader = create_dataloader_frames_only(in_dir) | |
| inference, image_shapes, width, height = do_detection(dataloader, model, device, verbose=verbose) | |
| outputs = do_suppression(inference, image_meter_width=image_meter[0], image_pixel_width=image_meter[1], conf_thres=config.conf_thresh, iou_thres=config.nms_iou, verbose=verbose) | |
| file_names = dataloader.files | |
| frame_list = [] | |
| for batch_i, batch in enumerate(outputs): | |
| batch_shapes = image_shapes[batch_i] | |
| # Format results | |
| for si, pred in enumerate(batch): | |
| (image_shape, original_shape) = batch_shapes[si] | |
| # Clip boxes to image bounds and resize to input shape | |
| clip_boxes(pred, (height, width)) | |
| boxes = pred[:, :4].clone() # xyxy | |
| confs = pred[:, 4].clone().tolist() | |
| scale_boxes(image_shape, boxes, original_shape[0], original_shape[1]) # to original shape | |
| frame = [ [*bb, conf] for bb, conf in zip(boxes.tolist(), confs) ] | |
| file_name = file_names[batch_i*32 + si] | |
| for ann in frame: | |
| ann.append(file_name) | |
| frame_list.append(frame) | |
| return frame_list | |
| def argument_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--frames", required=True, help="Path to frame directory. Required.") | |
| parser.add_argument("--metadata", required=True, help="Path to frame directory. Required.") | |
| parser.add_argument("--location", required=True, help="Name of location dir. Required.") | |
| parser.add_argument("--output", required=True, help="Path to output directory. Required.") | |
| parser.add_argument("--weights", default='models/v5m_896_300best.pt', help="Path to saved YOLOv5 weights. Default: ../models/v5m_896_300best.pt") | |
| return parser | |
| if __name__ == "__main__": | |
| args = argument_parser().parse_args() | |
| main(args) |