Spaces:
Runtime error
Runtime error
| import project_path | |
| import torch | |
| from tqdm import tqdm | |
| from functools import partial | |
| import numpy as np | |
| import json | |
| from unittest.mock import patch | |
| import math | |
| import torch | |
| import torchvision | |
| # assumes yolov5 on sys.path | |
| from lib.yolov5.models.experimental import attempt_load | |
| from lib.yolov5.utils.torch_utils import select_device | |
| from lib.yolov5.utils.general import clip_boxes, scale_boxes, xywh2xyxy | |
| from lib.yolov5.utils.metrics import box_iou | |
| from lib.fish_eye.tracker import Tracker | |
| from lib.fish_eye.tracker_bytetrack import Associate | |
| from backend.InferenceConfig import InferenceConfig, TrackerType | |
| ### Configuration options | |
| WEIGHTS = 'models/v5m_896_300best.pt' | |
| # will need to configure these based on GPU hardware | |
| BATCH_SIZE = 32 | |
| CONF_THRES = 0.05 # detection | |
| NMS_IOU = 0.2 # NMS IOU | |
| MAX_AGE = 14 # time until missing fish get's new id | |
| MIN_HITS = 16 # minimum number of frames with a specific fish for it to count | |
| MIN_LENGTH = 0.3 # minimum fish length, in meters | |
| IOU_THRES = 0.01 # IOU threshold for tracking | |
| MIN_TRAVEL = -1 # Minimum distance a track has to travel | |
| ### | |
| def norm(bbox, w, h): | |
| """ | |
| Normalize a bounding box. | |
| Args: | |
| bbox: list of length 4. Can be [x,y,w,h] or [x0,y0,x1,y1] | |
| w: image width | |
| h: image height | |
| """ | |
| bb = bbox.copy() | |
| bb[0] /= w | |
| bb[1] /= h | |
| bb[2] /= w | |
| bb[3] /= h | |
| return bb | |
| def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None, config=InferenceConfig(), verbose=True): | |
| # Set up model | |
| model, device = setup_model(config.weights) | |
| # Detect boxes in frames | |
| inference, image_shapes, width, height = do_detection(dataloader, model, device, gp=gp, verbose=verbose) | |
| result = do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=InferenceConfig(), gp=None, verbose=verbose) | |
| return result | |
| def do_full_tracking(inference, image_shapes, image_meter_width, image_meter_height, width, height, config=InferenceConfig(), gp=None, verbose=True): | |
| if config.associative_tracker == TrackerType.BYTETRACK: | |
| # Find low confidence detections | |
| low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp, verbose=verbose) | |
| low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp, verbose=verbose) | |
| # Find high confidence detections | |
| high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp, verbose=verbose) | |
| high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp, verbose=verbose) | |
| # Perform associative tracking (ByteTrack) | |
| results = do_associative_tracking( | |
| low_preds, high_preds, image_meter_width, image_meter_height, | |
| reverse=False, min_length=config.min_length, min_travel=config.min_travel, | |
| max_age=config.max_age, min_hits=config.min_hits, | |
| gp=gp, verbose=verbose) | |
| else: | |
| # Find confident detections | |
| outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp, verbose=verbose) | |
| if config.associative_tracker == TrackerType.CONF_BOOST: | |
| # Boost confidence based on found confident detections | |
| do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp, verbose=verbose) | |
| # Find confident detections from boosted list | |
| outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, max_length=config.max_length, gp=gp, verbose=verbose) | |
| # Format confident detections | |
| all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp, verbose=verbose) | |
| # Perform SORT tracking | |
| results = do_tracking( | |
| all_preds, image_meter_width, image_meter_height, | |
| min_length=config.min_length, min_travel=config.min_travel, | |
| max_age=config.max_age, min_hits=config.min_hits, | |
| gp=gp, verbose=verbose) | |
| return results | |
| def setup_model(weights_fp=WEIGHTS, imgsz=896, batch_size=32): | |
| if torch.cuda.is_available(): | |
| device = select_device('0', batch_size=batch_size) | |
| else: | |
| print("CUDA not available. Using CPU inference.") | |
| device = select_device('cpu', batch_size=batch_size) | |
| # Setup model for inference | |
| model = attempt_load(weights_fp, device=device) | |
| half = device.type != 'cpu' # half precision only supported on CUDA | |
| if half: | |
| model.half() | |
| model.eval() | |
| # Create dataloader for batched inference | |
| img = torch.zeros((1, 3, imgsz, imgsz), device=device) | |
| _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |
| return model, device | |
| def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verbose=True): | |
| """ | |
| Args: | |
| frames_dir: a directory containing frames to be evaluated | |
| image_meter_width: the width of each image, in meters (used for fish length calculation) | |
| gp: a callback function which takes as input 1 parameter, (int) percent complete | |
| prep_for_marking: re-index fish for manual marking output | |
| """ | |
| if (gp): gp(0, "Detection...") | |
| inference = [] | |
| image_shapes = [] | |
| # Run detection | |
| with tqdm(total=len(dataloader)*batch_size, desc="Running detection", ncols=0, disable=not verbose) as pbar: | |
| for batch_i, (img, _, shapes) in enumerate(dataloader): | |
| if gp: gp(batch_i / len(dataloader), pbar.__str__()) | |
| img = img.to(device, non_blocking=True) | |
| img = img.half() if device.type != 'cpu' else img.float() # uint8 to fp16/32 | |
| img /= 255.0 # 0 - 255 to 0.0 - 1.0 | |
| size = tuple(img.shape) | |
| nb, _, height, width = size # batch size, channels, height, width | |
| # Run model & NMS | |
| with torch.no_grad(): | |
| inf_out, _ = model(img, augment=False) | |
| # Save shapes for resizing to original shape | |
| batch_shape = [] | |
| for si, pred in enumerate(inf_out): | |
| batch_shape.append((img[si].shape[1:], shapes[si])) | |
| image_shapes.append(batch_shape) | |
| inference.append(inf_out) | |
| pbar.update(1*batch_size) | |
| return inference, image_shapes, width, height | |
| def do_suppression(inference, image_meter_width, image_pixel_width, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, max_length=1.5, verbose=True): | |
| """ | |
| Args: | |
| frames_dir: a directory containing frames to be evaluated | |
| image_meter_width: the width of each image, in meters (used for fish length calculation) | |
| gp: a callback function which takes as input 1 parameter, (int) percent complete | |
| prep_for_marking: re-index fish for manual marking output | |
| """ | |
| if (gp): gp(0, "Suppression...") | |
| # keep predictions to feed them ordered into the Tracker | |
| # TODO: how to deal with large files? | |
| outputs = [] | |
| with tqdm(total=len(inference)*batch_size, desc="Running suppression", ncols=0, disable=not verbose) as pbar: | |
| for batch_i, inf_out in enumerate(inference): | |
| if gp: gp(batch_i / len(inference), pbar.__str__()) | |
| with torch.no_grad(): | |
| output = non_max_suppression(inf_out, image_meter_width, image_pixel_width, conf_thres=conf_thres, iou_thres=iou_thres, max_length=max_length) | |
| outputs.append(output) | |
| pbar.update(1*batch_size) | |
| return outputs | |
| def format_predictions(image_shapes, outputs, width, height, gp=None, batch_size=BATCH_SIZE, verbose=True): | |
| """ | |
| Args: | |
| frames_dir: a directory containing frames to be evaluated | |
| image_meter_width: the width of each image, in meters (used for fish length calculation) | |
| gp: a callback function which takes as input 1 parameter, (int) percent complete | |
| prep_for_marking: re-index fish for manual marking output | |
| """ | |
| real_width = image_shapes[0][0][1][0][1] | |
| real_height = image_shapes[0][0][1][0][0] | |
| if (gp): gp(0, "Formatting...") | |
| # keep predictions to feed them ordered into the Tracker | |
| # TODO: how to deal with large files? | |
| all_preds = {} | |
| with tqdm(total=len(image_shapes)*batch_size, desc="Running formatting", ncols=0, disable=not verbose) as pbar: | |
| for batch_i, batch in enumerate(outputs): | |
| if gp: gp(batch_i / len(image_shapes), pbar.__str__()) | |
| batch_shapes = image_shapes[batch_i] | |
| # Format results | |
| for si, pred in enumerate(batch): | |
| (image_shape, original_shape) = batch_shapes[si] | |
| # Clip boxes to image bounds and resize to input shape | |
| clip_boxes(pred, (height, width)) | |
| box = pred[:, :4].clone() # xyxy | |
| confs = pred[:, 4].clone().tolist() | |
| scale_boxes(image_shape, box, original_shape[0], original_shape[1]) # to original shape | |
| # get boxes into tracker input format - normalized xyxy with confidence score | |
| # confidence score currently not used by tracker; set to 1.0 | |
| boxes = None | |
| if box.shape[0]: | |
| real_width = original_shape[0][1] | |
| real_height = original_shape[0][0] | |
| do_norm = partial(norm, w=original_shape[0][1], h=original_shape[0][0]) | |
| normed = list((map(do_norm, box[:, :4].tolist()))) | |
| boxes = np.stack([ [*bb, conf] for bb, conf in zip(normed, confs) ]) | |
| frame_num = (batch_i, si) | |
| all_preds[frame_num] = boxes | |
| pbar.update(1*batch_size) | |
| return all_preds, real_width, real_height | |
| # ---------------------------------------- TRACKING ------------------------------------------ | |
| def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True): | |
| """ | |
| Perform SORT tracking based on formatted detections | |
| """ | |
| if (gp): gp(0, "Tracking...") | |
| # Initialize tracker | |
| clip_info = { | |
| 'start_frame': 0, | |
| 'end_frame': len(all_preds), | |
| 'image_meter_width': image_meter_width, | |
| 'image_meter_height': image_meter_height | |
| } | |
| tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits) | |
| # Run tracking | |
| with tqdm(total=len(all_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar: | |
| for i, key in enumerate(sorted(all_preds.keys())): | |
| if gp: gp(i / len(all_preds), pbar.__str__()) | |
| boxes = all_preds[key] | |
| if boxes is not None: | |
| tracker.update(boxes) | |
| else: | |
| tracker.update() | |
| pbar.update(1) | |
| json_data = tracker.finalize(min_length=min_length, min_travel=min_travel) | |
| return json_data | |
| def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, boost_power=1, boost_decay=1, verbose=True): | |
| """ | |
| Takes in the full YOLO detections 'inference' and formatted non-max suppressed detections 'safe_preds' | |
| and boosts the confidence of detections around identified fish that are close in space in neighbouring frames. | |
| """ | |
| if (gp): gp(0, "Confidence Boost...") | |
| # keep predictions to feed them ordered into the Tracker | |
| # TODO: how to deal with large files? | |
| boost_cutoff = 0.01 | |
| boost_range = math.floor(math.sqrt(1/boost_decay * math.log(boost_power / boost_cutoff))) | |
| boost_scale = boost_power * math.exp(-boost_decay) | |
| with tqdm(total=len(inference), desc="Running confidence boost", ncols=0, disable=not verbose) as pbar: | |
| for batch_i in range(len(inference)): | |
| if gp: gp(batch_i / len(inference), pbar.__str__()) | |
| safe = safe_preds[batch_i] | |
| infer = inference[batch_i] | |
| for i in range(len(safe)): | |
| safe_frame = safe[i] | |
| if len(safe_frame) == 0: | |
| continue | |
| next_batch = inference[batch_i + 1] if batch_i+1 < len(inference) else None | |
| prev_batch = inference[batch_i - 1] if batch_i-1 >= 0 else None | |
| for dt in range(-boost_range, boost_range+1): | |
| if dt == 0: continue | |
| idx = i+dt | |
| temp_frame = None | |
| if idx >= 0 and idx < len(infer): | |
| temp_frame = infer[idx] | |
| elif idx < 0 and prev_batch is not None and -idx >= len(prev_batch): | |
| temp_frame = prev_batch[idx] | |
| elif idx >= len(infer) and next_batch is not None and idx - len(infer) < len(next_batch): | |
| temp_frame = next_batch[idx - len(infer)] | |
| if temp_frame is not None: | |
| boost_frame(safe_frame, temp_frame, dt, power=boost_scale, decay=boost_decay) | |
| pbar.update(1*batch_size) | |
| def boost_frame(safe_frame, base_frame, dt, power=1, decay=1): | |
| """ | |
| Boosts confidence of base_frame based on confidence in safe_frame, iou, and the time difference between frames. | |
| """ | |
| safe_boxes = safe_frame[:, :4] | |
| boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)≈ | |
| # If running on CPU, you have to convert to double for the .prod() function in box_iou for some reason? | |
| if torch.cuda.is_available(): | |
| ious = box_iou(boxes, safe_boxes) | |
| else: | |
| ious = box_iou(boxes.double(), safe_boxes).float() | |
| score = torch.matmul(ious, safe_frame[:, 4]) | |
| # score = iou(safe_box, base_box) * confidence(safe_box) | |
| base_frame[:, 4] *= 1 + power*(score)*math.exp(-decay*(dt*dt-1)) | |
| return base_frame | |
| # ByteTrack | |
| def do_associative_tracking(low_preds, high_preds, image_meter_width, image_meter_height, reverse=False, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, min_travel=MIN_TRAVEL, verbose=True): | |
| if (gp): gp(0, "Tracking...") | |
| # Initialize tracker | |
| clip_info = { | |
| 'start_frame': 0, | |
| 'end_frame': len(low_preds), | |
| 'image_meter_width': image_meter_width, | |
| 'image_meter_height': image_meter_height | |
| } | |
| print("Tracking using Associate") | |
| tracker = Tracker(clip_info, algorithm=Associate, reverse=reverse, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits) | |
| # Run tracking | |
| with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar: | |
| for i, key in enumerate(sorted(low_preds.keys(), reverse=reverse)): | |
| if gp: gp(i / len(low_preds), pbar.__str__()) | |
| low_boxes = low_preds[key] | |
| high_boxes = high_preds[key] | |
| boxes = (low_boxes, high_boxes) | |
| if low_boxes is not None and high_boxes is not None: | |
| tracker.update(boxes) | |
| else: | |
| tracker.update((np.empty((0, 5)), np.empty((0, 5)))) | |
| pbar.update(1) | |
| json_data = tracker.finalize(min_length=min_length, min_travel=min_travel) | |
| return json_data | |
| def json_dump_round_float(some_object, out_path, num_digits=4): | |
| """Write a json file to disk with a specified level of precision. | |
| See: https://gist.github.com/Sukonnik-Illia/ed9b2bec1821cad437d1b8adb17406a3 | |
| """ | |
| # saving original method | |
| of = json.encoder._make_iterencode | |
| def inner(*args, **kwargs): | |
| args = list(args) | |
| # fifth argument is float formater which will we replace | |
| fmt_str = '{:.' + str(num_digits) + 'f}' | |
| args[4] = lambda o: fmt_str.format(o) | |
| return of(*args, **kwargs) | |
| with patch('json.encoder._make_iterencode', wraps=inner): | |
| return json.dump(some_object, open(out_path, 'w'), indent=2) | |
| def filter_detection_size(inference, image_meter_width, width, max_length): | |
| outputs = [] | |
| for batch in inference: | |
| print("batch") | |
| print(type(batch)) | |
| print(batch.shape) | |
| pix2width = image_meter_width/width | |
| width = batch[..., 2]*pix2width | |
| wc = width < max_length | |
| print("wc") | |
| print(type(wc)) | |
| print(wc.shape) | |
| bs = batch.shape[0] # batches | |
| output = torch.zeros((bs, 0, 6), device=batch.device) | |
| print("wc") | |
| print(batch.shape) | |
| for xi, x in enumerate(batch): | |
| x = x[wc[xi]] # confidence | |
| print(x.shape) | |
| output[xi, :, :] = x | |
| output = torch.tensor(output) | |
| print("output len", output.shape) | |
| outputs.append(output) | |
| print(len(outputs)) | |
| return outputs | |
| def non_max_suppression( | |
| prediction, | |
| image_meter_width, | |
| image_pixel_width, | |
| max_length=1.5, | |
| conf_thres=0.25, | |
| iou_thres=0.45, | |
| max_det=300 | |
| ): | |
| """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections | |
| NOTE: SIMPLIFIED FOR SINGLE CLASS DETECTION | |
| Returns: | |
| list of detections, on (n,6) tensor per image [xyxy, conf, cls] | |
| """ | |
| # Checks | |
| assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' | |
| assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' | |
| if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) | |
| prediction = prediction[0] # select only inference output | |
| device = prediction.device | |
| mps = 'mps' in device.type # Apple MPS | |
| if mps: # MPS not fully supported yet, convert tensors to CPU before NMS | |
| prediction = prediction.cpu() | |
| bs = prediction.shape[0] # batch size | |
| xc = prediction[..., 4] > conf_thres # candidates | |
| # width filter | |
| pix2width = image_meter_width/image_pixel_width | |
| width = prediction[..., 2]*pix2width | |
| if max_length > 0: | |
| wc = width < max_length | |
| else: | |
| # If max_length is 0, ignore | |
| wc = width > max_length | |
| # Settings | |
| # min_wh = 2 # (pixels) minimum box width and height | |
| max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() | |
| redundant = True # require redundant detections | |
| merge = False # use merge-NMS | |
| output = [torch.zeros((0, 6), device=prediction.device)] * bs | |
| for xi, x in enumerate(prediction): # image index, image inference | |
| # Keep boxes that pass confidence threshold | |
| x = x[xc[xi] * wc[xi]] # confidence | |
| # If none remain process next image | |
| if not x.shape[0]: | |
| continue | |
| # Compute conf | |
| x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf | |
| # Box/Mask | |
| box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2) | |
| mask = x[:, 6:] # zero columns if no masks | |
| # Detections matrix nx6 (xyxy, conf, cls) | |
| conf, j = x[:, 5:6].max(1, keepdim=True) | |
| x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] | |
| # Check shape | |
| n = x.shape[0] # number of boxes | |
| if not n: # no boxes | |
| continue | |
| x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes | |
| # Batched NMS | |
| boxes = x[:, :4] # boxes (offset by class), scores | |
| scores = x[:, 4] | |
| i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS | |
| i = i[:max_det] # limit detections | |
| if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) | |
| # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) | |
| iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix | |
| weights = iou * scores[None] # box weights | |
| x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes | |
| if redundant: | |
| i = i[iou.sum(1) > 1] # require redundancy | |
| output[xi] = x[i] | |
| if mps: | |
| output[xi] = output[xi].to(device) | |
| logging = False | |
| return output | |