Spaces:
Runtime error
Runtime error
Update models/yolo.py
Browse files- models/yolo.py +110 -0
models/yolo.py
CHANGED
|
@@ -307,6 +307,93 @@ class IKeypoint(nn.Module):
|
|
| 307 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
| 308 |
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
class IAuxDetect(nn.Module):
|
| 312 |
stride = None # strides computed during build
|
|
@@ -572,6 +659,16 @@ class Model(nn.Module):
|
|
| 572 |
self.stride = m.stride
|
| 573 |
self._initialize_biases_kpt() # only run once
|
| 574 |
# print('Strides: %s' % m.stride.tolist())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
|
| 576 |
# Init weights, biases
|
| 577 |
initialize_weights(self)
|
|
@@ -793,10 +890,23 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|
| 793 |
args[1] = [list(range(args[1] * 2))] * len(f)
|
| 794 |
elif m is ReOrg:
|
| 795 |
c2 = ch[f] * 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
elif m is Contract:
|
| 797 |
c2 = ch[f] * args[0] ** 2
|
| 798 |
elif m is Expand:
|
| 799 |
c2 = ch[f] // args[0] ** 2
|
|
|
|
|
|
|
|
|
|
| 800 |
else:
|
| 801 |
c2 = ch[f]
|
| 802 |
|
|
|
|
| 307 |
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
| 308 |
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
| 309 |
|
| 310 |
+
class MT(nn.Module):
|
| 311 |
+
stride = None # strides computed during build
|
| 312 |
+
export = False # onnx export
|
| 313 |
+
|
| 314 |
+
def __init__(self, nc=80, anchors=(), attn=None, mask_iou=False, ch=()): # detection layer
|
| 315 |
+
super(MT, self).__init__()
|
| 316 |
+
self.nc = nc # number of classes
|
| 317 |
+
self.no = nc + 5 # number of outputs per anchor
|
| 318 |
+
self.nl = len(anchors) # number of detection layers
|
| 319 |
+
self.na = len(anchors[0]) // 2 # number of anchors
|
| 320 |
+
self.grid = [torch.zeros(1)] * self.nl # init grid
|
| 321 |
+
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
|
| 322 |
+
self.register_buffer('anchors', a) # shape(nl,na,2)
|
| 323 |
+
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
|
| 324 |
+
self.original_anchors = anchors
|
| 325 |
+
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[0]) # output conv
|
| 326 |
+
if mask_iou:
|
| 327 |
+
self.m_iou = nn.ModuleList(nn.Conv2d(x, self.na, 1) for x in ch[0]) # output con
|
| 328 |
+
self.mask_iou = mask_iou
|
| 329 |
+
self.attn = attn
|
| 330 |
+
if attn is not None:
|
| 331 |
+
# self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 3, padding=1) for x in ch) # output conv
|
| 332 |
+
self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 1) for x in ch[0]) # output conv
|
| 333 |
+
#self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, kernel_size=3, stride=1, padding=1) for x in ch) # output conv
|
| 334 |
+
|
| 335 |
+
def forward(self, x):
|
| 336 |
+
#print(x[1].shape)
|
| 337 |
+
#print(x[2].shape)
|
| 338 |
+
#print([a.shape for a in x])
|
| 339 |
+
#exit()
|
| 340 |
+
# x = x.copy() # for profiling
|
| 341 |
+
z = [] # inference output
|
| 342 |
+
za = []
|
| 343 |
+
zi = []
|
| 344 |
+
attn = [None] * self.nl
|
| 345 |
+
iou = [None] * self.nl
|
| 346 |
+
self.training |= self.export
|
| 347 |
+
output = dict()
|
| 348 |
+
for i in range(self.nl):
|
| 349 |
+
if self.attn is not None:
|
| 350 |
+
attn[i] = self.attn_m[i](x[0][i]) # conv
|
| 351 |
+
bs, _, ny, nx = attn[i].shape # x(bs,2352,20,20) to x(bs,3,20,20,784)
|
| 352 |
+
attn[i] = attn[i].view(bs, self.na, self.attn, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
| 353 |
+
if self.mask_iou:
|
| 354 |
+
iou[i] = self.m_iou[i](x[0][i])
|
| 355 |
+
x[0][i] = self.m[i](x[0][i]) # conv
|
| 356 |
+
|
| 357 |
+
bs, _, ny, nx = x[0][i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
| 358 |
+
x[0][i] = x[0][i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
| 359 |
+
if self.mask_iou:
|
| 360 |
+
iou[i] = iou[i].view(bs, self.na, ny, nx).contiguous()
|
| 361 |
+
|
| 362 |
+
if not self.training: # inference
|
| 363 |
+
za.append(attn[i].view(bs, -1, self.attn))
|
| 364 |
+
if self.mask_iou:
|
| 365 |
+
zi.append(iou[i].view(bs, -1))
|
| 366 |
+
if self.grid[i].shape[2:4] != x[0][i].shape[2:4]:
|
| 367 |
+
self.grid[i] = self._make_grid(nx, ny).to(x[0][i].device)
|
| 368 |
+
|
| 369 |
+
y = x[0][i].sigmoid()
|
| 370 |
+
y[..., 0:2] = (y[..., 0:2] * 3. - 1.0 + self.grid[i]) * self.stride[i] # xy
|
| 371 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
|
| 372 |
+
z.append(y.view(bs, -1, self.no))
|
| 373 |
+
output["mask_iou"] = None
|
| 374 |
+
if not self.training:
|
| 375 |
+
output["test"] = torch.cat(z, 1)
|
| 376 |
+
if self.attn is not None:
|
| 377 |
+
output["attn"] = torch.cat(za, 1)
|
| 378 |
+
if self.mask_iou:
|
| 379 |
+
output["mask_iou"] = torch.cat(zi, 1).sigmoid()
|
| 380 |
+
|
| 381 |
+
else:
|
| 382 |
+
if self.attn is not None:
|
| 383 |
+
output["attn"] = attn
|
| 384 |
+
if self.mask_iou:
|
| 385 |
+
output["mask_iou"] = iou
|
| 386 |
+
output["bbox_and_cls"] = x[0]
|
| 387 |
+
output["bases"] = x[1]
|
| 388 |
+
output["sem"] = x[2]
|
| 389 |
+
|
| 390 |
+
return output
|
| 391 |
+
|
| 392 |
+
@staticmethod
|
| 393 |
+
def _make_grid(nx=20, ny=20):
|
| 394 |
+
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
| 395 |
+
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
| 396 |
+
|
| 397 |
|
| 398 |
class IAuxDetect(nn.Module):
|
| 399 |
stride = None # strides computed during build
|
|
|
|
| 659 |
self.stride = m.stride
|
| 660 |
self._initialize_biases_kpt() # only run once
|
| 661 |
# print('Strides: %s' % m.stride.tolist())
|
| 662 |
+
if isinstance(m, MT):
|
| 663 |
+
s = 256 # 2x min stride
|
| 664 |
+
temp = self.forward(torch.zeros(1, ch, s, s))
|
| 665 |
+
if isinstance(temp, list):
|
| 666 |
+
temp = temp[0]
|
| 667 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in temp["bbox_and_cls"]]) # forward
|
| 668 |
+
check_anchor_order(m)
|
| 669 |
+
m.anchors /= m.stride.view(-1, 1, 1)
|
| 670 |
+
self.stride = m.stride
|
| 671 |
+
self._initialize_biases()
|
| 672 |
|
| 673 |
# Init weights, biases
|
| 674 |
initialize_weights(self)
|
|
|
|
| 890 |
args[1] = [list(range(args[1] * 2))] * len(f)
|
| 891 |
elif m is ReOrg:
|
| 892 |
c2 = ch[f] * 4
|
| 893 |
+
elif m in [Merge]:
|
| 894 |
+
c2 = args[0]
|
| 895 |
+
elif m in [MT]:
|
| 896 |
+
if len(args) == 3:
|
| 897 |
+
args.append(False)
|
| 898 |
+
#print(f)
|
| 899 |
+
#print(len(ch))
|
| 900 |
+
#for x in f:
|
| 901 |
+
# print(ch[x])
|
| 902 |
+
args.append([ch[x] for x in f])
|
| 903 |
elif m is Contract:
|
| 904 |
c2 = ch[f] * args[0] ** 2
|
| 905 |
elif m is Expand:
|
| 906 |
c2 = ch[f] // args[0] ** 2
|
| 907 |
+
elif m is Refine:
|
| 908 |
+
args.append([ch[x] for x in f])
|
| 909 |
+
c2 = args[0]
|
| 910 |
else:
|
| 911 |
c2 = ch[f]
|
| 912 |
|