Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- model.pth +3 -0
- siamrpn.py +287 -0
model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:071e99722aca789d8684a12a98795c5515b448a2a2dcb9280a23b8ddbc66554a
|
| 3 |
+
size 361780834
|
siamrpn.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, division
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from collections import namedtuple
|
| 9 |
+
from got10k.trackers import Tracker
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SiamRPN(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self, anchor_num=5):
|
| 15 |
+
super(SiamRPN, self).__init__()
|
| 16 |
+
self.anchor_num = anchor_num
|
| 17 |
+
self.feature = nn.Sequential(
|
| 18 |
+
# conv1
|
| 19 |
+
nn.Conv2d(3, 192, 11, 2),
|
| 20 |
+
nn.BatchNorm2d(192),
|
| 21 |
+
nn.ReLU(inplace=True),
|
| 22 |
+
nn.MaxPool2d(3, 2),
|
| 23 |
+
# conv2
|
| 24 |
+
nn.Conv2d(192, 512, 5, 1),
|
| 25 |
+
nn.BatchNorm2d(512),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.MaxPool2d(3, 2),
|
| 28 |
+
# conv3
|
| 29 |
+
nn.Conv2d(512, 768, 3, 1),
|
| 30 |
+
nn.BatchNorm2d(768),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
# conv4
|
| 33 |
+
nn.Conv2d(768, 768, 3, 1),
|
| 34 |
+
nn.BatchNorm2d(768),
|
| 35 |
+
nn.ReLU(inplace=True),
|
| 36 |
+
# conv5
|
| 37 |
+
nn.Conv2d(768, 512, 3, 1),
|
| 38 |
+
nn.BatchNorm2d(512))
|
| 39 |
+
|
| 40 |
+
self.conv_reg_z = nn.Conv2d(512, 512 * 4 * anchor_num, 3, 1)
|
| 41 |
+
self.conv_reg_x = nn.Conv2d(512, 512, 3)
|
| 42 |
+
self.conv_cls_z = nn.Conv2d(512, 512 * 2 * anchor_num, 3, 1)
|
| 43 |
+
self.conv_cls_x = nn.Conv2d(512, 512, 3)
|
| 44 |
+
self.adjust_reg = nn.Conv2d(4 * anchor_num, 4 * anchor_num, 1)
|
| 45 |
+
|
| 46 |
+
def forward(self, z, x):
|
| 47 |
+
return self.inference(x, **self.learn(z))
|
| 48 |
+
|
| 49 |
+
def learn(self, z):
|
| 50 |
+
z = self.feature(z)
|
| 51 |
+
kernel_reg = self.conv_reg_z(z)
|
| 52 |
+
kernel_cls = self.conv_cls_z(z)
|
| 53 |
+
|
| 54 |
+
k = kernel_reg.size()[-1]
|
| 55 |
+
kernel_reg = kernel_reg.view(4 * self.anchor_num, 512, k, k)
|
| 56 |
+
kernel_cls = kernel_cls.view(2 * self.anchor_num, 512, k, k)
|
| 57 |
+
|
| 58 |
+
return kernel_reg, kernel_cls
|
| 59 |
+
|
| 60 |
+
def inference(self, x, kernel_reg, kernel_cls):
|
| 61 |
+
x = self.feature(x)
|
| 62 |
+
x_reg = self.conv_reg_x(x)
|
| 63 |
+
x_cls = self.conv_cls_x(x)
|
| 64 |
+
|
| 65 |
+
out_reg = self.adjust_reg(F.conv2d(x_reg, kernel_reg))
|
| 66 |
+
out_cls = F.conv2d(x_cls, kernel_cls)
|
| 67 |
+
|
| 68 |
+
return out_reg, out_cls
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TrackerSiamRPN(Tracker):
|
| 72 |
+
|
| 73 |
+
def __init__(self, net_path=None, **kargs):
|
| 74 |
+
super(TrackerSiamRPN, self).__init__(
|
| 75 |
+
name='SiamRPN', is_deterministic=True)
|
| 76 |
+
self.parse_args(**kargs)
|
| 77 |
+
|
| 78 |
+
# setup GPU device if available
|
| 79 |
+
self.cuda = torch.cuda.is_available()
|
| 80 |
+
self.device = torch.device('cuda:0' if self.cuda else 'cpu')
|
| 81 |
+
|
| 82 |
+
# setup model
|
| 83 |
+
self.net = SiamRPN()
|
| 84 |
+
if net_path is not None:
|
| 85 |
+
self.net.load_state_dict(torch.load(
|
| 86 |
+
net_path, map_location=lambda storage, loc: storage))
|
| 87 |
+
self.net = self.net.to(self.device)
|
| 88 |
+
|
| 89 |
+
def parse_args(self, **kargs):
|
| 90 |
+
self.cfg = {
|
| 91 |
+
'exemplar_sz': 127,
|
| 92 |
+
'instance_sz': 271,
|
| 93 |
+
'total_stride': 8,
|
| 94 |
+
'context': 0.5,
|
| 95 |
+
'ratios': [0.33, 0.5, 1, 2, 3],
|
| 96 |
+
'scales': [8,],
|
| 97 |
+
'penalty_k': 0.055,
|
| 98 |
+
'window_influence': 0.42,
|
| 99 |
+
'lr': 0.295}
|
| 100 |
+
|
| 101 |
+
for key, val in kargs.items():
|
| 102 |
+
self.cfg.update({key: val})
|
| 103 |
+
self.cfg = namedtuple('GenericDict', self.cfg.keys())(**self.cfg)
|
| 104 |
+
|
| 105 |
+
def init(self, image, box):
|
| 106 |
+
image = np.asarray(image)
|
| 107 |
+
|
| 108 |
+
# convert box to 0-indexed and center based [y, x, h, w]
|
| 109 |
+
box = np.array([
|
| 110 |
+
box[1] - 1 + (box[3] - 1) / 2,
|
| 111 |
+
box[0] - 1 + (box[2] - 1) / 2,
|
| 112 |
+
box[3], box[2]], dtype=np.float32)
|
| 113 |
+
self.center, self.target_sz = box[:2], box[2:]
|
| 114 |
+
|
| 115 |
+
# for small target, use larger search region
|
| 116 |
+
if np.prod(self.target_sz) / np.prod(image.shape[:2]) < 0.004:
|
| 117 |
+
self.cfg = self.cfg._replace(instance_sz=287)
|
| 118 |
+
|
| 119 |
+
# generate anchors
|
| 120 |
+
self.response_sz = (self.cfg.instance_sz - \
|
| 121 |
+
self.cfg.exemplar_sz) // self.cfg.total_stride + 1
|
| 122 |
+
self.anchors = self._create_anchors(self.response_sz)
|
| 123 |
+
|
| 124 |
+
# create hanning window
|
| 125 |
+
self.hann_window = np.outer(
|
| 126 |
+
np.hanning(self.response_sz),
|
| 127 |
+
np.hanning(self.response_sz))
|
| 128 |
+
self.hann_window = np.tile(
|
| 129 |
+
self.hann_window.flatten(),
|
| 130 |
+
len(self.cfg.ratios) * len(self.cfg.scales))
|
| 131 |
+
|
| 132 |
+
# exemplar and search sizes
|
| 133 |
+
context = self.cfg.context * np.sum(self.target_sz)
|
| 134 |
+
self.z_sz = np.sqrt(np.prod(self.target_sz + context))
|
| 135 |
+
self.x_sz = self.z_sz * \
|
| 136 |
+
self.cfg.instance_sz / self.cfg.exemplar_sz
|
| 137 |
+
|
| 138 |
+
# exemplar image
|
| 139 |
+
self.avg_color = np.mean(image, axis=(0, 1))
|
| 140 |
+
exemplar_image = self._crop_and_resize(
|
| 141 |
+
image, self.center, self.z_sz,
|
| 142 |
+
self.cfg.exemplar_sz, self.avg_color)
|
| 143 |
+
|
| 144 |
+
# classification and regression kernels
|
| 145 |
+
exemplar_image = torch.from_numpy(exemplar_image).to(
|
| 146 |
+
self.device).permute([2, 0, 1]).unsqueeze(0).float()
|
| 147 |
+
with torch.set_grad_enabled(False):
|
| 148 |
+
self.net.eval()
|
| 149 |
+
self.kernel_reg, self.kernel_cls = self.net.learn(exemplar_image)
|
| 150 |
+
|
| 151 |
+
def update(self, image):
|
| 152 |
+
image = np.asarray(image)
|
| 153 |
+
|
| 154 |
+
# search image
|
| 155 |
+
instance_image = self._crop_and_resize(
|
| 156 |
+
image, self.center, self.x_sz,
|
| 157 |
+
self.cfg.instance_sz, self.avg_color)
|
| 158 |
+
|
| 159 |
+
# classification and regression outputs
|
| 160 |
+
instance_image = torch.from_numpy(instance_image).to(
|
| 161 |
+
self.device).permute(2, 0, 1).unsqueeze(0).float()
|
| 162 |
+
with torch.set_grad_enabled(False):
|
| 163 |
+
self.net.eval()
|
| 164 |
+
out_reg, out_cls = self.net.inference(
|
| 165 |
+
instance_image, self.kernel_reg, self.kernel_cls)
|
| 166 |
+
|
| 167 |
+
# offsets
|
| 168 |
+
offsets = out_reg.permute(
|
| 169 |
+
1, 2, 3, 0).contiguous().view(4, -1).cpu().numpy()
|
| 170 |
+
offsets[0] = offsets[0] * self.anchors[:, 2] + self.anchors[:, 0]
|
| 171 |
+
offsets[1] = offsets[1] * self.anchors[:, 3] + self.anchors[:, 1]
|
| 172 |
+
offsets[2] = np.exp(offsets[2]) * self.anchors[:, 2]
|
| 173 |
+
offsets[3] = np.exp(offsets[3]) * self.anchors[:, 3]
|
| 174 |
+
|
| 175 |
+
# scale and ratio penalty
|
| 176 |
+
penalty = self._create_penalty(self.target_sz, offsets)
|
| 177 |
+
|
| 178 |
+
# response
|
| 179 |
+
response = F.softmax(out_cls.permute(
|
| 180 |
+
1, 2, 3, 0).contiguous().view(2, -1), dim=0).data[1].cpu().numpy()
|
| 181 |
+
response = response * penalty
|
| 182 |
+
response = (1 - self.cfg.window_influence) * response + \
|
| 183 |
+
self.cfg.window_influence * self.hann_window
|
| 184 |
+
|
| 185 |
+
# peak location
|
| 186 |
+
best_id = np.argmax(response)
|
| 187 |
+
offset = offsets[:, best_id] * self.z_sz / self.cfg.exemplar_sz
|
| 188 |
+
|
| 189 |
+
# update center
|
| 190 |
+
self.center += offset[:2][::-1]
|
| 191 |
+
self.center = np.clip(self.center, 0, image.shape[:2])
|
| 192 |
+
|
| 193 |
+
# update scale
|
| 194 |
+
lr = response[best_id] * self.cfg.lr
|
| 195 |
+
self.target_sz = (1 - lr) * self.target_sz + lr * offset[2:][::-1]
|
| 196 |
+
self.target_sz = np.clip(self.target_sz, 10, image.shape[:2])
|
| 197 |
+
|
| 198 |
+
# update exemplar and instance sizes
|
| 199 |
+
context = self.cfg.context * np.sum(self.target_sz)
|
| 200 |
+
self.z_sz = np.sqrt(np.prod(self.target_sz + context))
|
| 201 |
+
self.x_sz = self.z_sz * \
|
| 202 |
+
self.cfg.instance_sz / self.cfg.exemplar_sz
|
| 203 |
+
|
| 204 |
+
# return 1-indexed and left-top based bounding box
|
| 205 |
+
box = np.array([
|
| 206 |
+
self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
|
| 207 |
+
self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
|
| 208 |
+
self.target_sz[1], self.target_sz[0]])
|
| 209 |
+
|
| 210 |
+
return box
|
| 211 |
+
|
| 212 |
+
def _create_anchors(self, response_sz):
|
| 213 |
+
anchor_num = len(self.cfg.ratios) * len(self.cfg.scales)
|
| 214 |
+
anchors = np.zeros((anchor_num, 4), dtype=np.float32)
|
| 215 |
+
|
| 216 |
+
size = self.cfg.total_stride * self.cfg.total_stride
|
| 217 |
+
ind = 0
|
| 218 |
+
for ratio in self.cfg.ratios:
|
| 219 |
+
w = int(np.sqrt(size / ratio))
|
| 220 |
+
h = int(w * ratio)
|
| 221 |
+
for scale in self.cfg.scales:
|
| 222 |
+
anchors[ind, 0] = 0
|
| 223 |
+
anchors[ind, 1] = 0
|
| 224 |
+
anchors[ind, 2] = w * scale
|
| 225 |
+
anchors[ind, 3] = h * scale
|
| 226 |
+
ind += 1
|
| 227 |
+
anchors = np.tile(
|
| 228 |
+
anchors, response_sz * response_sz).reshape((-1, 4))
|
| 229 |
+
|
| 230 |
+
begin = -(response_sz // 2) * self.cfg.total_stride
|
| 231 |
+
xs, ys = np.meshgrid(
|
| 232 |
+
begin + self.cfg.total_stride * np.arange(response_sz),
|
| 233 |
+
begin + self.cfg.total_stride * np.arange(response_sz))
|
| 234 |
+
xs = np.tile(xs.flatten(), (anchor_num, 1)).flatten()
|
| 235 |
+
ys = np.tile(ys.flatten(), (anchor_num, 1)).flatten()
|
| 236 |
+
anchors[:, 0] = xs.astype(np.float32)
|
| 237 |
+
anchors[:, 1] = ys.astype(np.float32)
|
| 238 |
+
|
| 239 |
+
return anchors
|
| 240 |
+
|
| 241 |
+
def _create_penalty(self, target_sz, offsets):
|
| 242 |
+
def padded_size(w, h):
|
| 243 |
+
context = self.cfg.context * (w + h)
|
| 244 |
+
return np.sqrt((w + context) * (h + context))
|
| 245 |
+
|
| 246 |
+
def larger_ratio(r):
|
| 247 |
+
return np.maximum(r, 1 / r)
|
| 248 |
+
|
| 249 |
+
src_sz = padded_size(
|
| 250 |
+
*(target_sz * self.cfg.exemplar_sz / self.z_sz))
|
| 251 |
+
dst_sz = padded_size(offsets[2], offsets[3])
|
| 252 |
+
change_sz = larger_ratio(dst_sz / src_sz)
|
| 253 |
+
|
| 254 |
+
src_ratio = target_sz[1] / target_sz[0]
|
| 255 |
+
dst_ratio = offsets[2] / offsets[3]
|
| 256 |
+
change_ratio = larger_ratio(dst_ratio / src_ratio)
|
| 257 |
+
|
| 258 |
+
penalty = np.exp(-(change_ratio * change_sz - 1) * \
|
| 259 |
+
self.cfg.penalty_k)
|
| 260 |
+
|
| 261 |
+
return penalty
|
| 262 |
+
|
| 263 |
+
def _crop_and_resize(self, image, center, size, out_size, pad_color):
|
| 264 |
+
# convert box to corners (0-indexed)
|
| 265 |
+
size = round(size)
|
| 266 |
+
corners = np.concatenate((
|
| 267 |
+
np.round(center - (size - 1) / 2),
|
| 268 |
+
np.round(center - (size - 1) / 2) + size))
|
| 269 |
+
corners = np.round(corners).astype(int)
|
| 270 |
+
|
| 271 |
+
# pad image if necessary
|
| 272 |
+
pads = np.concatenate((
|
| 273 |
+
-corners[:2], corners[2:] - image.shape[:2]))
|
| 274 |
+
npad = max(0, int(pads.max()))
|
| 275 |
+
if npad > 0:
|
| 276 |
+
image = cv2.copyMakeBorder(
|
| 277 |
+
image, npad, npad, npad, npad,
|
| 278 |
+
cv2.BORDER_CONSTANT, value=pad_color)
|
| 279 |
+
|
| 280 |
+
# crop image patch
|
| 281 |
+
corners = (corners + npad).astype(int)
|
| 282 |
+
patch = image[corners[0]:corners[2], corners[1]:corners[3]]
|
| 283 |
+
|
| 284 |
+
# resize to out_size
|
| 285 |
+
patch = cv2.resize(patch, (out_size, out_size))
|
| 286 |
+
|
| 287 |
+
return patch
|