from __future__ import annotations from collections import Counter from dataclasses import dataclass from typing import List import numpy as np from scipy.optimize import linear_sum_assignment from tracking.kalman import VariableSpeedKalmanFilter @dataclass class Annotation: # todo: maybe find a better name? """ Represents a single object found in a single frame (box and class) """ box: Box class_: int score: float @dataclass class TrackedAnnotation(Annotation): obj: TrackedObject @property def object_majority_class(self): return self.obj.majority_class @property def object_id(self): return self.obj.object_id class Tracker: tracked_objects: List[TrackedObject] def __init__(self, confusion_matrix, min_score_for_match=0.2, min_frames=30, max_missing_frames=8): self.tracked_objects = list() self.confusion_matrix = confusion_matrix self.min_score_for_match = min_score_for_match self.min_frames = min_frames self.max_missing_frames = max_missing_frames self.current_frame = 0 @property def active_objects(self): return [obj for obj in self.tracked_objects if obj.is_active] def advance_frame(self, new_annotations): new_annotations = list(new_annotations) for tracked_object in list(self.tracked_objects): if not tracked_object.is_active: continue tracked_object.predict_next_box() matches = self.match_objects_to_annotations(new_annotations) active_objects = self.active_objects for tracked_object, best_match in matches: tracked_object.add_new_measurement(best_match) tracked_object.missing_frame_count = 0 new_annotations.remove(best_match) active_objects.remove(tracked_object) tracked_object.annotation_history.append(tracked_object.current_annotation) for tracked_object in active_objects: tracked_object.annotation_history.append(tracked_object.current_annotation) tracked_object.missing_frame_count += 1 if tracked_object.missing_frame_count > self.max_missing_frames: tracked_object.is_active = False del tracked_object.annotation_history[-tracked_object.missing_frame_count:] for annotation in new_annotations: box = annotation.box kalmanf = VariableSpeedKalmanFilter( x_0=box.center_x, y_0=box.center_y, w_0=box.width, h_0=box.height ) tracked_obj = TrackedObject(kalmanf, annotation, start_frame=self.current_frame, object_id=len(self.tracked_objects)) self.tracked_objects.append(tracked_obj) self.current_frame += 1 return self.get_current_annotations() def advance_frames(self, raw_annotations_per_frame): for raw_annotations in raw_annotations_per_frame: self.advance_frame(raw_annotations) def match_objects_to_annotations(self, annotations): active_objects = self.active_objects score_matrix = np.zeros((len(active_objects), len(annotations))) for i, obj in enumerate(active_objects): for j, ann in enumerate(annotations): score_matrix[i, j] = self.calculate_match_score(ann, obj) obj_indices, ann_indices = linear_sum_assignment(-score_matrix) return [(active_objects[i], annotations[j]) for (i,j) in zip(obj_indices, ann_indices) if score_matrix[i, j] >= self.min_score_for_match] def finish(self): for tracked_object in list(self.tracked_objects): if tracked_object.is_active and not tracked_object.missing_frame_count > 0: tracked_object.annotation_history.pop() # remove short-lived objects self.tracked_objects = [obj for obj in self.tracked_objects if len(obj.annotation_history) >= self.min_frames] def get_current_annotations(self): return [ obj.current_annotation for obj in self.tracked_objects if obj.is_active ] def get_annotations_per_frame(self, frame_index): return [ obj.annotation_history[frame_index - obj.start_frame] for obj in self.tracked_objects if obj.start_frame <= frame_index <= obj.end_frame ] def calculate_match_score(self, annotation, tracked_object): base_score = tracked_object.predicted_next_box.iou(annotation.box) average_confusion_score = np.average([ self.confusion_matrix[annotation.class_, past_annotation.class_] for past_annotation in tracked_object.raw_annotation_history[-5:] ] + [ self.confusion_matrix[past_annotation.class_, annotation.class_] for past_annotation in tracked_object.raw_annotation_history[-5:] ]) return average_confusion_score * base_score class TrackedObject: """ A single object tracked over multiple frames """ def __init__(self, kalman_filter, init_annotation, start_frame, object_id): self.kalman_filter = kalman_filter self.missing_frame_count = 0 self.raw_annotation_history = [init_annotation] self.annotation_history = [self.current_annotation] self.is_active = True self.start_frame = start_frame self.predicted_next_box = None self.object_id = object_id def predict_next_box(self): prediction = self.kalman_filter.predict() self.predicted_next_box = Box.from_center_and_size(prediction[0], prediction[1], prediction[2], prediction[3]) return self.predicted_next_box def add_new_measurement(self, annotation): self.kalman_filter.update(annotation.box.center_and_size) self.raw_annotation_history.append(annotation) @property def current_annotation(self): return TrackedAnnotation( box=Box.from_center_and_size(*self.kalman_filter.next_x.flatten()[:4]), class_=self.raw_annotation_history[-1].class_, score=self.raw_annotation_history[-1].score, obj=self, ) @property def end_frame(self): return self.start_frame + len(self.annotation_history) - 1 @property def majority_class(self): return Counter(ann.class_ for ann in self.raw_annotation_history).most_common(1)[0][0] class Box: """ Represents a box with edges perpendicular to the x,y axes. first point is top left, second is bottom right. """ def __init__(self, x1, y1, x2, y2): self.x1 = x1 self.y1 = y1 self.x2 = x2 self.y2 = y2 @staticmethod def from_center_and_size(x, y, w, h): return Box(x - w/2, y - h/2, x + w/2, y + h/2) @staticmethod def from_top_left_and_size(x, y, w, h): return Box(x, y, x + w, y + h) @property def center_x(self): return (self.x1 + self.x2)/2 @property def center_y(self): return (self.y1 + self.y2)/2 @property def width(self): return self.x2 - self.x1 @property def height(self): return self.y2 - self.y1 @property def center_and_size(self): return np.array([self.center_x, self.center_y, self.width, self.height]) @property def area(self): return (self.x2 - self.x1) * (self.y2 - self.y1) @property def is_valid(self): return self.x2 > self.x1 and self.y2 > self.y1 def iou(self, other): """Calculate Intersection over Union with other box""" intersection_box = Box( max(self.x1, other.x1), max(self.y1, other.y1), min(self.x2, other.x2), min(self.y2, other.y2), ) if not intersection_box.is_valid: return 0 intersection_area = max(0, intersection_box.x2 - intersection_box.x1) * max(0, intersection_box.y2 - intersection_box.y1) union_area = self.area + other.area - intersection_box.area if union_area == 0: return 0.0 return intersection_area / union_area