Spaces:
Build error
Build error
Commit ·
e5b7f6a
1
Parent(s): 768448f
major commit with first working version
Browse files- app.py +154 -4
- data/tracking_parameters/observation_variance.npy +0 -0
- data/tracking_parameters/transition_variance.npy +0 -0
- detection/__init__.py +0 -0
- detection/detect.py +18 -0
- detection/mobilenet.py +141 -0
- detection/transforms.py +135 -0
- models/.gitignore +5 -0
- packages.txt +3 -0
- requirements.txt +13 -0
- tools/__init__.py +0 -0
- tools/misc.py +201 -0
- tools/optical_flow.py +68 -0
- tools/video_readers.py +193 -0
- tracking/__init__.py +0 -0
- tracking/postprocess_and_count_tracks.py +119 -0
- tracking/track_video.py +152 -0
- tracking/trackers.py +269 -0
- tracking/utils.py +221 -0
app.py
CHANGED
|
@@ -1,7 +1,157 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import os.path as op
|
| 4 |
+
from urllib.request import urlretrieve
|
| 5 |
|
| 6 |
+
import json
|
| 7 |
+
import multiprocessing
|
| 8 |
+
from typing import Dict, List, Tuple
|
| 9 |
|
| 10 |
+
import datetime
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# imports for tracking
|
| 14 |
+
import torch
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
from tools.video_readers import IterableFrameReader, SimpleVideoReader
|
| 18 |
+
from detection.mobilenet import get_mobilenet_v3_small
|
| 19 |
+
|
| 20 |
+
from detection.detect import detect
|
| 21 |
+
from tracking.postprocess_and_count_tracks import filter_tracks, postprocess_for_api
|
| 22 |
+
from tracking.utils import get_detections_for_video, write_tracking_results_to_file, read_tracking_results, gather_tracklets, generate_video_with_annotations
|
| 23 |
+
from tracking.track_video import track_video
|
| 24 |
+
from tracking.trackers import get_tracker
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger()
|
| 27 |
+
logger.setLevel(logging.DEBUG)
|
| 28 |
+
ch = logging.StreamHandler()
|
| 29 |
+
ch.setLevel(logging.DEBUG)
|
| 30 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 31 |
+
ch.setFormatter(formatter)
|
| 32 |
+
logger.addHandler(ch)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# download model
|
| 36 |
+
FILE_MODEL = "mobilenet_v3_pretrained.pth"
|
| 37 |
+
URL_MODEL = "https://partage.imt.fr/index.php/s/sJi22N6gedN6T4q/download"
|
| 38 |
+
model_filename = op.realpath('./models/' + FILE_MODEL)
|
| 39 |
+
if not op.exists(model_filename):
|
| 40 |
+
logger.info('---Downloading model...')
|
| 41 |
+
urlretrieve(URL_MODEL, model_filename)
|
| 42 |
+
else:
|
| 43 |
+
logger.info('---Model already downloaded.')
|
| 44 |
+
|
| 45 |
+
# load model
|
| 46 |
+
print("loading model")
|
| 47 |
+
logger.info('---Loading model...')
|
| 48 |
+
|
| 49 |
+
model = get_mobilenet_v3_small(num_layers=0, heads={'hm': 1}, head_conv=256)
|
| 50 |
+
checkpoint = torch.load(model_filename, map_location="cpu")
|
| 51 |
+
model.load_state_dict(checkpoint['model'], strict=True)
|
| 52 |
+
|
| 53 |
+
class DotDict(dict):
|
| 54 |
+
"""dot.notation access to dictionary attributes"""
|
| 55 |
+
__getattr__ = dict.get
|
| 56 |
+
__setattr__ = dict.__setitem__
|
| 57 |
+
__delattr__ = dict.__delitem__
|
| 58 |
+
|
| 59 |
+
config_track = DotDict({
|
| 60 |
+
"confidence_threshold": 0.5,
|
| 61 |
+
"detection_threshold": 0.3,
|
| 62 |
+
"downsampling_factor": 4,
|
| 63 |
+
"noise_covariances_path": "data/tracking_parameters",
|
| 64 |
+
"output_shape": (960,544),
|
| 65 |
+
"skip_frames": 3, #3
|
| 66 |
+
"arch": "mobilenet_v3_small",
|
| 67 |
+
"device": "cpu",
|
| 68 |
+
"detection_batch_size": 1,
|
| 69 |
+
"display": 0,
|
| 70 |
+
"kappa": 4, #4
|
| 71 |
+
"tau": 3, #4
|
| 72 |
+
"max_length": 240
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
def track(args):
|
| 76 |
+
device = torch.device("cpu")
|
| 77 |
+
|
| 78 |
+
engine = get_tracker('EKF')
|
| 79 |
+
|
| 80 |
+
detector = lambda frame: detect(frame, threshold=args.detection_threshold, model=model)
|
| 81 |
+
|
| 82 |
+
transition_variance = np.load(os.path.join(args.noise_covariances_path, 'transition_variance.npy'))
|
| 83 |
+
observation_variance = np.load(os.path.join(args.noise_covariances_path, 'observation_variance.npy'))
|
| 84 |
+
|
| 85 |
+
logger.info(f'---Processing {args.video_path}')
|
| 86 |
+
reader = IterableFrameReader(video_filename=args.video_path,
|
| 87 |
+
skip_frames=args.skip_frames,
|
| 88 |
+
output_shape=args.output_shape,
|
| 89 |
+
progress_bar=True,
|
| 90 |
+
preload=False,
|
| 91 |
+
max_frame=args.max_length)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
input_shape = reader.input_shape
|
| 95 |
+
output_shape = reader.output_shape
|
| 96 |
+
ratio_y = input_shape[0] / (output_shape[0] // args.downsampling_factor)
|
| 97 |
+
ratio_x = input_shape[1] / (output_shape[1] // args.downsampling_factor)
|
| 98 |
+
|
| 99 |
+
logger.info('---Detecting...')
|
| 100 |
+
detections = get_detections_for_video(reader, detector, batch_size=args.detection_batch_size, device=device)
|
| 101 |
+
|
| 102 |
+
logger.info('---Tracking...')
|
| 103 |
+
display = None
|
| 104 |
+
results = track_video(reader, iter(detections), args, engine, transition_variance, observation_variance, display)
|
| 105 |
+
|
| 106 |
+
# store unfiltered results
|
| 107 |
+
datestr = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
|
| 108 |
+
output_filename = os.path.splitext(args.video_path)[0] + "_" + datestr + '_unfiltered.txt'
|
| 109 |
+
write_tracking_results_to_file(results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
|
| 110 |
+
logger.info('---Filtering...')
|
| 111 |
+
|
| 112 |
+
# read from the file
|
| 113 |
+
results = read_tracking_results(output_filename)
|
| 114 |
+
filtered_results = filter_tracks(results, config_track.kappa, config_track.tau)
|
| 115 |
+
# store filtered results
|
| 116 |
+
output_filename = os.path.splitext(args.video_path)[0] + "_" + datestr + '_filtered.txt'
|
| 117 |
+
write_tracking_results_to_file(filtered_results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
|
| 118 |
+
|
| 119 |
+
return filtered_results
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def run_model(video_path, skip, tau, kappa, maxframes):
|
| 123 |
+
logger.info('---video filename: '+ video_path)
|
| 124 |
+
|
| 125 |
+
# launch the tracking
|
| 126 |
+
config_track.video_path = video_path
|
| 127 |
+
config_track.skip_frames = int(skip)
|
| 128 |
+
config_track.tau = int(tau)
|
| 129 |
+
config_track.kappa = int(kappa)
|
| 130 |
+
config_track.max_length = int(maxframes)
|
| 131 |
+
|
| 132 |
+
output_path = "/tmp/video_out.mp4"
|
| 133 |
+
filtered_results = track(config_track)
|
| 134 |
+
|
| 135 |
+
# postprocess
|
| 136 |
+
logger.info('---Postprocessing...')
|
| 137 |
+
output_json = postprocess_for_api(filtered_results)
|
| 138 |
+
#output_json={"detected_trash":[]}
|
| 139 |
+
# build video output
|
| 140 |
+
logger.info('---Generating new video...')
|
| 141 |
+
video = SimpleVideoReader(video_path, skip_frames=0)
|
| 142 |
+
generate_video_with_annotations(video, output_json, output_path, config_track.skip_frames, logger)
|
| 143 |
+
|
| 144 |
+
return output_path #, output_json
|
| 145 |
+
|
| 146 |
+
video_in = gr.inputs.Video(type="mp4", source="upload", label="Video Upload", optional=False)
|
| 147 |
+
skip_slider = gr.inputs.Slider(minimum=0, maximum=15, step=1, default=3, label="skip frames")
|
| 148 |
+
tau_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="tau")
|
| 149 |
+
kappa_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=4, label="kappa")
|
| 150 |
+
maxframes_num = gr.inputs.Number(default=240, label="max frames at 24fps")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
gr.Interface(fn=run_model, inputs=[video_in, skip_slider, tau_slider, kappa_slider, maxframes_num],
|
| 154 |
+
outputs="playable_video",
|
| 155 |
+
title="Surfnet demo",
|
| 156 |
+
description="Upload a video, only the first seconds will be processed (use maxframe to change that)",
|
| 157 |
+
allow_screenshot=False).launch(debug=True)
|
data/tracking_parameters/observation_variance.npy
ADDED
|
Binary file (144 Bytes). View file
|
|
|
data/tracking_parameters/transition_variance.npy
ADDED
|
Binary file (144 Bytes). View file
|
|
|
detection/__init__.py
ADDED
|
File without changes
|
detection/detect.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from time import time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def nms(heat, kernel=3):
|
| 6 |
+
pad = (kernel - 1) // 2
|
| 7 |
+
|
| 8 |
+
hmax = torch.nn.functional.max_pool2d(
|
| 9 |
+
heat, (kernel, kernel), stride=1, padding=pad)
|
| 10 |
+
keep = (hmax == heat).float()
|
| 11 |
+
return heat * keep
|
| 12 |
+
|
| 13 |
+
def detect(preprocessed_frames, threshold, model):
|
| 14 |
+
|
| 15 |
+
batch_result = torch.sigmoid(model(preprocessed_frames)[-1]['hm'])
|
| 16 |
+
batch_peaks = nms(batch_result).gt(threshold).squeeze(dim=1)
|
| 17 |
+
detections = [torch.nonzero(peaks).cpu().numpy()[:,::-1] for peaks in batch_peaks]
|
| 18 |
+
return detections
|
detection/mobilenet.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) Microsoft
|
| 3 |
+
# Licensed under the MIT License.
|
| 4 |
+
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
| 5 |
+
# Modified by Xingyi Zhou
|
| 6 |
+
# ------------------------------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import
|
| 9 |
+
from __future__ import division
|
| 10 |
+
from __future__ import print_function
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torchvision.models as models
|
| 17 |
+
|
| 18 |
+
BN_MOMENTUM = 0.1
|
| 19 |
+
|
| 20 |
+
class MobiletNetHM(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self, heads, head_conv, **kwargs):
|
| 23 |
+
self.inplanes = 576
|
| 24 |
+
self.deconv_with_bias = False
|
| 25 |
+
self.heads = heads
|
| 26 |
+
|
| 27 |
+
super(MobiletNetHM, self).__init__()
|
| 28 |
+
self.features = models.mobilenet_v3_small(pretrained=True).features
|
| 29 |
+
|
| 30 |
+
# used for deconv layers
|
| 31 |
+
self.deconv_layers = self._make_deconv_layer(
|
| 32 |
+
3,
|
| 33 |
+
[256, 256, 256],
|
| 34 |
+
[4, 4, 4],
|
| 35 |
+
)
|
| 36 |
+
# self.final_layer = []
|
| 37 |
+
|
| 38 |
+
for head in sorted(self.heads):
|
| 39 |
+
num_output = self.heads[head]
|
| 40 |
+
if head_conv > 0:
|
| 41 |
+
fc = nn.Sequential(
|
| 42 |
+
nn.Conv2d(256, head_conv,
|
| 43 |
+
kernel_size=3, padding=1, bias=True),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.Conv2d(head_conv, num_output,
|
| 46 |
+
kernel_size=1, stride=1, padding=0))
|
| 47 |
+
else:
|
| 48 |
+
fc = nn.Conv2d(
|
| 49 |
+
in_channels=256,
|
| 50 |
+
out_channels=num_output,
|
| 51 |
+
kernel_size=1,
|
| 52 |
+
stride=1,
|
| 53 |
+
padding=0
|
| 54 |
+
)
|
| 55 |
+
self.__setattr__(head, fc)
|
| 56 |
+
|
| 57 |
+
# self.final_layer = nn.ModuleList(self.final_layer)
|
| 58 |
+
|
| 59 |
+
def _get_deconv_cfg(self, deconv_kernel, index):
|
| 60 |
+
if deconv_kernel == 4:
|
| 61 |
+
padding = 1
|
| 62 |
+
output_padding = 0
|
| 63 |
+
elif deconv_kernel == 3:
|
| 64 |
+
padding = 1
|
| 65 |
+
output_padding = 1
|
| 66 |
+
elif deconv_kernel == 2:
|
| 67 |
+
padding = 0
|
| 68 |
+
output_padding = 0
|
| 69 |
+
|
| 70 |
+
return deconv_kernel, padding, output_padding
|
| 71 |
+
|
| 72 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
| 73 |
+
assert num_layers == len(num_filters), \
|
| 74 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 75 |
+
assert num_layers == len(num_kernels), \
|
| 76 |
+
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
| 77 |
+
|
| 78 |
+
layers = []
|
| 79 |
+
for i in range(num_layers):
|
| 80 |
+
kernel, padding, output_padding = \
|
| 81 |
+
self._get_deconv_cfg(num_kernels[i], i)
|
| 82 |
+
|
| 83 |
+
planes = num_filters[i]
|
| 84 |
+
layers.append(
|
| 85 |
+
nn.ConvTranspose2d(
|
| 86 |
+
in_channels=self.inplanes,
|
| 87 |
+
out_channels=planes,
|
| 88 |
+
kernel_size=kernel,
|
| 89 |
+
stride=2,
|
| 90 |
+
padding=padding,
|
| 91 |
+
output_padding=output_padding,
|
| 92 |
+
bias=self.deconv_with_bias))
|
| 93 |
+
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
| 94 |
+
layers.append(nn.ReLU(inplace=True))
|
| 95 |
+
self.inplanes = planes
|
| 96 |
+
|
| 97 |
+
return nn.Sequential(*layers)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
x = self.features(x)
|
| 101 |
+
x = self.deconv_layers(x)
|
| 102 |
+
ret = {}
|
| 103 |
+
for head in self.heads:
|
| 104 |
+
ret[head] = self.__getattr__(head)(x)
|
| 105 |
+
return [ret]
|
| 106 |
+
|
| 107 |
+
def init_weights(self):
|
| 108 |
+
# print('=> init resnet deconv weights from normal distribution')
|
| 109 |
+
for _, m in self.deconv_layers.named_modules():
|
| 110 |
+
if isinstance(m, nn.ConvTranspose2d):
|
| 111 |
+
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
|
| 112 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 113 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 114 |
+
if self.deconv_with_bias:
|
| 115 |
+
nn.init.constant_(m.bias, 0)
|
| 116 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 117 |
+
# print('=> init {}.weight as 1'.format(name))
|
| 118 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 119 |
+
nn.init.constant_(m.weight, 1)
|
| 120 |
+
nn.init.constant_(m.bias, 0)
|
| 121 |
+
# print('=> init final conv weights from normal distribution')
|
| 122 |
+
for head in self.heads:
|
| 123 |
+
final_layer = self.__getattr__(head)
|
| 124 |
+
for i, m in enumerate(final_layer.modules()):
|
| 125 |
+
if isinstance(m, nn.Conv2d):
|
| 126 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 127 |
+
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
|
| 128 |
+
# print('=> init {}.bias as 0'.format(name))
|
| 129 |
+
if m.weight.shape[0] == self.heads[head]:
|
| 130 |
+
if 'hm' in head:
|
| 131 |
+
nn.init.constant_(m.bias, -2.19)
|
| 132 |
+
else:
|
| 133 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 134 |
+
nn.init.constant_(m.bias, 0)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_mobilenet_v3_small(num_layers, heads, head_conv):
|
| 138 |
+
|
| 139 |
+
model = MobiletNetHM(heads, head_conv=head_conv)
|
| 140 |
+
model.init_weights()
|
| 141 |
+
return model
|
detection/transforms.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imgaug as ia
|
| 2 |
+
from imgaug import augmenters as iaa
|
| 3 |
+
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tools.misc import blob_for_bbox
|
| 6 |
+
import torch
|
| 7 |
+
ia.seed(1)
|
| 8 |
+
from torchvision.transforms import functional as F
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
class Compose(object):
|
| 13 |
+
def __init__(self, transforms):
|
| 14 |
+
self.transforms = transforms
|
| 15 |
+
|
| 16 |
+
def __call__(self, image, target):
|
| 17 |
+
for t in self.transforms:
|
| 18 |
+
image, target = t(image, target)
|
| 19 |
+
return image, target
|
| 20 |
+
|
| 21 |
+
class ToTensorBboxes(object):
|
| 22 |
+
|
| 23 |
+
def __init__(self, num_classes, downsampling_factor):
|
| 24 |
+
self.num_classes = num_classes
|
| 25 |
+
self.downsampling_factor = downsampling_factor
|
| 26 |
+
|
| 27 |
+
def __call__(self, image, bboxes):
|
| 28 |
+
h,w = image.shape[:-1]
|
| 29 |
+
image = F.to_tensor(image)
|
| 30 |
+
if self.downsampling_factor is not None:
|
| 31 |
+
blobs = np.zeros(shape=(self.num_classes + 2, h // self.downsampling_factor, w // self.downsampling_factor))
|
| 32 |
+
else:
|
| 33 |
+
blobs = np.zeros(shape=(self.num_classes + 2, h, w))
|
| 34 |
+
|
| 35 |
+
for bbox_imgaug in bboxes:
|
| 36 |
+
cat = bbox_imgaug.label-1
|
| 37 |
+
bbox = [bbox_imgaug.x1, bbox_imgaug.y1, bbox_imgaug.width, bbox_imgaug.height]
|
| 38 |
+
|
| 39 |
+
new_blobs, ct_int = blob_for_bbox(bbox, blobs[cat], self.downsampling_factor)
|
| 40 |
+
blobs[cat] = new_blobs
|
| 41 |
+
if ct_int is not None:
|
| 42 |
+
ct_x, ct_y = ct_int
|
| 43 |
+
if ct_x < blobs.shape[2] and ct_y < blobs.shape[1]:
|
| 44 |
+
blobs[-2, ct_y, ct_x] = bbox[3]
|
| 45 |
+
blobs[-1, ct_y, ct_x] = bbox[2]
|
| 46 |
+
# else:
|
| 47 |
+
# import matplotlib.pyplot as plt
|
| 48 |
+
# print(ct_x, ct_y)
|
| 49 |
+
# fig, ax = plt.subplots(1,1,figsize=(20,20))
|
| 50 |
+
# ax.imshow(blobs[cat])
|
| 51 |
+
# import pickle
|
| 52 |
+
# with open('verbose.pickle','wb') as f:
|
| 53 |
+
# pickle.dump((fig,ax),f)
|
| 54 |
+
# plt.close()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
target = torch.from_numpy(blobs)
|
| 58 |
+
return image, target
|
| 59 |
+
|
| 60 |
+
class Normalize(object):
|
| 61 |
+
def __init__(self, mean, std):
|
| 62 |
+
self.mean = mean
|
| 63 |
+
self.std = std
|
| 64 |
+
|
| 65 |
+
def __call__(self, image, target):
|
| 66 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 67 |
+
return image, target
|
| 68 |
+
|
| 69 |
+
class TrainTransforms:
|
| 70 |
+
def __init__(self, base_size, crop_size, num_classes, downsampling_factor, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
| 71 |
+
self.num_classes = num_classes
|
| 72 |
+
self.downsampling_factor = downsampling_factor
|
| 73 |
+
self.base_size = base_size
|
| 74 |
+
self.crop_height, self.crop_width = crop_size
|
| 75 |
+
self.hflip_prob = hflip_prob
|
| 76 |
+
self.random_size_range = (int(self.base_size),int(2.0*self.base_size))
|
| 77 |
+
self.seq = iaa.Sequential([
|
| 78 |
+
iaa.Resize({"height": self.random_size_range, "width": "keep-aspect-ratio"}),
|
| 79 |
+
iaa.Fliplr(p=self.hflip_prob),
|
| 80 |
+
iaa.PadToFixedSize(width=self.crop_width, height=self.crop_height),
|
| 81 |
+
iaa.CropToFixedSize(width=self.crop_width, height=self.crop_height)
|
| 82 |
+
])
|
| 83 |
+
self.last_transforms = Compose([ToTensorBboxes(num_classes, downsampling_factor),
|
| 84 |
+
Normalize(mean=mean,std=std)])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def __call__(self, img, target):
|
| 89 |
+
|
| 90 |
+
bboxes_imgaug = [BoundingBox(x1=bbox[0], y1=bbox[1], x2=bbox[0]+bbox[2], y2=bbox[1]+bbox[3], label=cat) \
|
| 91 |
+
for bbox, cat in zip(target['bboxes'],target['cats'])]
|
| 92 |
+
bboxes = BoundingBoxesOnImage(bboxes_imgaug, shape=img.shape)
|
| 93 |
+
|
| 94 |
+
img, bboxes_imgaug = self.seq(image=img, bounding_boxes=bboxes)
|
| 95 |
+
return self.last_transforms(img, bboxes_imgaug)
|
| 96 |
+
|
| 97 |
+
class ValTransforms:
|
| 98 |
+
def __init__(self, base_size, crop_size, num_classes, downsampling_factor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
| 99 |
+
self.num_classes = num_classes
|
| 100 |
+
self.downsampling_factor = downsampling_factor
|
| 101 |
+
self.base_size = base_size
|
| 102 |
+
self.crop_height, self.crop_width = crop_size
|
| 103 |
+
self.seq = iaa.Sequential([
|
| 104 |
+
iaa.Resize({"height": int(self.base_size), "width": "keep-aspect-ratio"}),
|
| 105 |
+
# iaa.Rotate((-45,45)),
|
| 106 |
+
iaa.CenterPadToFixedSize(width=self.crop_width, height=self.crop_height),
|
| 107 |
+
iaa.CenterCropToFixedSize(width=self.crop_width, height=self.crop_height)
|
| 108 |
+
])
|
| 109 |
+
self.last_transforms = Compose([ToTensorBboxes(num_classes, downsampling_factor),
|
| 110 |
+
Normalize(mean=mean,std=std)])
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def __call__(self, img, target):
|
| 115 |
+
|
| 116 |
+
bboxes_imgaug = [BoundingBox(x1=bbox[0], y1=bbox[1], x2=bbox[0]+bbox[2], y2=bbox[1]+bbox[3], label=cat) \
|
| 117 |
+
for bbox, cat in zip(target['bboxes'],target['cats'])]
|
| 118 |
+
bboxes = BoundingBoxesOnImage(bboxes_imgaug, shape=img.shape)
|
| 119 |
+
|
| 120 |
+
img, bboxes_imgaug = self.seq(image=img, bounding_boxes=bboxes)
|
| 121 |
+
return self.last_transforms(img, bboxes_imgaug)
|
| 122 |
+
|
| 123 |
+
class TransformFrames:
|
| 124 |
+
def __init__(self):
|
| 125 |
+
transforms = []
|
| 126 |
+
|
| 127 |
+
transforms.append(T.Lambda(lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB)))
|
| 128 |
+
transforms.append(T.ToTensor())
|
| 129 |
+
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
|
| 130 |
+
std=[0.229, 0.224, 0.225]))
|
| 131 |
+
|
| 132 |
+
self.transforms = T.Compose(transforms)
|
| 133 |
+
|
| 134 |
+
def __call__(self, img):
|
| 135 |
+
return self.transforms(img)
|
models/.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore everything in this directory
|
| 2 |
+
*
|
| 3 |
+
# Except this file
|
| 4 |
+
!.gitignore
|
| 5 |
+
!download_pretrained_base.sh
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
x264
|
| 3 |
+
libx264-dev
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
scikit-image
|
| 2 |
+
opencv-python
|
| 3 |
+
pytorch
|
| 4 |
+
torchvision
|
| 5 |
+
pycocotools
|
| 6 |
+
debugpy
|
| 7 |
+
scipy
|
| 8 |
+
tqdm
|
| 9 |
+
tensorboard
|
| 10 |
+
imgaug
|
| 11 |
+
psycopg2-binary
|
| 12 |
+
moviepy
|
| 13 |
+
git+https://github.com/pykalman/pykalman.git
|
tools/__init__.py
ADDED
|
File without changes
|
tools/misc.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matplotlib.pyplot import grid
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms.functional as F
|
| 7 |
+
|
| 8 |
+
class ResizeForCenterNet(object):
|
| 9 |
+
def __init__(self, fix_res=False):
|
| 10 |
+
self.fix_res = fix_res
|
| 11 |
+
|
| 12 |
+
def __call__(self, image):
|
| 13 |
+
if self.fix_res:
|
| 14 |
+
new_h = 512
|
| 15 |
+
new_w = 512
|
| 16 |
+
else:
|
| 17 |
+
w, h = image.size
|
| 18 |
+
new_h = (h | 31) + 1
|
| 19 |
+
new_w = (w | 31) + 1
|
| 20 |
+
image = F.resize(image, (new_h, new_w))
|
| 21 |
+
return image
|
| 22 |
+
|
| 23 |
+
def gaussian_radius(det_size, min_overlap=0.7):
|
| 24 |
+
height, width = det_size
|
| 25 |
+
|
| 26 |
+
a1 = 1
|
| 27 |
+
b1 = (height + width)
|
| 28 |
+
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
|
| 29 |
+
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
|
| 30 |
+
r1 = (b1 + sq1) / 2
|
| 31 |
+
|
| 32 |
+
a2 = 4
|
| 33 |
+
b2 = 2 * (height + width)
|
| 34 |
+
c2 = (1 - min_overlap) * width * height
|
| 35 |
+
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
|
| 36 |
+
r2 = (b2 + sq2) / 2
|
| 37 |
+
|
| 38 |
+
a3 = 4 * min_overlap
|
| 39 |
+
b3 = -2 * min_overlap * (height + width)
|
| 40 |
+
c3 = (min_overlap - 1) * width * height
|
| 41 |
+
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
|
| 42 |
+
r3 = (b3 + sq3) / 2
|
| 43 |
+
return min(r1, r2, r3)
|
| 44 |
+
|
| 45 |
+
def gaussian2D(shape, sigma=1):
|
| 46 |
+
m, n = [(ss - 1.) / 2. for ss in shape]
|
| 47 |
+
y, x = np.ogrid[-m:m+1,-n:n+1]
|
| 48 |
+
|
| 49 |
+
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
|
| 50 |
+
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
| 51 |
+
return h
|
| 52 |
+
|
| 53 |
+
def draw_umich_gaussian(heatmap, center, radius, k=1):
|
| 54 |
+
diameter = 2 * radius + 1
|
| 55 |
+
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
|
| 56 |
+
|
| 57 |
+
x, y = int(center[0]), int(center[1])
|
| 58 |
+
|
| 59 |
+
height, width = heatmap.shape[0:2]
|
| 60 |
+
|
| 61 |
+
left, right = min(x, radius), min(width - x, radius + 1)
|
| 62 |
+
top, bottom = min(y, radius), min(height - y, radius + 1)
|
| 63 |
+
|
| 64 |
+
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
|
| 65 |
+
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
|
| 66 |
+
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
|
| 67 |
+
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
|
| 68 |
+
|
| 69 |
+
return heatmap
|
| 70 |
+
|
| 71 |
+
def blob_for_bbox(bbox, heatmap, downsampling_factor=None):
|
| 72 |
+
if downsampling_factor is not None:
|
| 73 |
+
left, top, w, h = [bbox_coord // downsampling_factor for bbox_coord in bbox]
|
| 74 |
+
else:
|
| 75 |
+
left, top, w, h = [bbox_coord for bbox_coord in bbox]
|
| 76 |
+
|
| 77 |
+
right, bottom = left+w, top+h
|
| 78 |
+
ct_int = None
|
| 79 |
+
if h > 0 and w > 0:
|
| 80 |
+
radius = gaussian_radius((math.ceil(h), math.ceil(w)))
|
| 81 |
+
radius = max(0, int(radius))
|
| 82 |
+
ct = np.array([(left + right) / 2, (top + bottom) / 2], dtype=np.float32)
|
| 83 |
+
ct_int = ct.astype(np.int32)
|
| 84 |
+
heatmap = draw_umich_gaussian(heatmap, ct_int, radius)
|
| 85 |
+
return heatmap, ct_int
|
| 86 |
+
|
| 87 |
+
def pre_process_centernet(image, meta=None, fix_res=True):
|
| 88 |
+
scale = 1.0
|
| 89 |
+
mean = [0.408, 0.447, 0.47]
|
| 90 |
+
std = [0.289, 0.274, 0.278]
|
| 91 |
+
height, width = image.shape[0:2]
|
| 92 |
+
new_height = int(height * scale)
|
| 93 |
+
new_width = int(width * scale)
|
| 94 |
+
if fix_res:
|
| 95 |
+
inp_height, inp_width = 512, 512
|
| 96 |
+
c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
|
| 97 |
+
s = max(height, width) * 1.0
|
| 98 |
+
else:
|
| 99 |
+
inp_height = (new_height | 31) + 1
|
| 100 |
+
inp_width = (new_width | 31) + 1
|
| 101 |
+
c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
|
| 102 |
+
s = np.array([inp_width, inp_height], dtype=np.float32)
|
| 103 |
+
|
| 104 |
+
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
|
| 105 |
+
resized_image = cv2.resize(image, (new_width, new_height))
|
| 106 |
+
inp_image = cv2.warpAffine(
|
| 107 |
+
resized_image, trans_input, (inp_width, inp_height),
|
| 108 |
+
flags=cv2.INTER_LINEAR)
|
| 109 |
+
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)
|
| 110 |
+
|
| 111 |
+
images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width)
|
| 112 |
+
# if self.opt.flip_test:
|
| 113 |
+
# images = np.concatenate((images, images[:, :, :, ::-1]), axis=0)
|
| 114 |
+
images = torch.from_numpy(images)
|
| 115 |
+
# meta = {'c': c, 's': s,
|
| 116 |
+
# 'out_height': inp_height // self.opt.down_ratio,
|
| 117 |
+
# 'out_width': inp_width // self.opt.down_ratio}
|
| 118 |
+
return images.squeeze() #, meta
|
| 119 |
+
|
| 120 |
+
def get_affine_transform(center,
|
| 121 |
+
scale,
|
| 122 |
+
rot,
|
| 123 |
+
output_size,
|
| 124 |
+
shift=np.array([0, 0], dtype=np.float32),
|
| 125 |
+
inv=0):
|
| 126 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
| 127 |
+
scale = np.array([scale, scale], dtype=np.float32)
|
| 128 |
+
|
| 129 |
+
scale_tmp = scale
|
| 130 |
+
src_w = scale_tmp[0]
|
| 131 |
+
dst_w = output_size[0]
|
| 132 |
+
dst_h = output_size[1]
|
| 133 |
+
|
| 134 |
+
rot_rad = np.pi * rot / 180
|
| 135 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
| 136 |
+
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
| 137 |
+
|
| 138 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 139 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 140 |
+
src[0, :] = center + scale_tmp * shift
|
| 141 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
| 142 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
| 143 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir
|
| 144 |
+
|
| 145 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
| 146 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
| 147 |
+
|
| 148 |
+
if inv:
|
| 149 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
| 150 |
+
else:
|
| 151 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 152 |
+
|
| 153 |
+
return trans
|
| 154 |
+
|
| 155 |
+
def get_dir(src_point, rot_rad):
|
| 156 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 157 |
+
|
| 158 |
+
src_result = [0, 0]
|
| 159 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
| 160 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
| 161 |
+
|
| 162 |
+
return src_result
|
| 163 |
+
|
| 164 |
+
def get_3rd_point(a, b):
|
| 165 |
+
direct = a - b
|
| 166 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
| 167 |
+
|
| 168 |
+
def load_checkpoint(model, trained_model_weights_filename):
|
| 169 |
+
checkpoint = torch.load(trained_model_weights_filename, map_location='cpu')
|
| 170 |
+
model.load_state_dict(checkpoint['model'])
|
| 171 |
+
return model
|
| 172 |
+
|
| 173 |
+
def load_model(arch, model_weights, device):
|
| 174 |
+
|
| 175 |
+
if model_weights is None:
|
| 176 |
+
if arch == 'mobilenet_v3_small':
|
| 177 |
+
model_weights = 'models/mobilenet_v3_pretrained.pth'
|
| 178 |
+
arch = 'mobilenetv3small'
|
| 179 |
+
elif arch == 'res_18':
|
| 180 |
+
model_weights = 'models/res18_pretrained.pth'
|
| 181 |
+
elif arch == 'dla_34':
|
| 182 |
+
model_weights = 'models/dla_34_pretrained.pth'
|
| 183 |
+
|
| 184 |
+
heads = {'hm':1} if arch != 'dla_34' else {'hm':1, 'wh':2}
|
| 185 |
+
|
| 186 |
+
model = create_base(arch, heads=heads, head_conv=256).to(device)
|
| 187 |
+
model = load_checkpoint(model, model_weights)
|
| 188 |
+
for param in model.parameters():
|
| 189 |
+
param.requires_grad = False
|
| 190 |
+
model.eval()
|
| 191 |
+
|
| 192 |
+
return model
|
| 193 |
+
|
| 194 |
+
def _calculate_euclidean_similarity(distances, zero_distance):
|
| 195 |
+
""" Calculates the euclidean distance between two sets of detections, and then converts this into a similarity
|
| 196 |
+
measure with values between 0 and 1 using the following formula: sim = max(0, 1 - dist/zero_distance).
|
| 197 |
+
The default zero_distance of 2.0, corresponds to the default used in MOT15_3D, such that a 0.5 similarity
|
| 198 |
+
threshold corresponds to a 1m distance threshold for TPs.
|
| 199 |
+
"""
|
| 200 |
+
sim = np.maximum(0, 1 - distances/zero_distance)
|
| 201 |
+
return sim
|
tools/optical_flow.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# params for ShiTomasi corner detection
|
| 5 |
+
feature_params = dict( maxCorners = 100,
|
| 6 |
+
qualityLevel = 0.3,
|
| 7 |
+
minDistance = 2,
|
| 8 |
+
blockSize = 7 )
|
| 9 |
+
# Parameters for lucas kanade optical flow
|
| 10 |
+
lk_params = dict( winSize = (15,15),
|
| 11 |
+
maxLevel = 20,
|
| 12 |
+
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
|
| 13 |
+
|
| 14 |
+
def viz_dense_flow(frame_reader, nb_frames_to_process):
|
| 15 |
+
|
| 16 |
+
ret, frame1 = frame_reader.read_frame()
|
| 17 |
+
prvs = cv2.cvtColor(frame1,cv2.COLOR_BGR2GRAY)
|
| 18 |
+
hsv = np.zeros_like(frame1)
|
| 19 |
+
hsv[...,1] = 0
|
| 20 |
+
|
| 21 |
+
while(frame_reader.nb_frames_read < nb_frames_to_process):
|
| 22 |
+
ret, frame2 = frame_reader.read_frame()
|
| 23 |
+
next = cv2.cvtColor(frame2,cv2.COLOR_BGR2GRAY)
|
| 24 |
+
flow = cv2.calcOpticalFlowFarneback(prvs,next, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
| 25 |
+
mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1])
|
| 26 |
+
hsv[...,0] = ang*180/np.pi/2
|
| 27 |
+
hsv[...,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)
|
| 28 |
+
bgr = cv2.cvtColor(hsv,cv2.COLOR_HSV2BGR)
|
| 29 |
+
cv2.imshow('frame',np.concatenate([frame2, bgr], axis=0))
|
| 30 |
+
k = cv2.waitKey(0) & 0xff
|
| 31 |
+
prvs = next
|
| 32 |
+
cv2.destroyAllWindows()
|
| 33 |
+
|
| 34 |
+
def flow_opencv_dense(img, img2):
|
| 35 |
+
|
| 36 |
+
prvs = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 37 |
+
next = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY)
|
| 38 |
+
|
| 39 |
+
flow = cv2.calcOpticalFlowFarneback(prvs,next, None, 0.5, 3, 15, 3, 5, 1.2, 0)
|
| 40 |
+
|
| 41 |
+
return flow
|
| 42 |
+
|
| 43 |
+
def flow_opencv_sparse(img, img2, p0):
|
| 44 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 45 |
+
img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
| 46 |
+
|
| 47 |
+
p1, st, err = cv2.calcOpticalFlowPyrLK(img_gray, img2_gray, p0, None, **lk_params)
|
| 48 |
+
|
| 49 |
+
# Select good points
|
| 50 |
+
good_new = p1[st==1]
|
| 51 |
+
|
| 52 |
+
return good_new
|
| 53 |
+
|
| 54 |
+
def dense_flow_norm(dense_flow):
|
| 55 |
+
v, u = dense_flow[...,1], dense_flow[...,0]
|
| 56 |
+
return np.sqrt(u ** 2 + v ** 2)
|
| 57 |
+
|
| 58 |
+
def compute_flow(frame0, frame1, downsampling_factor):
|
| 59 |
+
|
| 60 |
+
if downsampling_factor > 1:
|
| 61 |
+
h, w = frame0.shape[:-1]
|
| 62 |
+
new_h = h // downsampling_factor
|
| 63 |
+
new_w = w // downsampling_factor
|
| 64 |
+
|
| 65 |
+
frame0 = cv2.resize(frame0, (new_w, new_h))
|
| 66 |
+
frame1 = cv2.resize(frame1, (new_w, new_h))
|
| 67 |
+
|
| 68 |
+
return flow_opencv_dense(frame0, frame1)
|
tools/video_readers.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from itertools import cycle
|
| 5 |
+
class AdvancedFrameReader:
|
| 6 |
+
def __init__(self, video_name, read_every, rescale_factor, init_time_min, init_time_s):
|
| 7 |
+
|
| 8 |
+
self.cap = cv2.VideoCapture(video_name)
|
| 9 |
+
|
| 10 |
+
self.original_width = self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)
|
| 11 |
+
self.original_height = self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
| 12 |
+
|
| 13 |
+
if rescale_factor != 1:
|
| 14 |
+
self.set_rescale_factor(rescale_factor)
|
| 15 |
+
else:
|
| 16 |
+
self.original_shape_mode = True
|
| 17 |
+
|
| 18 |
+
self.init_rescale_factor = rescale_factor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
self.frame_skip = read_every - 1
|
| 22 |
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)/read_every
|
| 23 |
+
print(f'Reading at {self.fps:.2f} fps')
|
| 24 |
+
|
| 25 |
+
self.set_time_position(init_time_min, init_time_s)
|
| 26 |
+
self.init_frame = self.cap.get(cv2.CAP_PROP_POS_FRAMES)
|
| 27 |
+
self.total_num_frames = self.cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def post_process(self, ret, frame):
|
| 31 |
+
if ret:
|
| 32 |
+
if self.original_shape_mode:
|
| 33 |
+
return ret, frame
|
| 34 |
+
else:
|
| 35 |
+
return ret, cv2.resize(frame, self.new_shape)
|
| 36 |
+
else:
|
| 37 |
+
return ret, []
|
| 38 |
+
|
| 39 |
+
def skip(self):
|
| 40 |
+
if not self.nb_frames_read:
|
| 41 |
+
return
|
| 42 |
+
else:
|
| 43 |
+
for _ in range(self.frame_skip):
|
| 44 |
+
self.cap.read()
|
| 45 |
+
|
| 46 |
+
def read_frame(self):
|
| 47 |
+
|
| 48 |
+
self.skip()
|
| 49 |
+
ret, frame = self.cap.read()
|
| 50 |
+
self.nb_frames_read += 1
|
| 51 |
+
return self.post_process(ret, frame)
|
| 52 |
+
|
| 53 |
+
def set_time_position(self, time_min, time_s):
|
| 54 |
+
time = 60 * time_min + time_s
|
| 55 |
+
self.cap.set(cv2.CAP_PROP_POS_MSEC, 1000 * time)
|
| 56 |
+
self.init_frame = self.cap.get(cv2.CAP_PROP_POS_FRAMES)
|
| 57 |
+
print('Reading from {}min{}sec'.format(time_min,time_s))
|
| 58 |
+
self.nb_frames_read = 0
|
| 59 |
+
|
| 60 |
+
def reset_init_frame(self):
|
| 61 |
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES,int(self.init_frame))
|
| 62 |
+
self.nb_frames_read = 0
|
| 63 |
+
|
| 64 |
+
def set_init_frame(self, init_frame):
|
| 65 |
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES, int(init_frame))
|
| 66 |
+
self.init_frame = self.cap.get(cv2.CAP_PROP_POS_FRAMES)
|
| 67 |
+
self.nb_frames_read = 0
|
| 68 |
+
|
| 69 |
+
def set_rescale_factor(self, rescale_factor):
|
| 70 |
+
width = int(self.original_width/rescale_factor)
|
| 71 |
+
height = int(self.original_height/rescale_factor)
|
| 72 |
+
self.new_shape = (width, height)
|
| 73 |
+
self.original_shape_mode = False
|
| 74 |
+
print('Reading in {}x{}'.format(width, height))
|
| 75 |
+
|
| 76 |
+
def set_original_shape_mode(self, mode):
|
| 77 |
+
self.original_shape_mode = mode
|
| 78 |
+
|
| 79 |
+
def reset_init_rescale_factor(self):
|
| 80 |
+
self.set_rescale_factor(self.init_rescale_factor)
|
| 81 |
+
|
| 82 |
+
class IterableFrameReader:
|
| 83 |
+
|
| 84 |
+
def __init__(self, video_filename, skip_frames=0, output_shape=None, progress_bar=False, preload=False, max_frame=0):
|
| 85 |
+
|
| 86 |
+
self.video = cv2.VideoCapture(video_filename)
|
| 87 |
+
self.input_shape = (self.video.get(cv2.CAP_PROP_FRAME_WIDTH), self.video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 88 |
+
self.skip_frames = skip_frames
|
| 89 |
+
self.preload = preload
|
| 90 |
+
self.total_num_frames = self.video.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 91 |
+
self.max_num_frames = min(max_frame, self.total_num_frames) if max_frame!=0 else self.total_num_frames
|
| 92 |
+
self.counter = 0
|
| 93 |
+
|
| 94 |
+
if output_shape is None:
|
| 95 |
+
w, h = self.input_shape
|
| 96 |
+
new_h = (h | 31) + 1
|
| 97 |
+
new_w = (w | 31) + 1
|
| 98 |
+
self.output_shape = (new_w, new_h)
|
| 99 |
+
else:
|
| 100 |
+
self.output_shape = output_shape
|
| 101 |
+
|
| 102 |
+
self.fps = self.video.get(cv2.CAP_PROP_FPS) / (self.skip_frames+1)
|
| 103 |
+
|
| 104 |
+
print(f'Reading video at {self.fps}fps.')
|
| 105 |
+
if progress_bar:
|
| 106 |
+
self.progress_bar = tqdm(total=int(self.max_num_frames/(self.skip_frames+1)), leave=True)
|
| 107 |
+
self.progress_bar_update = self.progress_bar.update
|
| 108 |
+
else:
|
| 109 |
+
self.progress_bar_update = lambda: None
|
| 110 |
+
|
| 111 |
+
if self.preload:
|
| 112 |
+
print('Preloading frames in RAM...')
|
| 113 |
+
self.frames = self._load_all_frames()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _load_all_frames(self):
|
| 117 |
+
frames = []
|
| 118 |
+
while True:
|
| 119 |
+
ret, frame = self._read_frame()
|
| 120 |
+
if ret:
|
| 121 |
+
frames.append(frame)
|
| 122 |
+
else: break
|
| 123 |
+
if self.progress_bar: self.progress_bar.reset()
|
| 124 |
+
return frames
|
| 125 |
+
|
| 126 |
+
def __next__(self):
|
| 127 |
+
self.counter+=1
|
| 128 |
+
if self.preload:
|
| 129 |
+
if self.counter < len(self.frames):
|
| 130 |
+
frame = self.frames[self.counter]
|
| 131 |
+
self.progress_bar_update()
|
| 132 |
+
return frame
|
| 133 |
+
else:
|
| 134 |
+
if self.counter < self.max_num_frames:
|
| 135 |
+
ret, frame = self._read_frame()
|
| 136 |
+
if ret:
|
| 137 |
+
return frame
|
| 138 |
+
print("here")
|
| 139 |
+
|
| 140 |
+
self.counter=0
|
| 141 |
+
self.video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
| 142 |
+
if self.progress_bar: self.progress_bar.reset()
|
| 143 |
+
raise StopIteration
|
| 144 |
+
|
| 145 |
+
def _read_frame(self):
|
| 146 |
+
ret, frame = self.video.read()
|
| 147 |
+
self._skip_frames()
|
| 148 |
+
if ret:
|
| 149 |
+
self.progress_bar_update()
|
| 150 |
+
frame = cv2.resize(frame, self.output_shape)
|
| 151 |
+
return ret, frame
|
| 152 |
+
|
| 153 |
+
def __iter__(self):
|
| 154 |
+
return self
|
| 155 |
+
|
| 156 |
+
def _skip_frames(self):
|
| 157 |
+
for _ in range(self.skip_frames):
|
| 158 |
+
self.counter+=1
|
| 159 |
+
self.video.read()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class SimpleVideoReader:
|
| 163 |
+
def __init__(self, video_filename, skip_frames=0):
|
| 164 |
+
self.skip_frames = skip_frames
|
| 165 |
+
self.video = cv2.VideoCapture(video_filename)
|
| 166 |
+
self.shape = (int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
| 167 |
+
self.fps = self.video.get(cv2.CAP_PROP_FPS) / (skip_frames+1)
|
| 168 |
+
self.frame_nb = 0
|
| 169 |
+
self.num_frames = self.video.get(cv2.CAP_PROP_FRAME_COUNT)
|
| 170 |
+
|
| 171 |
+
def read(self):
|
| 172 |
+
ret, frame = self.video.read()
|
| 173 |
+
self.frame_nb+=1
|
| 174 |
+
self._skip_frames()
|
| 175 |
+
return ret, frame, self.frame_nb-1
|
| 176 |
+
|
| 177 |
+
def set_frame(self, frame_nb_to_set):
|
| 178 |
+
self.video.set(cv2.CAP_PROP_POS_FRAMES, frame_nb_to_set)
|
| 179 |
+
self.frame_nb = frame_nb_to_set
|
| 180 |
+
|
| 181 |
+
def _skip_frames(self):
|
| 182 |
+
for _ in range(self.skip_frames):
|
| 183 |
+
self.video.read()
|
| 184 |
+
|
| 185 |
+
class TorchIterableFromReader(torch.utils.data.IterableDataset):
|
| 186 |
+
|
| 187 |
+
def __init__(self, reader, transforms):
|
| 188 |
+
self.transforms = transforms
|
| 189 |
+
self.reader = reader
|
| 190 |
+
|
| 191 |
+
def __iter__(self):
|
| 192 |
+
for frame in self.reader:
|
| 193 |
+
yield self.transforms(frame)
|
tracking/__init__.py
ADDED
|
File without changes
|
tracking/postprocess_and_count_tracks.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import argparse
|
| 3 |
+
from scipy.signal import convolve
|
| 4 |
+
from tracking.utils import write_tracking_results_to_file, read_tracking_results
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def filter_tracks(tracklets, kappa, tau):
|
| 10 |
+
""" filters the tracks depending on params
|
| 11 |
+
kappa: size of the moving average window
|
| 12 |
+
tau: minimum length of tracklet
|
| 13 |
+
returns raw filtered tracks
|
| 14 |
+
"""
|
| 15 |
+
if not kappa == 1:
|
| 16 |
+
tracks = filter_by_nb_consecutive_obs(tracklets, kappa, tau)
|
| 17 |
+
else: tracks = tracklets
|
| 18 |
+
results = []
|
| 19 |
+
for tracker_nb, associated_detections in enumerate(tracks):
|
| 20 |
+
for associated_detection in associated_detections:
|
| 21 |
+
results.append((associated_detection[0], tracker_nb, associated_detection[1], associated_detection[2]))
|
| 22 |
+
|
| 23 |
+
results = sorted(results, key=lambda x: x[0])
|
| 24 |
+
return results
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def postprocess_for_api(results):
|
| 28 |
+
""" Converts tracking results into json object for API
|
| 29 |
+
"""
|
| 30 |
+
result_list = []
|
| 31 |
+
id_list = {}
|
| 32 |
+
|
| 33 |
+
for res in results:
|
| 34 |
+
frame_number = res[0]
|
| 35 |
+
box = [round(res[2],1), round(res[3],1), round(res[2],1), round(res[3],1)]
|
| 36 |
+
id = res[1]
|
| 37 |
+
# if the id is not already is the results, add a new jsonline
|
| 38 |
+
if id not in id_list:
|
| 39 |
+
id_list[id] = len(result_list)
|
| 40 |
+
result_list.append({"label":"fragments",
|
| 41 |
+
"id": id,
|
| 42 |
+
"frame_to_box": {str(frame_number): box}})
|
| 43 |
+
# otherwise, retrieve the jsonline and append the box
|
| 44 |
+
else:
|
| 45 |
+
result_list[id_list[id]]["frame_to_box"][str(frame_number)] = box
|
| 46 |
+
return {"detected_trash": result_list}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def write(results, output_name):
|
| 50 |
+
""" Writes the results in two files:
|
| 51 |
+
- tracking in a Mathis format xxx_track.txt (frame, id, box_x, box_y, ...)
|
| 52 |
+
- the number of detected objects in a separate file xxx_count.txt
|
| 53 |
+
"""
|
| 54 |
+
output_tracks_filename = output_name.split('.')[0]+'_tracks.txt'
|
| 55 |
+
write_tracking_results_to_file(results, ratio_x=1, ratio_y=1,
|
| 56 |
+
output_filename=output_tracks_filename)
|
| 57 |
+
|
| 58 |
+
with open(output_name.split('.')[0]+'_count.txt','w') as out_file:
|
| 59 |
+
if len(results):
|
| 60 |
+
out_file.write(f'{max(result[1]+1 for result in results)}')
|
| 61 |
+
else:
|
| 62 |
+
out_file.write('0')
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def threshold(tracklets, tau):
|
| 66 |
+
return [tracklet for tracklet in tracklets if len(tracklet) > tau]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def compute_moving_average(tracklet, kappa):
|
| 70 |
+
if len(tracklet)==0 or len(tracklet[0])==0:
|
| 71 |
+
return tracklet
|
| 72 |
+
pad = (kappa-1)//2
|
| 73 |
+
observation_points = np.zeros(tracklet[-1][0] - tracklet[0][0] + 1)
|
| 74 |
+
first_frame_id = tracklet[0][0] - 1
|
| 75 |
+
for observation in tracklet:
|
| 76 |
+
frame_id = observation[0] - 1
|
| 77 |
+
observation_points[frame_id - first_frame_id] = 1
|
| 78 |
+
density_fill = convolve(observation_points, np.ones(kappa)/kappa, mode='same')
|
| 79 |
+
if pad>0 and len(observation_points) >= kappa:
|
| 80 |
+
density_fill[:pad] = density_fill[pad:2*pad]
|
| 81 |
+
density_fill[-pad:] = density_fill[-2*pad:-pad]
|
| 82 |
+
density_fill = observation_points * density_fill
|
| 83 |
+
|
| 84 |
+
return density_fill[density_fill > 0]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def filter_by_nb_consecutive_obs(tracklets, kappa, tau):
|
| 88 |
+
|
| 89 |
+
new_tracklets = []
|
| 90 |
+
|
| 91 |
+
for tracklet in tracklets:
|
| 92 |
+
new_tracklet = []
|
| 93 |
+
density_fill = compute_moving_average(tracklet, kappa=kappa)
|
| 94 |
+
for (observation, density_fill_value) in zip(tracklet, density_fill):
|
| 95 |
+
if density_fill_value > 0.6:
|
| 96 |
+
new_tracklet.append(observation)
|
| 97 |
+
new_tracklets.append(new_tracklet)
|
| 98 |
+
|
| 99 |
+
return threshold(new_tracklets, tau)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
|
| 104 |
+
parser = argparse.ArgumentParser()
|
| 105 |
+
parser.add_argument('--input_file',type=str)
|
| 106 |
+
parser.add_argument('--output_name',type=str)
|
| 107 |
+
parser.add_argument('--kappa',type=int)
|
| 108 |
+
parser.add_argument('--tau',type=int)
|
| 109 |
+
parser.add_argument('--output_type',type=str,default="api")
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
|
| 112 |
+
tracklets = read_tracking_results(args.input_file)
|
| 113 |
+
filtered_results = filter_tracks(tracklets, args.kappa, args.tau)
|
| 114 |
+
if args.output_type == "api":
|
| 115 |
+
output = postprocess_for_api(filtered_results)
|
| 116 |
+
with open(args.output_name, 'w') as f:
|
| 117 |
+
json.dump(output, f)
|
| 118 |
+
else:
|
| 119 |
+
write(filtered_results, args.output_name)
|
tracking/track_video.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from detection.detect import detect
|
| 5 |
+
from tracking.utils import in_frame, init_trackers
|
| 6 |
+
from tools.optical_flow import compute_flow
|
| 7 |
+
from tracking.trackers import get_tracker
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from scipy.spatial.distance import euclidean
|
| 10 |
+
from scipy.optimize import linear_sum_assignment
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Display:
|
| 15 |
+
def __init__(self, on, interactive=True):
|
| 16 |
+
self.on = on
|
| 17 |
+
self.fig, self.ax = plt.subplots()
|
| 18 |
+
self.interactive = interactive
|
| 19 |
+
if interactive:
|
| 20 |
+
plt.ion()
|
| 21 |
+
self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
| 22 |
+
self.legends = []
|
| 23 |
+
self.plot_count = 0
|
| 24 |
+
|
| 25 |
+
def display(self, trackers):
|
| 26 |
+
|
| 27 |
+
something_to_show = False
|
| 28 |
+
for tracker_nb, tracker in enumerate(trackers):
|
| 29 |
+
if tracker.enabled:
|
| 30 |
+
tracker.fill_display(self, tracker_nb)
|
| 31 |
+
something_to_show = True
|
| 32 |
+
|
| 33 |
+
self.ax.imshow(self.latest_frame_to_show)
|
| 34 |
+
|
| 35 |
+
if len(self.latest_detections):
|
| 36 |
+
self.ax.scatter(self.latest_detections[:, 0], self.latest_detections[:, 1], c='r', s=40)
|
| 37 |
+
|
| 38 |
+
if something_to_show:
|
| 39 |
+
self.ax.xaxis.tick_top()
|
| 40 |
+
plt.legend(handles=self.legends)
|
| 41 |
+
self.fig.canvas.draw()
|
| 42 |
+
if self.interactive:
|
| 43 |
+
plt.show()
|
| 44 |
+
while not plt.waitforbuttonpress():
|
| 45 |
+
continue
|
| 46 |
+
else:
|
| 47 |
+
plt.savefig(os.path.join('plots',str(self.plot_count)))
|
| 48 |
+
self.ax.cla()
|
| 49 |
+
self.legends = []
|
| 50 |
+
self.plot_count+=1
|
| 51 |
+
|
| 52 |
+
def update_detections_and_frame(self, latest_detections, frame):
|
| 53 |
+
self.latest_detections = latest_detections
|
| 54 |
+
self.latest_frame_to_show = cv2.cvtColor(cv2.resize(frame, self.display_shape), cv2.COLOR_BGR2RGB)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def build_confidence_function_for_trackers(trackers, flow01):
|
| 58 |
+
tracker_nbs = []
|
| 59 |
+
confidence_functions = []
|
| 60 |
+
for tracker_nb, tracker in enumerate(trackers):
|
| 61 |
+
if tracker.enabled:
|
| 62 |
+
tracker_nbs.append(tracker_nb)
|
| 63 |
+
confidence_functions.append(tracker.build_confidence_function(flow01))
|
| 64 |
+
return tracker_nbs, confidence_functions
|
| 65 |
+
|
| 66 |
+
def associate_detections_to_trackers(detections_for_frame, trackers, flow01, confidence_threshold):
|
| 67 |
+
tracker_nbs, confidence_functions = build_confidence_function_for_trackers(trackers, flow01)
|
| 68 |
+
assigned_trackers = [None]*len(detections_for_frame)
|
| 69 |
+
if len(tracker_nbs):
|
| 70 |
+
cost_matrix = np.zeros(shape=(len(detections_for_frame),len(tracker_nbs)))
|
| 71 |
+
for detection_nb, detection in enumerate(detections_for_frame):
|
| 72 |
+
for tracker_id, confidence_function in enumerate(confidence_functions):
|
| 73 |
+
score = confidence_function(detection)
|
| 74 |
+
if score > confidence_threshold:
|
| 75 |
+
cost_matrix[detection_nb,tracker_id] = score
|
| 76 |
+
else:
|
| 77 |
+
cost_matrix[detection_nb,tracker_id] = 0
|
| 78 |
+
row_inds, col_inds = linear_sum_assignment(cost_matrix,maximize=True)
|
| 79 |
+
for row_ind, col_ind in zip(row_inds, col_inds):
|
| 80 |
+
if cost_matrix[row_ind,col_ind] > confidence_threshold: assigned_trackers[row_ind] = tracker_nbs[col_ind]
|
| 81 |
+
|
| 82 |
+
return assigned_trackers
|
| 83 |
+
|
| 84 |
+
def track_video(reader, detections, args, engine, transition_variance, observation_variance, display):
|
| 85 |
+
init = False
|
| 86 |
+
trackers = dict()
|
| 87 |
+
frame_nb = 0
|
| 88 |
+
frame0 = next(reader)
|
| 89 |
+
detections_for_frame = next(detections)
|
| 90 |
+
|
| 91 |
+
max_distance = euclidean(reader.output_shape, np.array([0,0]))
|
| 92 |
+
delta = 0.05*max_distance
|
| 93 |
+
|
| 94 |
+
if display is not None and display.on:
|
| 95 |
+
|
| 96 |
+
display.display_shape = (reader.output_shape[0] // args.downsampling_factor, reader.output_shape[1] // args.downsampling_factor)
|
| 97 |
+
display.update_detections_and_frame(detections_for_frame, frame0)
|
| 98 |
+
|
| 99 |
+
if len(detections_for_frame):
|
| 100 |
+
trackers = init_trackers(engine, detections_for_frame, frame_nb, transition_variance, observation_variance, delta)
|
| 101 |
+
init = True
|
| 102 |
+
|
| 103 |
+
if display is not None and display.on: display.display(trackers)
|
| 104 |
+
|
| 105 |
+
for frame_nb, (frame1, detections_for_frame) in enumerate(zip(reader, detections), start=1):
|
| 106 |
+
|
| 107 |
+
if display is not None and display.on:
|
| 108 |
+
display.update_detections_and_frame(detections_for_frame, frame1)
|
| 109 |
+
|
| 110 |
+
if not init:
|
| 111 |
+
if len(detections_for_frame):
|
| 112 |
+
trackers = init_trackers(engine, detections_for_frame, frame_nb, transition_variance, observation_variance, delta)
|
| 113 |
+
init = True
|
| 114 |
+
|
| 115 |
+
else:
|
| 116 |
+
|
| 117 |
+
new_trackers = []
|
| 118 |
+
flow01 = compute_flow(frame0, frame1, args.downsampling_factor)
|
| 119 |
+
|
| 120 |
+
if len(detections_for_frame):
|
| 121 |
+
|
| 122 |
+
assigned_trackers = associate_detections_to_trackers(detections_for_frame, trackers,
|
| 123 |
+
flow01, args.confidence_threshold)
|
| 124 |
+
|
| 125 |
+
for detection, assigned_tracker in zip(detections_for_frame, assigned_trackers):
|
| 126 |
+
if in_frame(detection, flow01.shape[:-1]):
|
| 127 |
+
if assigned_tracker is None :
|
| 128 |
+
new_trackers.append(engine(frame_nb, detection, transition_variance, observation_variance, delta))
|
| 129 |
+
else:
|
| 130 |
+
trackers[assigned_tracker].update(detection, flow01, frame_nb)
|
| 131 |
+
|
| 132 |
+
for tracker in trackers:
|
| 133 |
+
tracker.update_status(flow01)
|
| 134 |
+
|
| 135 |
+
if len(new_trackers):
|
| 136 |
+
trackers.extend(new_trackers)
|
| 137 |
+
|
| 138 |
+
if display is not None and display.on:
|
| 139 |
+
display.display(trackers)
|
| 140 |
+
frame0 = frame1.copy()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
results = []
|
| 144 |
+
tracklets = [tracker.tracklet for tracker in trackers]
|
| 145 |
+
|
| 146 |
+
for tracker_nb, associated_detections in enumerate(tracklets):
|
| 147 |
+
for associated_detection in associated_detections:
|
| 148 |
+
results.append((associated_detection[0], tracker_nb, associated_detection[1][0], associated_detection[1][1]))
|
| 149 |
+
|
| 150 |
+
results = sorted(results, key=lambda x: x[0])
|
| 151 |
+
|
| 152 |
+
return results
|
tracking/trackers.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.stats import multivariate_normal
|
| 3 |
+
from tracking.utils import in_frame, exp_and_normalise, GaussianMixture
|
| 4 |
+
from pykalman import KalmanFilter, AdditiveUnscentedKalmanFilter
|
| 5 |
+
import matplotlib.patches as mpatches
|
| 6 |
+
|
| 7 |
+
class Tracker:
|
| 8 |
+
|
| 9 |
+
def __init__(self, frame_nb, X0, transition_variance, observation_variance, delta):
|
| 10 |
+
|
| 11 |
+
self.transition_covariance = np.diag(transition_variance)
|
| 12 |
+
self.observation_covariance = np.diag(observation_variance)
|
| 13 |
+
self.updated = False
|
| 14 |
+
self.steps_since_last_observation = 0
|
| 15 |
+
self.enabled = True
|
| 16 |
+
self.tracklet = [(frame_nb, X0)]
|
| 17 |
+
self.delta = delta
|
| 18 |
+
|
| 19 |
+
def store_observation(self, observation, frame_nb):
|
| 20 |
+
self.tracklet.append((frame_nb, observation))
|
| 21 |
+
self.updated = True
|
| 22 |
+
|
| 23 |
+
def update_status(self, flow):
|
| 24 |
+
if self.enabled and not self.updated:
|
| 25 |
+
self.steps_since_last_observation += 1
|
| 26 |
+
self.enabled = self.update(None, flow)
|
| 27 |
+
else:
|
| 28 |
+
self.steps_since_last_observation = 0
|
| 29 |
+
self.updated = False
|
| 30 |
+
|
| 31 |
+
def build_confidence_function(self, flow):
|
| 32 |
+
|
| 33 |
+
def confidence_from_multivariate_distribution(coord, distribution):
|
| 34 |
+
delta = self.delta
|
| 35 |
+
x = coord[0]
|
| 36 |
+
y = coord[1]
|
| 37 |
+
right_top = np.array([x+delta, y+delta])
|
| 38 |
+
left_low = np.array([x-delta, y-delta])
|
| 39 |
+
right_low = np.array([x+delta, y-delta])
|
| 40 |
+
left_top = np.array([x-delta, y+delta])
|
| 41 |
+
|
| 42 |
+
return distribution.cdf(right_top) \
|
| 43 |
+
- distribution.cdf(right_low) \
|
| 44 |
+
- distribution.cdf(left_top) \
|
| 45 |
+
+ distribution.cdf(left_low)
|
| 46 |
+
|
| 47 |
+
distribution = self.predictive_distribution(flow)
|
| 48 |
+
|
| 49 |
+
return lambda coord: confidence_from_multivariate_distribution(coord, distribution)
|
| 50 |
+
|
| 51 |
+
def get_display_colors(self, display, tracker_nb):
|
| 52 |
+
colors = display.colors
|
| 53 |
+
color = colors[tracker_nb % len(colors)]
|
| 54 |
+
display.legends.append(mpatches.Patch(color=color, label=len(self.tracklet)))
|
| 55 |
+
return colors[tracker_nb % len(colors)]
|
| 56 |
+
|
| 57 |
+
class SMC(Tracker):
|
| 58 |
+
def set_param(param):
|
| 59 |
+
SMC.n_particles = int(param)
|
| 60 |
+
|
| 61 |
+
def __init__(self, frame_nb, X0, transition_variance, observation_variance, delta):
|
| 62 |
+
super().__init__(frame_nb, X0, transition_variance, observation_variance, delta)
|
| 63 |
+
|
| 64 |
+
self.particles = multivariate_normal(
|
| 65 |
+
X0, cov=self.observation_covariance).rvs(SMC.n_particles)
|
| 66 |
+
self.normalized_weights = np.ones(SMC.n_particles)/SMC.n_particles
|
| 67 |
+
|
| 68 |
+
def update(self, observation, flow, frame_nb=None):
|
| 69 |
+
if observation is not None: self.store_observation(observation, frame_nb)
|
| 70 |
+
self.resample()
|
| 71 |
+
enabled = self.move_particles(flow)
|
| 72 |
+
if observation is not None:
|
| 73 |
+
self.importance_reweighting(observation)
|
| 74 |
+
else:
|
| 75 |
+
self.normalized_weights = np.ones(
|
| 76 |
+
len(self.particles))/len(self.particles)
|
| 77 |
+
|
| 78 |
+
return enabled
|
| 79 |
+
|
| 80 |
+
def state_transition(self, state, flow):
|
| 81 |
+
|
| 82 |
+
mean = state + \
|
| 83 |
+
flow[max(0, int(state[1])),
|
| 84 |
+
max(0, int(state[0])), :]
|
| 85 |
+
cov = np.diag(self.transition_covariance)
|
| 86 |
+
return multivariate_normal(mean, cov)
|
| 87 |
+
|
| 88 |
+
def observation(self, state):
|
| 89 |
+
|
| 90 |
+
return multivariate_normal(state, self.observation_covariance)
|
| 91 |
+
|
| 92 |
+
def move_particles(self, flow):
|
| 93 |
+
new_particles = []
|
| 94 |
+
for particle in self.particles:
|
| 95 |
+
new_particle = self.state_transition(particle, flow).rvs(1)
|
| 96 |
+
if in_frame(new_particle, flow.shape[:-1]):
|
| 97 |
+
new_particles.append(new_particle)
|
| 98 |
+
if len(new_particles):
|
| 99 |
+
self.particles = np.array(new_particles)
|
| 100 |
+
enabled = True
|
| 101 |
+
else:
|
| 102 |
+
enabled = False
|
| 103 |
+
|
| 104 |
+
return enabled
|
| 105 |
+
|
| 106 |
+
def importance_reweighting(self, observation):
|
| 107 |
+
log_weights_unnormalized = np.zeros(len(self.particles))
|
| 108 |
+
for particle_nb, particle in enumerate(self.particles):
|
| 109 |
+
log_weights_unnormalized[particle_nb] = self.observation(
|
| 110 |
+
particle).logpdf(observation)
|
| 111 |
+
self.normalized_weights = exp_and_normalise(log_weights_unnormalized)
|
| 112 |
+
|
| 113 |
+
def resample(self):
|
| 114 |
+
resampling_indices = np.random.choice(
|
| 115 |
+
a=len(self.particles), p=self.normalized_weights, size=len(self.particles))
|
| 116 |
+
self.particles = self.particles[resampling_indices]
|
| 117 |
+
|
| 118 |
+
def predictive_distribution(self, flow, nb_new_particles=5):
|
| 119 |
+
new_particles = []
|
| 120 |
+
new_weights = []
|
| 121 |
+
|
| 122 |
+
for particle, normalized_weight in zip(self.particles, self.normalized_weights):
|
| 123 |
+
new_particles_for_particle = self.state_transition(
|
| 124 |
+
particle, flow).rvs(nb_new_particles)
|
| 125 |
+
|
| 126 |
+
new_particles_for_particle = [
|
| 127 |
+
particle for particle in new_particles_for_particle if in_frame(particle, flow.shape[:-1])]
|
| 128 |
+
|
| 129 |
+
if len(new_particles_for_particle):
|
| 130 |
+
new_particles.extend(new_particles_for_particle)
|
| 131 |
+
new_weights.extend([normalized_weight/len(new_particles_for_particle)] *
|
| 132 |
+
len(new_particles_for_particle))
|
| 133 |
+
|
| 134 |
+
new_particles = np.array(new_particles)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
return GaussianMixture(new_particles, self.observation_covariance, new_weights)
|
| 139 |
+
|
| 140 |
+
def fill_display(self, display, tracker_nb):
|
| 141 |
+
color = self.get_display_colors(display, tracker_nb)
|
| 142 |
+
display.ax.scatter(self.particles[:,0], self.particles[:,1], s=5, c=color)
|
| 143 |
+
|
| 144 |
+
class EKF(Tracker):
|
| 145 |
+
|
| 146 |
+
def __init__(self, frame_nb, X0, transition_variance, observation_variance, delta):
|
| 147 |
+
super().__init__(frame_nb, X0, transition_variance, observation_variance, delta)
|
| 148 |
+
self.filter = KalmanFilter(initial_state_mean=X0,
|
| 149 |
+
initial_state_covariance=self.observation_covariance,
|
| 150 |
+
transition_covariance=self.transition_covariance,
|
| 151 |
+
observation_matrices=np.eye(2),
|
| 152 |
+
observation_covariance=self.observation_covariance)
|
| 153 |
+
|
| 154 |
+
self.filtered_state_mean = X0
|
| 155 |
+
self.filtered_state_covariance = self.observation_covariance
|
| 156 |
+
|
| 157 |
+
def get_update_parameters(self, flow):
|
| 158 |
+
|
| 159 |
+
flow_value = flow[int(self.filtered_state_mean[1]),int(self.filtered_state_mean[0]), :]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
grad_flow_value = np.array([np.gradient(flow[:,:,0]),np.gradient(flow[:,:,1])])[:,:,int(self.filtered_state_mean[1]),int(self.filtered_state_mean[0])]
|
| 163 |
+
return np.eye(2) + grad_flow_value, flow_value - grad_flow_value.dot(self.filtered_state_mean)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def EKF_step(self, observation, flow):
|
| 167 |
+
transition_matrix, transition_offset = self.get_update_parameters(flow)
|
| 168 |
+
|
| 169 |
+
return self.filter.filter_update(self.filtered_state_mean,
|
| 170 |
+
self.filtered_state_covariance,
|
| 171 |
+
transition_matrix=transition_matrix,
|
| 172 |
+
transition_offset=transition_offset,
|
| 173 |
+
observation=observation)
|
| 174 |
+
|
| 175 |
+
def update(self, observation, flow, frame_nb=None):
|
| 176 |
+
if observation is not None: self.store_observation(observation, frame_nb)
|
| 177 |
+
|
| 178 |
+
self.filtered_state_mean, self.filtered_state_covariance = self.EKF_step(observation, flow)
|
| 179 |
+
|
| 180 |
+
enabled=False if not in_frame(self.filtered_state_mean,flow.shape[:-1]) else True
|
| 181 |
+
|
| 182 |
+
return enabled
|
| 183 |
+
|
| 184 |
+
def predictive_distribution(self, flow):
|
| 185 |
+
|
| 186 |
+
filtered_state_mean, filtered_state_covariance = self.EKF_step(None, flow)
|
| 187 |
+
|
| 188 |
+
distribution = multivariate_normal(filtered_state_mean, filtered_state_covariance + self.observation_covariance)
|
| 189 |
+
|
| 190 |
+
return distribution
|
| 191 |
+
|
| 192 |
+
def fill_display(self, display, tracker_nb):
|
| 193 |
+
yy, xx = np.mgrid[0:display.display_shape[1]:1, 0:display.display_shape[0]:1]
|
| 194 |
+
pos = np.dstack((xx, yy))
|
| 195 |
+
distribution = multivariate_normal(self.filtered_state_mean, self.filtered_state_covariance)
|
| 196 |
+
|
| 197 |
+
color = self.get_display_colors(display, tracker_nb)
|
| 198 |
+
cs = display.ax.contour(distribution.pdf(pos), colors=color)
|
| 199 |
+
display.ax.clabel(cs, inline=True, fontsize='large')
|
| 200 |
+
display.ax.scatter(self.filtered_state_mean[0], self.filtered_state_mean[1], color=color, marker="x", s=100)
|
| 201 |
+
|
| 202 |
+
class UKF(Tracker):
|
| 203 |
+
|
| 204 |
+
def __init__(self, frame_nb, X0, transition_variance, observation_variance, delta):
|
| 205 |
+
super().__init__(frame_nb, X0, transition_variance, observation_variance, delta)
|
| 206 |
+
self.filter = AdditiveUnscentedKalmanFilter(initial_state_mean=X0,
|
| 207 |
+
initial_state_covariance=self.observation_covariance,
|
| 208 |
+
observation_functions = lambda z: np.eye(2).dot(z),
|
| 209 |
+
transition_covariance=self.transition_covariance,
|
| 210 |
+
observation_covariance=self.observation_covariance)
|
| 211 |
+
|
| 212 |
+
self.filtered_state_mean = X0
|
| 213 |
+
self.filtered_state_covariance = self.observation_covariance
|
| 214 |
+
|
| 215 |
+
def UKF_step(self, observation, flow):
|
| 216 |
+
return self.filter.filter_update(self.filtered_state_mean,
|
| 217 |
+
self.filtered_state_covariance,
|
| 218 |
+
transition_function=lambda x: x + flow[int(x[1]),int(x[0]),:],
|
| 219 |
+
observation=observation)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def update(self, observation, flow, frame_nb=None):
|
| 223 |
+
if observation is not None: self.store_observation(observation, frame_nb)
|
| 224 |
+
|
| 225 |
+
self.filtered_state_mean, self.filtered_state_covariance = self.UKF_step(observation, flow)
|
| 226 |
+
|
| 227 |
+
enabled=False if not in_frame(self.filtered_state_mean,flow.shape[:-1]) else True
|
| 228 |
+
|
| 229 |
+
return enabled
|
| 230 |
+
|
| 231 |
+
def predictive_distribution(self, flow):
|
| 232 |
+
|
| 233 |
+
filtered_state_mean, filtered_state_covariance = self.UKF_step(None, flow)
|
| 234 |
+
|
| 235 |
+
distribution = multivariate_normal(filtered_state_mean, filtered_state_covariance + self.observation_covariance)
|
| 236 |
+
|
| 237 |
+
return distribution
|
| 238 |
+
|
| 239 |
+
def fill_display(self, display, tracker_nb):
|
| 240 |
+
yy, xx = np.mgrid[0:display.display_shape[1]:1, 0:display.display_shape[0]:1]
|
| 241 |
+
pos = np.dstack((xx, yy))
|
| 242 |
+
distribution = multivariate_normal(self.filtered_state_mean, self.filtered_state_covariance)
|
| 243 |
+
|
| 244 |
+
color = self.get_display_colors(display, tracker_nb)
|
| 245 |
+
cs = display.ax.contour(distribution.pdf(pos), colors=color)
|
| 246 |
+
display.ax.clabel(cs, inline=True, fontsize='large')
|
| 247 |
+
display.ax.scatter(self.filtered_state_mean[0], self.filtered_state_mean[1], color=color, marker="x", s=100)
|
| 248 |
+
|
| 249 |
+
trackers = {'EKF': EKF,
|
| 250 |
+
'SMC': SMC,
|
| 251 |
+
'UKF': UKF}
|
| 252 |
+
|
| 253 |
+
def get_tracker(algorithm_and_params):
|
| 254 |
+
print(f'{algorithm_and_params} will be used for tracking.')
|
| 255 |
+
|
| 256 |
+
splitted_name = algorithm_and_params.split('_')
|
| 257 |
+
if len(splitted_name) > 1:
|
| 258 |
+
algorithm_name, param = splitted_name
|
| 259 |
+
tracker = trackers[algorithm_name]
|
| 260 |
+
tracker.set_param(param)
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
algorithm_name = splitted_name[0]
|
| 264 |
+
tracker = trackers[algorithm_name]
|
| 265 |
+
|
| 266 |
+
return tracker
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
tracking/utils.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.stats import multivariate_normal
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import torch
|
| 7 |
+
from tools.video_readers import TorchIterableFromReader
|
| 8 |
+
from time import time
|
| 9 |
+
from detection.transforms import TransformFrames
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from moviepy.editor import ImageSequenceClip
|
| 12 |
+
from skimage.transform import downscale_local_mean
|
| 13 |
+
|
| 14 |
+
class GaussianMixture(object):
|
| 15 |
+
def __init__(self, means, covariance, weights):
|
| 16 |
+
self.components = [multivariate_normal(
|
| 17 |
+
mean=mean, cov=covariance) for mean in means]
|
| 18 |
+
self.weights = weights
|
| 19 |
+
|
| 20 |
+
def pdf(self, x):
|
| 21 |
+
result = 0
|
| 22 |
+
for weight, component in zip(self.weights, self.components):
|
| 23 |
+
result += weight*component.pdf(x)
|
| 24 |
+
return result
|
| 25 |
+
|
| 26 |
+
def logpdf(self, x):
|
| 27 |
+
return np.log(self.pdf(x))
|
| 28 |
+
|
| 29 |
+
def cdf(self, x):
|
| 30 |
+
result = 0
|
| 31 |
+
for weight, component in zip(self.weights, self.components):
|
| 32 |
+
result += weight*component.cdf(x)
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
def init_trackers(engine, detections, frame_nb, state_variance, observation_variance, delta):
|
| 36 |
+
trackers = []
|
| 37 |
+
|
| 38 |
+
for detection in detections:
|
| 39 |
+
tracker_for_detection = engine(frame_nb, detection, state_variance, observation_variance, delta)
|
| 40 |
+
trackers.append(tracker_for_detection)
|
| 41 |
+
|
| 42 |
+
return trackers
|
| 43 |
+
|
| 44 |
+
def exp_and_normalise(lw):
|
| 45 |
+
w = np.exp(lw - lw.max())
|
| 46 |
+
return w / w.sum()
|
| 47 |
+
|
| 48 |
+
def in_frame(position, shape, border=0.02):
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
shape_x = shape[1]
|
| 52 |
+
shape_y = shape[0]
|
| 53 |
+
x = position[0]
|
| 54 |
+
y = position[1]
|
| 55 |
+
|
| 56 |
+
return x > border*shape_x and x < (1-border)*shape_x and y > border*shape_y and y < (1-border)*shape_y
|
| 57 |
+
|
| 58 |
+
def gather_filenames_for_video_in_annotations(video, images, data_dir):
|
| 59 |
+
images_for_video = [image for image in images
|
| 60 |
+
if image['video_id'] == video['id']]
|
| 61 |
+
images_for_video = sorted(
|
| 62 |
+
images_for_video, key=lambda image: image['frame_id'])
|
| 63 |
+
|
| 64 |
+
return [os.path.join(data_dir, image['file_name'])
|
| 65 |
+
for image in images_for_video]
|
| 66 |
+
|
| 67 |
+
def get_detections_for_video(reader, detector, batch_size=16, device=None):
|
| 68 |
+
|
| 69 |
+
detections = []
|
| 70 |
+
dataset = TorchIterableFromReader(reader, TransformFrames())
|
| 71 |
+
loader = DataLoader(dataset, batch_size=batch_size)
|
| 72 |
+
average_times = []
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for preprocessed_frames in loader:
|
| 75 |
+
time0 = time()
|
| 76 |
+
detections_for_frames = detector(preprocessed_frames.to(device))
|
| 77 |
+
average_times.append(time() - time0)
|
| 78 |
+
for detections_for_frame in detections_for_frames:
|
| 79 |
+
if len(detections_for_frame): detections.append(detections_for_frame)
|
| 80 |
+
else: detections.append(np.array([]))
|
| 81 |
+
print(f'Frame-wise inference time: {batch_size/np.mean(average_times)} fps')
|
| 82 |
+
return detections
|
| 83 |
+
|
| 84 |
+
def generate_video_with_annotations(video, output_detected, output_filename, skip_frames, logger):
|
| 85 |
+
fps = 24
|
| 86 |
+
logger.info("---intepreting json")
|
| 87 |
+
results = defaultdict(list)
|
| 88 |
+
for trash in output_detected["detected_trash"]:
|
| 89 |
+
for k, v in trash["frame_to_box"].items():
|
| 90 |
+
frame_nb = int(k) - 1
|
| 91 |
+
object_nb = trash["id"] + 1
|
| 92 |
+
center_x = v[0]
|
| 93 |
+
center_y = v[1]
|
| 94 |
+
results[frame_nb * (skip_frames+1)].append((object_nb, center_x, center_y))
|
| 95 |
+
# append next skip_frames
|
| 96 |
+
if str(frame_nb + 2) in trash["frame_to_box"]:
|
| 97 |
+
next_trash = trash["frame_to_box"][str(frame_nb + 2)]
|
| 98 |
+
next_x = next_trash[0]
|
| 99 |
+
next_y = next_trash[1]
|
| 100 |
+
for i in range(1, skip_frames+1):
|
| 101 |
+
new_x = center_x + (next_x - center_x) * i/(skip_frames+1)
|
| 102 |
+
new_y = center_y + (next_y - center_y) * i/(skip_frames+1)
|
| 103 |
+
results[frame_nb * (skip_frames+1) + i].append((object_nb, new_x, new_y))
|
| 104 |
+
logger.info("---writing video")
|
| 105 |
+
|
| 106 |
+
#fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 107 |
+
# writer = cv2.VideoWriter(filename=output_filename,
|
| 108 |
+
#apiPreference=cv2.CAP_FFMPEG,
|
| 109 |
+
# fourcc=fourcc,
|
| 110 |
+
# fps=fps,
|
| 111 |
+
# frameSize=video.shape)
|
| 112 |
+
|
| 113 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 114 |
+
ret, frame, frame_nb = video.read()
|
| 115 |
+
frames = []
|
| 116 |
+
while ret:
|
| 117 |
+
detections_for_frame = results[frame_nb]
|
| 118 |
+
for detection in detections_for_frame:
|
| 119 |
+
cv2.putText(frame, '{}'.format(detection[0]), (int(detection[1]), int(detection[2])+5), font, 2, (0, 0, 255), 3, cv2.LINE_AA)
|
| 120 |
+
|
| 121 |
+
#writer.write(frame)
|
| 122 |
+
frame = downscale_local_mean(frame, (4,4,1))
|
| 123 |
+
frames.append(frame[:,:,::-1])
|
| 124 |
+
|
| 125 |
+
ret, frame, frame_nb = video.read()
|
| 126 |
+
if frame_nb > 24 * 10:
|
| 127 |
+
break
|
| 128 |
+
#writer.release()
|
| 129 |
+
|
| 130 |
+
# Test with another library
|
| 131 |
+
|
| 132 |
+
clip = ImageSequenceClip(sequence=frames, fps=fps)
|
| 133 |
+
clip.write_videofile(output_filename, fps=fps)
|
| 134 |
+
del frames
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
logger.info("---finished writing video")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def resize_external_detections(detections, ratio):
|
| 141 |
+
|
| 142 |
+
for detection_nb in range(len(detections)):
|
| 143 |
+
detection = detections[detection_nb]
|
| 144 |
+
if len(detection):
|
| 145 |
+
detection = np.array(detection)[:,:-1]
|
| 146 |
+
detection[:,0] = (detection[:,0] + detection[:,2])/2
|
| 147 |
+
detection[:,1] = (detection[:,1] + detection[:,3])/2
|
| 148 |
+
detections[detection_nb] = detection[:,:2]/ratio
|
| 149 |
+
return detections
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def write_tracking_results_to_file(results, ratio_x, ratio_y, output_filename):
|
| 153 |
+
""" writes the output result of a tracking the following format:
|
| 154 |
+
- frame
|
| 155 |
+
- id
|
| 156 |
+
- x_tl, y_tl, w=0, h=0
|
| 157 |
+
- 4x unused=-1
|
| 158 |
+
"""
|
| 159 |
+
with open(output_filename, 'w') as output_file:
|
| 160 |
+
for result in results:
|
| 161 |
+
output_file.write('{},{},{},{},{},{},{},{},{},{}\n'.format(result[0]+1,
|
| 162 |
+
result[1]+1,
|
| 163 |
+
ratio_x * result[2],
|
| 164 |
+
ratio_y * result[3],
|
| 165 |
+
0,
|
| 166 |
+
0,
|
| 167 |
+
-1,-1,-1,-1))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def read_tracking_results(input_file):
|
| 171 |
+
""" read the input filename and interpret it as tracklets
|
| 172 |
+
i.e. lists of lists
|
| 173 |
+
"""
|
| 174 |
+
raw_results = np.loadtxt(input_file, delimiter=',')
|
| 175 |
+
if raw_results.ndim == 1: raw_results = np.expand_dims(raw_results,axis=0)
|
| 176 |
+
tracklets = defaultdict(list)
|
| 177 |
+
for result in raw_results:
|
| 178 |
+
frame_id = int(result[0])
|
| 179 |
+
track_id = int(result[1])
|
| 180 |
+
left, top, width, height = result[2:6]
|
| 181 |
+
center_x = left + width/2
|
| 182 |
+
center_y = top + height/2
|
| 183 |
+
tracklets[track_id].append((frame_id, center_x, center_y))
|
| 184 |
+
|
| 185 |
+
tracklets = list(tracklets.values())
|
| 186 |
+
return tracklets
|
| 187 |
+
|
| 188 |
+
def gather_tracklets(tracklist):
|
| 189 |
+
""" Converts a list of flat tracklets into a list of lists
|
| 190 |
+
"""
|
| 191 |
+
tracklets = defaultdict(list)
|
| 192 |
+
for track in tracklist:
|
| 193 |
+
frame_id = track[0]
|
| 194 |
+
track_id = track[1]
|
| 195 |
+
center_x = track[2]
|
| 196 |
+
center_y = track[3]
|
| 197 |
+
tracklets[track_id].append((frame_id, center_x, center_y))
|
| 198 |
+
|
| 199 |
+
tracklets = list(tracklets.values())
|
| 200 |
+
return tracklets
|
| 201 |
+
|
| 202 |
+
class FramesWithInfo:
|
| 203 |
+
def __init__(self, frames, output_shape=None):
|
| 204 |
+
self.frames = frames
|
| 205 |
+
if output_shape is None:
|
| 206 |
+
self.output_shape = frames[0].shape[:-1][::-1]
|
| 207 |
+
else: self.output_shape = output_shape
|
| 208 |
+
self.end = len(frames)
|
| 209 |
+
self.read_head = 0
|
| 210 |
+
|
| 211 |
+
def __next__(self):
|
| 212 |
+
if self.read_head < self.end:
|
| 213 |
+
frame = self.frames[self.read_head]
|
| 214 |
+
self.read_head+=1
|
| 215 |
+
return frame
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
raise StopIteration
|
| 219 |
+
|
| 220 |
+
def __iter__(self):
|
| 221 |
+
return self
|