Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| from typing import List | |
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| from tracking.kalman import VariableSpeedKalmanFilter | |
| class Annotation: # todo: maybe find a better name? | |
| """ | |
| Represents a single object found in a single frame (box and class) | |
| """ | |
| box: Box | |
| class_: int | |
| score: float | |
| class TrackedAnnotation(Annotation): | |
| obj: TrackedObject | |
| def object_majority_class(self): | |
| return self.obj.majority_class | |
| def object_id(self): | |
| return self.obj.object_id | |
| class Tracker: | |
| tracked_objects: List[TrackedObject] | |
| def __init__(self, confusion_matrix, min_score_for_match=0.2, min_frames=30, max_missing_frames=8): | |
| self.tracked_objects = list() | |
| self.confusion_matrix = confusion_matrix | |
| self.min_score_for_match = min_score_for_match | |
| self.min_frames = min_frames | |
| self.max_missing_frames = max_missing_frames | |
| self.current_frame = 0 | |
| def active_objects(self): | |
| return [obj for obj in self.tracked_objects if obj.is_active] | |
| def advance_frame(self, new_annotations): | |
| new_annotations = list(new_annotations) | |
| for tracked_object in list(self.tracked_objects): | |
| if not tracked_object.is_active: | |
| continue | |
| tracked_object.predict_next_box() | |
| matches = self.match_objects_to_annotations(new_annotations) | |
| active_objects = self.active_objects | |
| for tracked_object, best_match in matches: | |
| tracked_object.add_new_measurement(best_match) | |
| tracked_object.missing_frame_count = 0 | |
| new_annotations.remove(best_match) | |
| active_objects.remove(tracked_object) | |
| tracked_object.annotation_history.append(tracked_object.current_annotation) | |
| for tracked_object in active_objects: | |
| tracked_object.annotation_history.append(tracked_object.current_annotation) | |
| tracked_object.missing_frame_count += 1 | |
| if tracked_object.missing_frame_count > self.max_missing_frames: | |
| tracked_object.is_active = False | |
| del tracked_object.annotation_history[-tracked_object.missing_frame_count:] | |
| for annotation in new_annotations: | |
| box = annotation.box | |
| kalmanf = VariableSpeedKalmanFilter( | |
| x_0=box.center_x, | |
| y_0=box.center_y, | |
| w_0=box.width, | |
| h_0=box.height | |
| ) | |
| tracked_obj = TrackedObject(kalmanf, annotation, start_frame=self.current_frame, object_id=len(self.tracked_objects)) | |
| self.tracked_objects.append(tracked_obj) | |
| self.current_frame += 1 | |
| return self.get_current_annotations() | |
| def advance_frames(self, raw_annotations_per_frame): | |
| for raw_annotations in raw_annotations_per_frame: | |
| self.advance_frame(raw_annotations) | |
| def match_objects_to_annotations(self, annotations): | |
| active_objects = self.active_objects | |
| score_matrix = np.zeros((len(active_objects), len(annotations))) | |
| for i, obj in enumerate(active_objects): | |
| for j, ann in enumerate(annotations): | |
| score_matrix[i, j] = self.calculate_match_score(ann, obj) | |
| obj_indices, ann_indices = linear_sum_assignment(-score_matrix) | |
| return [(active_objects[i], annotations[j]) for (i,j) in zip(obj_indices, ann_indices) if score_matrix[i, j] >= self.min_score_for_match] | |
| def finish(self): | |
| for tracked_object in list(self.tracked_objects): | |
| if tracked_object.is_active and not tracked_object.missing_frame_count > 0: | |
| tracked_object.annotation_history.pop() | |
| # remove short-lived objects | |
| self.tracked_objects = [obj for obj in self.tracked_objects if len(obj.annotation_history) >= self.min_frames] | |
| def get_current_annotations(self): | |
| return [ | |
| obj.current_annotation | |
| for obj in self.tracked_objects | |
| if obj.is_active | |
| ] | |
| def get_annotations_per_frame(self, frame_index): | |
| return [ | |
| obj.annotation_history[frame_index - obj.start_frame] | |
| for obj in self.tracked_objects | |
| if obj.start_frame <= frame_index <= obj.end_frame | |
| ] | |
| def calculate_match_score(self, annotation, tracked_object): | |
| base_score = tracked_object.predicted_next_box.iou(annotation.box) | |
| average_confusion_score = np.average([ | |
| self.confusion_matrix[annotation.class_, past_annotation.class_] | |
| for past_annotation in tracked_object.raw_annotation_history[-5:] | |
| ] + [ | |
| self.confusion_matrix[past_annotation.class_, annotation.class_] | |
| for past_annotation in tracked_object.raw_annotation_history[-5:] | |
| ]) | |
| return average_confusion_score * base_score | |
| class TrackedObject: | |
| """ | |
| A single object tracked over multiple frames | |
| """ | |
| def __init__(self, kalman_filter, init_annotation, start_frame, object_id): | |
| self.kalman_filter = kalman_filter | |
| self.missing_frame_count = 0 | |
| self.raw_annotation_history = [init_annotation] | |
| self.annotation_history = [self.current_annotation] | |
| self.is_active = True | |
| self.start_frame = start_frame | |
| self.predicted_next_box = None | |
| self.object_id = object_id | |
| def predict_next_box(self): | |
| prediction = self.kalman_filter.predict() | |
| self.predicted_next_box = Box.from_center_and_size(prediction[0], prediction[1], prediction[2], prediction[3]) | |
| return self.predicted_next_box | |
| def add_new_measurement(self, annotation): | |
| self.kalman_filter.update(annotation.box.center_and_size) | |
| self.raw_annotation_history.append(annotation) | |
| def current_annotation(self): | |
| return TrackedAnnotation( | |
| box=Box.from_center_and_size(*self.kalman_filter.next_x.flatten()[:4]), | |
| class_=self.raw_annotation_history[-1].class_, | |
| score=self.raw_annotation_history[-1].score, | |
| obj=self, | |
| ) | |
| def end_frame(self): | |
| return self.start_frame + len(self.annotation_history) - 1 | |
| def majority_class(self): | |
| return Counter(ann.class_ for ann in self.raw_annotation_history).most_common(1)[0][0] | |
| class Box: | |
| """ | |
| Represents a box with edges perpendicular to the x,y axes. | |
| first point is top left, second is bottom right. | |
| """ | |
| def __init__(self, x1, y1, x2, y2): | |
| self.x1 = x1 | |
| self.y1 = y1 | |
| self.x2 = x2 | |
| self.y2 = y2 | |
| def from_center_and_size(x, y, w, h): | |
| return Box(x - w/2, y - h/2, x + w/2, y + h/2) | |
| def from_top_left_and_size(x, y, w, h): | |
| return Box(x, y, x + w, y + h) | |
| def center_x(self): | |
| return (self.x1 + self.x2)/2 | |
| def center_y(self): | |
| return (self.y1 + self.y2)/2 | |
| def width(self): | |
| return self.x2 - self.x1 | |
| def height(self): | |
| return self.y2 - self.y1 | |
| def center_and_size(self): | |
| return np.array([self.center_x, self.center_y, self.width, self.height]) | |
| def area(self): | |
| return (self.x2 - self.x1) * (self.y2 - self.y1) | |
| def is_valid(self): | |
| return self.x2 > self.x1 and self.y2 > self.y1 | |
| def iou(self, other): | |
| """Calculate Intersection over Union with other box""" | |
| intersection_box = Box( | |
| max(self.x1, other.x1), | |
| max(self.y1, other.y1), | |
| min(self.x2, other.x2), | |
| min(self.y2, other.y2), | |
| ) | |
| if not intersection_box.is_valid: | |
| return 0 | |
| intersection_area = max(0, intersection_box.x2 - intersection_box.x1) * max(0, intersection_box.y2 - intersection_box.y1) | |
| union_area = self.area + other.area - intersection_box.area | |
| if union_area == 0: | |
| return 0.0 | |
| return intersection_area / union_area | |