Heat-Vision / src /tracking /tracker.py
TulkinRB's picture
Add stuff
0bdfe9d
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from typing import List
import numpy as np
from scipy.optimize import linear_sum_assignment
from tracking.kalman import VariableSpeedKalmanFilter
@dataclass
class Annotation: # todo: maybe find a better name?
"""
Represents a single object found in a single frame (box and class)
"""
box: Box
class_: int
score: float
@dataclass
class TrackedAnnotation(Annotation):
obj: TrackedObject
@property
def object_majority_class(self):
return self.obj.majority_class
@property
def object_id(self):
return self.obj.object_id
class Tracker:
tracked_objects: List[TrackedObject]
def __init__(self, confusion_matrix, min_score_for_match=0.2, min_frames=30, max_missing_frames=8):
self.tracked_objects = list()
self.confusion_matrix = confusion_matrix
self.min_score_for_match = min_score_for_match
self.min_frames = min_frames
self.max_missing_frames = max_missing_frames
self.current_frame = 0
@property
def active_objects(self):
return [obj for obj in self.tracked_objects if obj.is_active]
def advance_frame(self, new_annotations):
new_annotations = list(new_annotations)
for tracked_object in list(self.tracked_objects):
if not tracked_object.is_active:
continue
tracked_object.predict_next_box()
matches = self.match_objects_to_annotations(new_annotations)
active_objects = self.active_objects
for tracked_object, best_match in matches:
tracked_object.add_new_measurement(best_match)
tracked_object.missing_frame_count = 0
new_annotations.remove(best_match)
active_objects.remove(tracked_object)
tracked_object.annotation_history.append(tracked_object.current_annotation)
for tracked_object in active_objects:
tracked_object.annotation_history.append(tracked_object.current_annotation)
tracked_object.missing_frame_count += 1
if tracked_object.missing_frame_count > self.max_missing_frames:
tracked_object.is_active = False
del tracked_object.annotation_history[-tracked_object.missing_frame_count:]
for annotation in new_annotations:
box = annotation.box
kalmanf = VariableSpeedKalmanFilter(
x_0=box.center_x,
y_0=box.center_y,
w_0=box.width,
h_0=box.height
)
tracked_obj = TrackedObject(kalmanf, annotation, start_frame=self.current_frame, object_id=len(self.tracked_objects))
self.tracked_objects.append(tracked_obj)
self.current_frame += 1
return self.get_current_annotations()
def advance_frames(self, raw_annotations_per_frame):
for raw_annotations in raw_annotations_per_frame:
self.advance_frame(raw_annotations)
def match_objects_to_annotations(self, annotations):
active_objects = self.active_objects
score_matrix = np.zeros((len(active_objects), len(annotations)))
for i, obj in enumerate(active_objects):
for j, ann in enumerate(annotations):
score_matrix[i, j] = self.calculate_match_score(ann, obj)
obj_indices, ann_indices = linear_sum_assignment(-score_matrix)
return [(active_objects[i], annotations[j]) for (i,j) in zip(obj_indices, ann_indices) if score_matrix[i, j] >= self.min_score_for_match]
def finish(self):
for tracked_object in list(self.tracked_objects):
if tracked_object.is_active and not tracked_object.missing_frame_count > 0:
tracked_object.annotation_history.pop()
# remove short-lived objects
self.tracked_objects = [obj for obj in self.tracked_objects if len(obj.annotation_history) >= self.min_frames]
def get_current_annotations(self):
return [
obj.current_annotation
for obj in self.tracked_objects
if obj.is_active
]
def get_annotations_per_frame(self, frame_index):
return [
obj.annotation_history[frame_index - obj.start_frame]
for obj in self.tracked_objects
if obj.start_frame <= frame_index <= obj.end_frame
]
def calculate_match_score(self, annotation, tracked_object):
base_score = tracked_object.predicted_next_box.iou(annotation.box)
average_confusion_score = np.average([
self.confusion_matrix[annotation.class_, past_annotation.class_]
for past_annotation in tracked_object.raw_annotation_history[-5:]
] + [
self.confusion_matrix[past_annotation.class_, annotation.class_]
for past_annotation in tracked_object.raw_annotation_history[-5:]
])
return average_confusion_score * base_score
class TrackedObject:
"""
A single object tracked over multiple frames
"""
def __init__(self, kalman_filter, init_annotation, start_frame, object_id):
self.kalman_filter = kalman_filter
self.missing_frame_count = 0
self.raw_annotation_history = [init_annotation]
self.annotation_history = [self.current_annotation]
self.is_active = True
self.start_frame = start_frame
self.predicted_next_box = None
self.object_id = object_id
def predict_next_box(self):
prediction = self.kalman_filter.predict()
self.predicted_next_box = Box.from_center_and_size(prediction[0], prediction[1], prediction[2], prediction[3])
return self.predicted_next_box
def add_new_measurement(self, annotation):
self.kalman_filter.update(annotation.box.center_and_size)
self.raw_annotation_history.append(annotation)
@property
def current_annotation(self):
return TrackedAnnotation(
box=Box.from_center_and_size(*self.kalman_filter.next_x.flatten()[:4]),
class_=self.raw_annotation_history[-1].class_,
score=self.raw_annotation_history[-1].score,
obj=self,
)
@property
def end_frame(self):
return self.start_frame + len(self.annotation_history) - 1
@property
def majority_class(self):
return Counter(ann.class_ for ann in self.raw_annotation_history).most_common(1)[0][0]
class Box:
"""
Represents a box with edges perpendicular to the x,y axes.
first point is top left, second is bottom right.
"""
def __init__(self, x1, y1, x2, y2):
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
@staticmethod
def from_center_and_size(x, y, w, h):
return Box(x - w/2, y - h/2, x + w/2, y + h/2)
@staticmethod
def from_top_left_and_size(x, y, w, h):
return Box(x, y, x + w, y + h)
@property
def center_x(self):
return (self.x1 + self.x2)/2
@property
def center_y(self):
return (self.y1 + self.y2)/2
@property
def width(self):
return self.x2 - self.x1
@property
def height(self):
return self.y2 - self.y1
@property
def center_and_size(self):
return np.array([self.center_x, self.center_y, self.width, self.height])
@property
def area(self):
return (self.x2 - self.x1) * (self.y2 - self.y1)
@property
def is_valid(self):
return self.x2 > self.x1 and self.y2 > self.y1
def iou(self, other):
"""Calculate Intersection over Union with other box"""
intersection_box = Box(
max(self.x1, other.x1),
max(self.y1, other.y1),
min(self.x2, other.x2),
min(self.y2, other.y2),
)
if not intersection_box.is_valid:
return 0
intersection_area = max(0, intersection_box.x2 - intersection_box.x1) * max(0, intersection_box.y2 - intersection_box.y1)
union_area = self.area + other.area - intersection_box.area
if union_area == 0:
return 0.0
return intersection_area / union_area