Spaces:
Runtime error
Runtime error
File size: 8,221 Bytes
0bdfe9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from typing import List
import numpy as np
from scipy.optimize import linear_sum_assignment
from tracking.kalman import VariableSpeedKalmanFilter
@dataclass
class Annotation: # todo: maybe find a better name?
"""
Represents a single object found in a single frame (box and class)
"""
box: Box
class_: int
score: float
@dataclass
class TrackedAnnotation(Annotation):
obj: TrackedObject
@property
def object_majority_class(self):
return self.obj.majority_class
@property
def object_id(self):
return self.obj.object_id
class Tracker:
tracked_objects: List[TrackedObject]
def __init__(self, confusion_matrix, min_score_for_match=0.2, min_frames=30, max_missing_frames=8):
self.tracked_objects = list()
self.confusion_matrix = confusion_matrix
self.min_score_for_match = min_score_for_match
self.min_frames = min_frames
self.max_missing_frames = max_missing_frames
self.current_frame = 0
@property
def active_objects(self):
return [obj for obj in self.tracked_objects if obj.is_active]
def advance_frame(self, new_annotations):
new_annotations = list(new_annotations)
for tracked_object in list(self.tracked_objects):
if not tracked_object.is_active:
continue
tracked_object.predict_next_box()
matches = self.match_objects_to_annotations(new_annotations)
active_objects = self.active_objects
for tracked_object, best_match in matches:
tracked_object.add_new_measurement(best_match)
tracked_object.missing_frame_count = 0
new_annotations.remove(best_match)
active_objects.remove(tracked_object)
tracked_object.annotation_history.append(tracked_object.current_annotation)
for tracked_object in active_objects:
tracked_object.annotation_history.append(tracked_object.current_annotation)
tracked_object.missing_frame_count += 1
if tracked_object.missing_frame_count > self.max_missing_frames:
tracked_object.is_active = False
del tracked_object.annotation_history[-tracked_object.missing_frame_count:]
for annotation in new_annotations:
box = annotation.box
kalmanf = VariableSpeedKalmanFilter(
x_0=box.center_x,
y_0=box.center_y,
w_0=box.width,
h_0=box.height
)
tracked_obj = TrackedObject(kalmanf, annotation, start_frame=self.current_frame, object_id=len(self.tracked_objects))
self.tracked_objects.append(tracked_obj)
self.current_frame += 1
return self.get_current_annotations()
def advance_frames(self, raw_annotations_per_frame):
for raw_annotations in raw_annotations_per_frame:
self.advance_frame(raw_annotations)
def match_objects_to_annotations(self, annotations):
active_objects = self.active_objects
score_matrix = np.zeros((len(active_objects), len(annotations)))
for i, obj in enumerate(active_objects):
for j, ann in enumerate(annotations):
score_matrix[i, j] = self.calculate_match_score(ann, obj)
obj_indices, ann_indices = linear_sum_assignment(-score_matrix)
return [(active_objects[i], annotations[j]) for (i,j) in zip(obj_indices, ann_indices) if score_matrix[i, j] >= self.min_score_for_match]
def finish(self):
for tracked_object in list(self.tracked_objects):
if tracked_object.is_active and not tracked_object.missing_frame_count > 0:
tracked_object.annotation_history.pop()
# remove short-lived objects
self.tracked_objects = [obj for obj in self.tracked_objects if len(obj.annotation_history) >= self.min_frames]
def get_current_annotations(self):
return [
obj.current_annotation
for obj in self.tracked_objects
if obj.is_active
]
def get_annotations_per_frame(self, frame_index):
return [
obj.annotation_history[frame_index - obj.start_frame]
for obj in self.tracked_objects
if obj.start_frame <= frame_index <= obj.end_frame
]
def calculate_match_score(self, annotation, tracked_object):
base_score = tracked_object.predicted_next_box.iou(annotation.box)
average_confusion_score = np.average([
self.confusion_matrix[annotation.class_, past_annotation.class_]
for past_annotation in tracked_object.raw_annotation_history[-5:]
] + [
self.confusion_matrix[past_annotation.class_, annotation.class_]
for past_annotation in tracked_object.raw_annotation_history[-5:]
])
return average_confusion_score * base_score
class TrackedObject:
"""
A single object tracked over multiple frames
"""
def __init__(self, kalman_filter, init_annotation, start_frame, object_id):
self.kalman_filter = kalman_filter
self.missing_frame_count = 0
self.raw_annotation_history = [init_annotation]
self.annotation_history = [self.current_annotation]
self.is_active = True
self.start_frame = start_frame
self.predicted_next_box = None
self.object_id = object_id
def predict_next_box(self):
prediction = self.kalman_filter.predict()
self.predicted_next_box = Box.from_center_and_size(prediction[0], prediction[1], prediction[2], prediction[3])
return self.predicted_next_box
def add_new_measurement(self, annotation):
self.kalman_filter.update(annotation.box.center_and_size)
self.raw_annotation_history.append(annotation)
@property
def current_annotation(self):
return TrackedAnnotation(
box=Box.from_center_and_size(*self.kalman_filter.next_x.flatten()[:4]),
class_=self.raw_annotation_history[-1].class_,
score=self.raw_annotation_history[-1].score,
obj=self,
)
@property
def end_frame(self):
return self.start_frame + len(self.annotation_history) - 1
@property
def majority_class(self):
return Counter(ann.class_ for ann in self.raw_annotation_history).most_common(1)[0][0]
class Box:
"""
Represents a box with edges perpendicular to the x,y axes.
first point is top left, second is bottom right.
"""
def __init__(self, x1, y1, x2, y2):
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
@staticmethod
def from_center_and_size(x, y, w, h):
return Box(x - w/2, y - h/2, x + w/2, y + h/2)
@staticmethod
def from_top_left_and_size(x, y, w, h):
return Box(x, y, x + w, y + h)
@property
def center_x(self):
return (self.x1 + self.x2)/2
@property
def center_y(self):
return (self.y1 + self.y2)/2
@property
def width(self):
return self.x2 - self.x1
@property
def height(self):
return self.y2 - self.y1
@property
def center_and_size(self):
return np.array([self.center_x, self.center_y, self.width, self.height])
@property
def area(self):
return (self.x2 - self.x1) * (self.y2 - self.y1)
@property
def is_valid(self):
return self.x2 > self.x1 and self.y2 > self.y1
def iou(self, other):
"""Calculate Intersection over Union with other box"""
intersection_box = Box(
max(self.x1, other.x1),
max(self.y1, other.y1),
min(self.x2, other.x2),
min(self.y2, other.y2),
)
if not intersection_box.is_valid:
return 0
intersection_area = max(0, intersection_box.x2 - intersection_box.x1) * max(0, intersection_box.y2 - intersection_box.y1)
union_area = self.area + other.area - intersection_box.area
if union_area == 0:
return 0.0
return intersection_area / union_area
|