Spaces:
Running
Running
demo letr
Browse files- app.py +57 -0
- checkpoint0024.pth +3 -0
- demo.png +0 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/backbone.cpython-38.pyc +0 -0
- models/__pycache__/letr.cpython-38.pyc +0 -0
- models/__pycache__/letr_stack.cpython-38.pyc +0 -0
- models/__pycache__/matcher.cpython-38.pyc +0 -0
- models/__pycache__/misc.cpython-38.pyc +0 -0
- models/__pycache__/multi_head_attention.cpython-38.pyc +0 -0
- models/__pycache__/position_encoding.cpython-38.pyc +0 -0
- models/__pycache__/preprocessing.cpython-38.pyc +0 -0
- models/__pycache__/transformer.cpython-38.pyc +0 -0
- models/backbone.py +120 -0
- models/letr.py +371 -0
- models/letr_stack.py +376 -0
- models/matcher.py +81 -0
- models/misc.py +467 -0
- models/multi_head_attention.py +537 -0
- models/position_encoding.py +89 -0
- models/preprocessing.py +71 -0
- models/transformer.py +297 -0
- requirements.txt +5 -0
- tappeto-per-calibrazione.jpg +0 -0
- test.py +67 -0
app.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
# import sys
|
| 10 |
+
# sys.path.insert(0, './')
|
| 11 |
+
from test import create_letr, draw_fig
|
| 12 |
+
from models.preprocessing import *
|
| 13 |
+
from models.misc import nested_tensor_from_tensor_list
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
model = create_letr()
|
| 17 |
+
|
| 18 |
+
# PREPARE PREPROCESSING
|
| 19 |
+
test_size = 256
|
| 20 |
+
# transform_test = transforms.Compose([
|
| 21 |
+
# transforms.Resize((test_size)),
|
| 22 |
+
# transforms.ToTensor(),
|
| 23 |
+
# transforms.Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
| 24 |
+
# ])
|
| 25 |
+
normalize = Compose([
|
| 26 |
+
ToTensor(),
|
| 27 |
+
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
| 28 |
+
Resize([test_size]),
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def predict(inp):
|
| 33 |
+
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
| 34 |
+
h, w = image.height, image.width
|
| 35 |
+
orig_size = torch.as_tensor([int(h), int(w)])
|
| 36 |
+
|
| 37 |
+
img = normalize(image)
|
| 38 |
+
inputs = nested_tensor_from_tensor_list([img])
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
outputs = model(inputs)[0]
|
| 42 |
+
|
| 43 |
+
draw_fig(image, outputs, orig_size)
|
| 44 |
+
|
| 45 |
+
return image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
inputs = gr.inputs.Image()
|
| 49 |
+
outputs = gr.outputs.Image()
|
| 50 |
+
gr.Interface(
|
| 51 |
+
fn=predict,
|
| 52 |
+
inputs=inputs,
|
| 53 |
+
outputs=outputs,
|
| 54 |
+
examples=["demo.png", "tappeto-per-calibrazione.jpg"],
|
| 55 |
+
title="LETR",
|
| 56 |
+
description="Model for line detection..."
|
| 57 |
+
).launch()
|
checkpoint0024.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26725e48335937731ac968a3fbde602d296ca3edcf93b79f4f76f356ad3a4ff9
|
| 3 |
+
size 380893769
|
demo.png
ADDED
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
models/__pycache__/backbone.cpython-38.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
models/__pycache__/letr.cpython-38.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
models/__pycache__/letr_stack.cpython-38.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
models/__pycache__/matcher.cpython-38.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
models/__pycache__/misc.cpython-38.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
models/__pycache__/multi_head_attention.cpython-38.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
models/__pycache__/position_encoding.cpython-38.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
models/__pycache__/preprocessing.cpython-38.pyc
ADDED
|
Binary file (2.98 kB). View file
|
|
|
models/__pycache__/transformer.cpython-38.pyc
ADDED
|
Binary file (9 kB). View file
|
|
|
models/backbone.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LETR Backbone modules.
|
| 3 |
+
modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
| 4 |
+
"""
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
from .misc import NestedTensor, is_main_process
|
| 15 |
+
|
| 16 |
+
from .position_encoding import build_position_encoding
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 22 |
+
|
| 23 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
| 24 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
| 25 |
+
produce nans.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, n):
|
| 29 |
+
super(FrozenBatchNorm2d, self).__init__()
|
| 30 |
+
self.register_buffer("weight", torch.ones(n))
|
| 31 |
+
self.register_buffer("bias", torch.zeros(n))
|
| 32 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
| 33 |
+
self.register_buffer("running_var", torch.ones(n))
|
| 34 |
+
|
| 35 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 36 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 37 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
| 38 |
+
if num_batches_tracked_key in state_dict:
|
| 39 |
+
del state_dict[num_batches_tracked_key]
|
| 40 |
+
|
| 41 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
| 42 |
+
state_dict, prefix, local_metadata, strict,
|
| 43 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
# move reshapes to the beginning
|
| 47 |
+
# to make it fuser-friendly
|
| 48 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
| 49 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
| 50 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
| 51 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
| 52 |
+
eps = 1e-5
|
| 53 |
+
scale = w * (rv + eps).rsqrt()
|
| 54 |
+
bias = b - rm * scale
|
| 55 |
+
return x * scale + bias
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class BackboneBase(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
| 61 |
+
super().__init__()
|
| 62 |
+
for name, parameter in backbone.named_parameters():
|
| 63 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
| 64 |
+
parameter.requires_grad_(False)
|
| 65 |
+
if return_interm_layers:
|
| 66 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
| 67 |
+
else:
|
| 68 |
+
return_layers = {'layer4': "0"}
|
| 69 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 70 |
+
self.num_channels = num_channels
|
| 71 |
+
|
| 72 |
+
def forward(self, tensor_list: NestedTensor):
|
| 73 |
+
xs = self.body(tensor_list.tensors)
|
| 74 |
+
out: Dict[str, NestedTensor] = {}
|
| 75 |
+
for name, x in xs.items():
|
| 76 |
+
|
| 77 |
+
m = tensor_list.mask
|
| 78 |
+
assert m is not None
|
| 79 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
| 80 |
+
out[name] = NestedTensor(x, mask)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Backbone(BackboneBase):
|
| 85 |
+
"""ResNet backbone with frozen BatchNorm."""
|
| 86 |
+
def __init__(self, name: str,
|
| 87 |
+
train_backbone: bool,
|
| 88 |
+
return_interm_layers: bool,
|
| 89 |
+
dilation: bool):
|
| 90 |
+
backbone = getattr(torchvision.models, name)(
|
| 91 |
+
replace_stride_with_dilation=[False, False, dilation],
|
| 92 |
+
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
| 93 |
+
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
| 94 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Joiner(nn.Sequential):
|
| 98 |
+
def __init__(self, backbone, position_embedding):
|
| 99 |
+
super().__init__(backbone, position_embedding)
|
| 100 |
+
|
| 101 |
+
def forward(self, tensor_list: NestedTensor):
|
| 102 |
+
xs = self[0](tensor_list)
|
| 103 |
+
out: List[NestedTensor] = []
|
| 104 |
+
pos = []
|
| 105 |
+
for name, x in xs.items():
|
| 106 |
+
out.append(x)
|
| 107 |
+
# position encoding
|
| 108 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
| 109 |
+
|
| 110 |
+
return out, pos
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_backbone(args):
|
| 114 |
+
position_embedding = build_position_encoding(args)
|
| 115 |
+
train_backbone = args.lr_backbone > 0
|
| 116 |
+
return_interm_layers = True
|
| 117 |
+
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
| 118 |
+
model = Joiner(backbone, position_embedding)
|
| 119 |
+
model.num_channels = backbone.num_channels
|
| 120 |
+
return model
|
models/letr.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file provides coarse stage LETR definition
|
| 3 |
+
Modified based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from .misc import (NestedTensor, nested_tensor_from_tensor_list,
|
| 10 |
+
accuracy, get_world_size, interpolate,
|
| 11 |
+
is_dist_avail_and_initialized)
|
| 12 |
+
|
| 13 |
+
from .backbone import build_backbone
|
| 14 |
+
from .matcher import build_matcher
|
| 15 |
+
from .transformer import build_transformer
|
| 16 |
+
from .letr_stack import LETRstack
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
class LETR(nn.Module):
|
| 20 |
+
""" This is the LETR module that performs object detection """
|
| 21 |
+
def __init__(self, backbone, transformer, num_classes, num_queries, args, aux_loss=False):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_queries = num_queries
|
| 24 |
+
self.transformer = transformer
|
| 25 |
+
hidden_dim = transformer.d_model
|
| 26 |
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
| 27 |
+
|
| 28 |
+
self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 29 |
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
| 30 |
+
|
| 31 |
+
channel = [256, 512, 1024, 2048]
|
| 32 |
+
self.input_proj = nn.Conv2d(channel[args.layer1_num], hidden_dim, kernel_size=1)
|
| 33 |
+
|
| 34 |
+
self.backbone = backbone
|
| 35 |
+
self.aux_loss = aux_loss
|
| 36 |
+
self.args = args
|
| 37 |
+
|
| 38 |
+
def forward(self, samples, postprocessors=None, targets=None, criterion=None):
|
| 39 |
+
if isinstance(samples, (list, torch.Tensor)):
|
| 40 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 41 |
+
|
| 42 |
+
features, pos = self.backbone(samples)
|
| 43 |
+
|
| 44 |
+
num = self.args.layer1_num
|
| 45 |
+
src, mask = features[num].decompose()
|
| 46 |
+
assert mask is not None
|
| 47 |
+
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[num])[0]
|
| 48 |
+
|
| 49 |
+
outputs_class = self.class_embed(hs)
|
| 50 |
+
outputs_coord = self.lines_embed(hs).sigmoid()
|
| 51 |
+
out = {'pred_logits': outputs_class[-1], 'pred_lines': outputs_coord[-1]}
|
| 52 |
+
if self.aux_loss:
|
| 53 |
+
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
@torch.jit.unused
|
| 57 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 58 |
+
return [{'pred_logits': a, 'pred_lines': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
| 59 |
+
|
| 60 |
+
class SetCriterion(nn.Module):
|
| 61 |
+
|
| 62 |
+
def __init__(self, num_classes, weight_dict, eos_coef, losses, args, matcher=None):
|
| 63 |
+
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.num_classes = num_classes
|
| 66 |
+
|
| 67 |
+
self.matcher = matcher
|
| 68 |
+
|
| 69 |
+
self.weight_dict = weight_dict
|
| 70 |
+
self.eos_coef = eos_coef
|
| 71 |
+
self.losses = losses
|
| 72 |
+
empty_weight = torch.ones(self.num_classes + 1)
|
| 73 |
+
empty_weight[-1] = self.eos_coef
|
| 74 |
+
self.register_buffer('empty_weight', empty_weight)
|
| 75 |
+
self.args = args
|
| 76 |
+
try:
|
| 77 |
+
self.args.label_loss_params = eval(self.args.label_loss_params) # Convert the string to dict.
|
| 78 |
+
except:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
def loss_lines_labels(self, outputs, targets, num_items, log=False, origin_indices=None):
|
| 82 |
+
"""Classification loss (NLL)
|
| 83 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_lines]
|
| 84 |
+
"""
|
| 85 |
+
assert 'pred_logits' in outputs
|
| 86 |
+
src_logits = outputs['pred_logits']
|
| 87 |
+
|
| 88 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
| 89 |
+
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, origin_indices)])
|
| 90 |
+
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
| 91 |
+
dtype=torch.int64, device=src_logits.device)
|
| 92 |
+
target_classes[idx] = target_classes_o
|
| 93 |
+
|
| 94 |
+
if self.args.label_loss_func == 'cross_entropy':
|
| 95 |
+
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
| 96 |
+
elif self.args.label_loss_func == 'focal_loss':
|
| 97 |
+
loss_ce = self.label_focal_loss(src_logits.transpose(1, 2), target_classes, self.empty_weight, **self.args.label_loss_params)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError()
|
| 100 |
+
|
| 101 |
+
losses = {'loss_ce': loss_ce}
|
| 102 |
+
return losses
|
| 103 |
+
|
| 104 |
+
def label_focal_loss(self, input, target, weight, gamma=2.0):
|
| 105 |
+
""" Focal loss for label prediction. """
|
| 106 |
+
# In our case, target has 2 classes: 0 for foreground (i.e. line) and 1 for background.
|
| 107 |
+
# The weight here can serve as the alpha hyperparameter in focal loss. However, in focal loss,
|
| 108 |
+
#
|
| 109 |
+
# Ref: https://github.com/facebookresearch/DETR/blob/699bf53f3e3ecd4f000007b8473eda6a08a8bed6/models/segmentation.py#L190
|
| 110 |
+
# Ref: https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7
|
| 111 |
+
|
| 112 |
+
# input shape: [batch size, #classes, #queries]
|
| 113 |
+
# target shape: [batch size, #queries]
|
| 114 |
+
# weight shape: [#classes]
|
| 115 |
+
|
| 116 |
+
prob = F.softmax(input, 1) # Shape: [batch size, #classes, #queries].
|
| 117 |
+
ce_loss = F.cross_entropy(input, target, weight, reduction='none') # Shape: [batch size, #queries].
|
| 118 |
+
p_t = prob[:,1,:] * target + prob[:,0,:] * (1 - target) # Shape: [batch size, #queries]. Note: prob[:,0,:] + prob[:,1,:] should be 1.
|
| 119 |
+
loss = ce_loss * ((1 - p_t) ** gamma)
|
| 120 |
+
loss = loss.mean() # Original label loss (i.e. cross entropy) does not consider the #lines, so we also do not consider that.
|
| 121 |
+
return loss
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def loss_cardinality(self, outputs, targets, num_items, origin_indices=None):
|
| 125 |
+
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty lines
|
| 126 |
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
| 127 |
+
"""
|
| 128 |
+
pred_logits = outputs['pred_logits']
|
| 129 |
+
device = pred_logits.device
|
| 130 |
+
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
|
| 131 |
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
| 132 |
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
| 133 |
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
| 134 |
+
losses = {'cardinality_error': card_err}
|
| 135 |
+
return losses
|
| 136 |
+
|
| 137 |
+
def loss_lines_POST(self, outputs, targets, num_items, origin_indices=None):
|
| 138 |
+
assert 'POST_pred_lines' in outputs
|
| 139 |
+
|
| 140 |
+
if outputs['POST_pred_lines'].shape[1] == 1000:
|
| 141 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
| 142 |
+
|
| 143 |
+
src_lines = outputs['POST_pred_lines'][idx]
|
| 144 |
+
|
| 145 |
+
else:
|
| 146 |
+
src_lines = outputs['POST_pred_lines'].squeeze(0)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
|
| 150 |
+
|
| 151 |
+
loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
|
| 152 |
+
|
| 153 |
+
losses = {}
|
| 154 |
+
losses['loss_line'] = loss_line.sum() / num_items
|
| 155 |
+
|
| 156 |
+
return losses
|
| 157 |
+
|
| 158 |
+
def loss_lines(self, outputs, targets, num_items, origin_indices=None):
|
| 159 |
+
assert 'pred_lines' in outputs
|
| 160 |
+
|
| 161 |
+
idx = self._get_src_permutation_idx(origin_indices)
|
| 162 |
+
|
| 163 |
+
src_lines = outputs['pred_lines'][idx]
|
| 164 |
+
target_lines = torch.cat([t['lines'][i] for t, (_, i) in zip(targets, origin_indices)], dim=0)
|
| 165 |
+
|
| 166 |
+
loss_line = F.l1_loss(src_lines, target_lines, reduction='none')
|
| 167 |
+
|
| 168 |
+
losses = {}
|
| 169 |
+
losses['loss_line'] = loss_line.sum() / num_items
|
| 170 |
+
|
| 171 |
+
return losses
|
| 172 |
+
|
| 173 |
+
def _get_src_permutation_idx(self, indices):
|
| 174 |
+
# permute predictions following indices
|
| 175 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
| 176 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
| 177 |
+
return batch_idx, src_idx
|
| 178 |
+
|
| 179 |
+
def _get_tgt_permutation_idx(self, indices):
|
| 180 |
+
# permute targets following indices
|
| 181 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
| 182 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
| 183 |
+
return batch_idx, tgt_idx
|
| 184 |
+
|
| 185 |
+
def get_loss(self, loss, outputs, targets, num_items, **kwargs):
|
| 186 |
+
|
| 187 |
+
loss_map = {
|
| 188 |
+
'POST_lines_labels': self.loss_lines_labels,
|
| 189 |
+
'POST_lines': self.loss_lines,
|
| 190 |
+
'lines_labels': self.loss_lines_labels,
|
| 191 |
+
'cardinality': self.loss_cardinality,
|
| 192 |
+
'lines': self.loss_lines,
|
| 193 |
+
}
|
| 194 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
| 195 |
+
return loss_map[loss](outputs, targets, num_items, **kwargs)
|
| 196 |
+
|
| 197 |
+
def forward(self, outputs, targets, origin_indices=None):
|
| 198 |
+
""" This performs the loss computation.
|
| 199 |
+
Parameters:
|
| 200 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
| 201 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
| 202 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
| 203 |
+
"""
|
| 204 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
origin_indices = self.matcher(outputs_without_aux, targets)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
num_items = sum(len(t["labels"]) for t in targets)
|
| 211 |
+
|
| 212 |
+
num_items = torch.as_tensor([num_items], dtype=torch.float, device=next(iter(outputs.values())).device)
|
| 213 |
+
if is_dist_avail_and_initialized():
|
| 214 |
+
torch.distributed.all_reduce(num_items)
|
| 215 |
+
num_items = torch.clamp(num_items / get_world_size(), min=1).item()
|
| 216 |
+
|
| 217 |
+
# Compute all the requested losses
|
| 218 |
+
losses = {}
|
| 219 |
+
for loss in self.losses:
|
| 220 |
+
losses.update(self.get_loss(loss, outputs, targets, num_items, origin_indices=origin_indices))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
| 224 |
+
aux_name = 'aux_outputs'
|
| 225 |
+
if aux_name in outputs:
|
| 226 |
+
for i, aux_outputs in enumerate(outputs[aux_name]):
|
| 227 |
+
|
| 228 |
+
origin_indices = self.matcher(aux_outputs, targets)
|
| 229 |
+
|
| 230 |
+
for loss in self.losses:
|
| 231 |
+
|
| 232 |
+
kwargs = {}
|
| 233 |
+
if loss == 'labels':
|
| 234 |
+
# Logging is enabled only for the last layer
|
| 235 |
+
kwargs = {'log': False}
|
| 236 |
+
|
| 237 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, num_items, origin_indices=origin_indices, **kwargs)
|
| 238 |
+
|
| 239 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
| 240 |
+
losses.update(l_dict)
|
| 241 |
+
|
| 242 |
+
return losses
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class PostProcess_Line(nn.Module):
|
| 246 |
+
|
| 247 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def forward(self, outputs, target_sizes, output_type):
|
| 250 |
+
""" Perform the computation
|
| 251 |
+
Parameters:
|
| 252 |
+
outputs: raw outputs of the model
|
| 253 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
| 254 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
| 255 |
+
For visualization, this should be the image size after data augment, but before padding
|
| 256 |
+
"""
|
| 257 |
+
if output_type == "prediction":
|
| 258 |
+
out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
|
| 259 |
+
|
| 260 |
+
assert len(out_logits) == len(target_sizes)
|
| 261 |
+
assert target_sizes.shape[1] == 2
|
| 262 |
+
|
| 263 |
+
prob = F.softmax(out_logits, -1)
|
| 264 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 265 |
+
|
| 266 |
+
# convert to [x0, y0, x1, y1] format
|
| 267 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 268 |
+
|
| 269 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 270 |
+
lines = out_line * scale_fct[:, None, :]
|
| 271 |
+
|
| 272 |
+
results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
|
| 273 |
+
elif output_type == "prediction_POST":
|
| 274 |
+
out_logits, out_line = outputs['pred_logits'], outputs['POST_pred_lines']
|
| 275 |
+
|
| 276 |
+
assert len(out_logits) == len(target_sizes)
|
| 277 |
+
assert target_sizes.shape[1] == 2
|
| 278 |
+
|
| 279 |
+
prob = F.softmax(out_logits, -1)
|
| 280 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 281 |
+
|
| 282 |
+
# convert to [x0, y0, x1, y1] format
|
| 283 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 284 |
+
|
| 285 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 286 |
+
lines = out_line * scale_fct[:, None, :]
|
| 287 |
+
|
| 288 |
+
results = [{'scores': s, 'labels': l, 'lines': b} for s, l, b in zip(scores, labels, lines)]
|
| 289 |
+
elif output_type == "ground_truth":
|
| 290 |
+
results = []
|
| 291 |
+
for dic in outputs:
|
| 292 |
+
lines = dic['lines']
|
| 293 |
+
img_h, img_w = target_sizes.unbind(1)
|
| 294 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
| 295 |
+
scaled_lines = lines * scale_fct
|
| 296 |
+
results.append({'labels': dic['labels'], 'lines': scaled_lines, 'image_id': dic['image_id']})
|
| 297 |
+
else:
|
| 298 |
+
assert False
|
| 299 |
+
return results
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class MLP(nn.Module):
|
| 303 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 304 |
+
|
| 305 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.num_layers = num_layers
|
| 308 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 309 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 310 |
+
|
| 311 |
+
def forward(self, x):
|
| 312 |
+
for i, layer in enumerate(self.layers):
|
| 313 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def build(args):
|
| 318 |
+
num_classes = 1
|
| 319 |
+
|
| 320 |
+
device = torch.device(args.device)
|
| 321 |
+
|
| 322 |
+
backbone = build_backbone(args)
|
| 323 |
+
|
| 324 |
+
transformer = build_transformer(args)
|
| 325 |
+
|
| 326 |
+
model = LETR(
|
| 327 |
+
backbone,
|
| 328 |
+
transformer,
|
| 329 |
+
num_classes=num_classes,
|
| 330 |
+
num_queries=args.num_queries,
|
| 331 |
+
args=args,
|
| 332 |
+
aux_loss=args.aux_loss,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if args.LETRpost:
|
| 336 |
+
model = LETRstack(model, args=args)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
matcher = build_matcher(args, type='origin_line')
|
| 340 |
+
|
| 341 |
+
losses = []
|
| 342 |
+
weight_dict = {}
|
| 343 |
+
|
| 344 |
+
if args.LETRpost:
|
| 345 |
+
losses.append('POST_lines_labels')
|
| 346 |
+
losses.append('POST_lines')
|
| 347 |
+
weight_dict['loss_ce'] = 1
|
| 348 |
+
weight_dict['loss_line'] = args.line_loss_coef
|
| 349 |
+
aux_layer = args.second_dec_layers
|
| 350 |
+
else:
|
| 351 |
+
losses.append('lines_labels')
|
| 352 |
+
losses.append('lines')
|
| 353 |
+
weight_dict['loss_ce'] = 1
|
| 354 |
+
weight_dict['loss_line'] = args.line_loss_coef
|
| 355 |
+
aux_layer = args.dec_layers
|
| 356 |
+
|
| 357 |
+
if args.aux_loss:
|
| 358 |
+
aux_weight_dict = {}
|
| 359 |
+
for i in range(aux_layer - 1):
|
| 360 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
|
| 361 |
+
weight_dict.update(aux_weight_dict)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
criterion = SetCriterion(num_classes, weight_dict=weight_dict, eos_coef=args.eos_coef, losses=losses, args=args, matcher=matcher)
|
| 365 |
+
criterion.to(device)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
postprocessors = {'line': PostProcess_Line()}
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
return model, criterion, postprocessors
|
models/letr_stack.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file provides fine stage LETR definition
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
import io
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from .misc import NestedTensor, nested_tensor_from_tensor_list
|
| 15 |
+
import copy
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LETRstack(nn.Module):
|
| 19 |
+
def __init__(self, letr, args):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.letr = letr
|
| 22 |
+
self.backbone = self.letr.backbone
|
| 23 |
+
|
| 24 |
+
if args.layer1_frozen:
|
| 25 |
+
# freeze backbone, encoder, decoder
|
| 26 |
+
for n, p in self.named_parameters():
|
| 27 |
+
p.requires_grad_(False)
|
| 28 |
+
|
| 29 |
+
hidden_dim, nheads = letr.transformer.d_model, letr.transformer.nhead
|
| 30 |
+
|
| 31 |
+
# add new input proj layer
|
| 32 |
+
channel = [256, 512, 1024, 2048]
|
| 33 |
+
self.input_proj = nn.Conv2d(channel[args.layer2_num], hidden_dim, kernel_size=1)
|
| 34 |
+
|
| 35 |
+
# add new transformer encoder decoder
|
| 36 |
+
self.transformer = Transformer( d_model=args.second_hidden_dim, dropout=args.second_dropout, nhead=args.second_nheads,
|
| 37 |
+
dim_feedforward=args.second_dim_feedforward, num_encoder_layers=args.second_enc_layers,
|
| 38 |
+
num_decoder_layers=args.second_dec_layers, normalize_before=args.second_pre_norm, return_intermediate_dec=True,)
|
| 39 |
+
|
| 40 |
+
# output layer
|
| 41 |
+
self.class_embed = nn.Linear(hidden_dim, 1 + 1)
|
| 42 |
+
self.lines_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
self.aux_loss=args.aux_loss
|
| 46 |
+
self.args = args
|
| 47 |
+
|
| 48 |
+
def forward(self, samples, postprocessors=None, targets=None, criterion=None):
|
| 49 |
+
if isinstance(samples, (list, torch.Tensor)):
|
| 50 |
+
samples = nested_tensor_from_tensor_list(samples)
|
| 51 |
+
|
| 52 |
+
# backbone
|
| 53 |
+
features, pos = self.letr.backbone(samples)
|
| 54 |
+
|
| 55 |
+
# layer 1
|
| 56 |
+
l1_num = self.args.layer1_num
|
| 57 |
+
src1, mask1 = features[l1_num].decompose()
|
| 58 |
+
assert mask1 is not None
|
| 59 |
+
|
| 60 |
+
# layer 1 transformer
|
| 61 |
+
hs1, _ = self.letr.transformer(self.letr.input_proj(src1), mask1, self.letr.query_embed.weight, pos[l1_num])
|
| 62 |
+
|
| 63 |
+
# layer 2
|
| 64 |
+
l2_num = self.args.layer2_num
|
| 65 |
+
src2, mask2 = features[l2_num].decompose()
|
| 66 |
+
src2 = self.input_proj(src2)
|
| 67 |
+
|
| 68 |
+
# layer 2 transformer
|
| 69 |
+
hs2, memory, _ = self.transformer(src2, mask2, hs1[-1], pos[l2_num])
|
| 70 |
+
|
| 71 |
+
outputs_class = self.class_embed(hs2)
|
| 72 |
+
outputs_coord = self.lines_embed(hs2).sigmoid()
|
| 73 |
+
out = {}
|
| 74 |
+
out["pred_logits"] = outputs_class[-1]
|
| 75 |
+
out["pred_lines"] = outputs_coord[-1]
|
| 76 |
+
|
| 77 |
+
if self.aux_loss:
|
| 78 |
+
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
|
| 79 |
+
|
| 80 |
+
return out, None
|
| 81 |
+
|
| 82 |
+
@torch.jit.unused
|
| 83 |
+
def _set_aux_loss(self, outputs_class, outputs_coord):
|
| 84 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 85 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 86 |
+
# as a dict having both a Tensor and a list.
|
| 87 |
+
return [{'pred_logits': a, 'pred_lines': b}
|
| 88 |
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
| 89 |
+
|
| 90 |
+
@torch.jit.unused
|
| 91 |
+
def _set_aux_loss_POST(self, outputs_class, outputs_coord):
|
| 92 |
+
# this is a workaround to make torchscript happy, as torchscript
|
| 93 |
+
# doesn't support dictionary with non-homogeneous values, such
|
| 94 |
+
# as a dict having both a Tensor and a list.
|
| 95 |
+
return [{'POST_pred_lines': b} for b in outputs_coord[:-1]]
|
| 96 |
+
|
| 97 |
+
def _expand(tensor, length: int):
|
| 98 |
+
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
| 99 |
+
|
| 100 |
+
class MLP(nn.Module):
|
| 101 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.num_layers = num_layers
|
| 106 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 107 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
for i, layer in enumerate(self.layers):
|
| 111 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Transformer(nn.Module):
|
| 116 |
+
|
| 117 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
| 118 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
| 119 |
+
activation="relu", normalize_before=False,
|
| 120 |
+
return_intermediate_dec=False):
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
| 124 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 125 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 126 |
+
|
| 127 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
|
| 128 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 129 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 130 |
+
return_intermediate=return_intermediate_dec)
|
| 131 |
+
|
| 132 |
+
self._reset_parameters()
|
| 133 |
+
|
| 134 |
+
self.d_model = d_model
|
| 135 |
+
self.nhead = nhead
|
| 136 |
+
|
| 137 |
+
def _reset_parameters(self):
|
| 138 |
+
for p in self.parameters():
|
| 139 |
+
if p.dim() > 1:
|
| 140 |
+
nn.init.xavier_uniform_(p)
|
| 141 |
+
|
| 142 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
| 143 |
+
# flatten NxCxHxW to HWxNxC
|
| 144 |
+
bs, c, h, w = src.shape
|
| 145 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 146 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
| 147 |
+
mask = mask.flatten(1)
|
| 148 |
+
|
| 149 |
+
query_embed = query_embed.permute(1, 0, 2)
|
| 150 |
+
tgt = torch.zeros_like(query_embed)
|
| 151 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
| 152 |
+
hs, attn_output_weights = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
| 153 |
+
return hs.transpose(1, 2), memory, attn_output_weights
|
| 154 |
+
|
| 155 |
+
class TransformerEncoder(nn.Module):
|
| 156 |
+
|
| 157 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 160 |
+
self.num_layers = num_layers
|
| 161 |
+
self.norm = norm
|
| 162 |
+
|
| 163 |
+
def forward(self, src,
|
| 164 |
+
mask: Optional[Tensor] = None,
|
| 165 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 166 |
+
pos: Optional[Tensor] = None):
|
| 167 |
+
output = src
|
| 168 |
+
|
| 169 |
+
for layer in self.layers:
|
| 170 |
+
output = layer(output, src_mask=mask,
|
| 171 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
| 172 |
+
|
| 173 |
+
if self.norm is not None:
|
| 174 |
+
output = self.norm(output)
|
| 175 |
+
|
| 176 |
+
return output
|
| 177 |
+
|
| 178 |
+
class TransformerDecoder(nn.Module):
|
| 179 |
+
|
| 180 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 183 |
+
self.num_layers = num_layers
|
| 184 |
+
self.norm = norm
|
| 185 |
+
self.return_intermediate = return_intermediate
|
| 186 |
+
|
| 187 |
+
def forward(self, tgt, memory,
|
| 188 |
+
tgt_mask: Optional[Tensor] = None,
|
| 189 |
+
memory_mask: Optional[Tensor] = None,
|
| 190 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 191 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 192 |
+
pos: Optional[Tensor] = None,
|
| 193 |
+
query_pos: Optional[Tensor] = None):
|
| 194 |
+
output = tgt
|
| 195 |
+
|
| 196 |
+
intermediate = []
|
| 197 |
+
attn_output_weights_list = []
|
| 198 |
+
for layer in self.layers:
|
| 199 |
+
output, attn_output_weights = layer(output, memory, tgt_mask=tgt_mask,
|
| 200 |
+
memory_mask=memory_mask,
|
| 201 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 202 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 203 |
+
pos=pos, query_pos=query_pos)
|
| 204 |
+
if self.return_intermediate:
|
| 205 |
+
intermediate.append(self.norm(output))
|
| 206 |
+
attn_output_weights_list.append(attn_output_weights)
|
| 207 |
+
if self.norm is not None:
|
| 208 |
+
output = self.norm(output)
|
| 209 |
+
if self.return_intermediate:
|
| 210 |
+
intermediate.pop()
|
| 211 |
+
intermediate.append(output)
|
| 212 |
+
|
| 213 |
+
if self.return_intermediate:
|
| 214 |
+
return torch.stack(intermediate), attn_output_weights_list
|
| 215 |
+
|
| 216 |
+
return output.unsqueeze(0), attn_output_weights
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class TransformerEncoderLayer(nn.Module):
|
| 220 |
+
|
| 221 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 222 |
+
activation="relu", normalize_before=False):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 225 |
+
# Implementation of Feedforward model
|
| 226 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 227 |
+
self.dropout = nn.Dropout(dropout)
|
| 228 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 229 |
+
|
| 230 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 231 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 232 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 233 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 234 |
+
|
| 235 |
+
self.activation = _get_activation_fn(activation)
|
| 236 |
+
self.normalize_before = normalize_before
|
| 237 |
+
|
| 238 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 239 |
+
return tensor if pos is None else tensor + pos
|
| 240 |
+
|
| 241 |
+
def forward_post(self,
|
| 242 |
+
src,
|
| 243 |
+
src_mask: Optional[Tensor] = None,
|
| 244 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 245 |
+
pos: Optional[Tensor] = None):
|
| 246 |
+
q = k = self.with_pos_embed(src, pos)
|
| 247 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
| 248 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 249 |
+
src = src + self.dropout1(src2)
|
| 250 |
+
src = self.norm1(src)
|
| 251 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 252 |
+
src = src + self.dropout2(src2)
|
| 253 |
+
src = self.norm2(src)
|
| 254 |
+
return src
|
| 255 |
+
|
| 256 |
+
def forward_pre(self, src,
|
| 257 |
+
src_mask: Optional[Tensor] = None,
|
| 258 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 259 |
+
pos: Optional[Tensor] = None):
|
| 260 |
+
src2 = self.norm1(src)
|
| 261 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 262 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
| 263 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 264 |
+
src = src + self.dropout1(src2)
|
| 265 |
+
src2 = self.norm2(src)
|
| 266 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 267 |
+
src = src + self.dropout2(src2)
|
| 268 |
+
return src
|
| 269 |
+
|
| 270 |
+
def forward(self, src,
|
| 271 |
+
src_mask: Optional[Tensor] = None,
|
| 272 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 273 |
+
pos: Optional[Tensor] = None):
|
| 274 |
+
if self.normalize_before:
|
| 275 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 276 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class TransformerDecoderLayer(nn.Module):
|
| 280 |
+
|
| 281 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 282 |
+
activation="relu", normalize_before=False):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 285 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 286 |
+
# Implementation of Feedforward model
|
| 287 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 288 |
+
self.dropout = nn.Dropout(dropout)
|
| 289 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 290 |
+
|
| 291 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 292 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 293 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 294 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 295 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 296 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 297 |
+
|
| 298 |
+
self.activation = _get_activation_fn(activation)
|
| 299 |
+
self.normalize_before = normalize_before
|
| 300 |
+
|
| 301 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 302 |
+
return tensor if pos is None else tensor + pos
|
| 303 |
+
|
| 304 |
+
def forward_post(self, tgt, memory,
|
| 305 |
+
tgt_mask: Optional[Tensor] = None,
|
| 306 |
+
memory_mask: Optional[Tensor] = None,
|
| 307 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 308 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 309 |
+
pos: Optional[Tensor] = None,
|
| 310 |
+
query_pos: Optional[Tensor] = None):
|
| 311 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 312 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 313 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 314 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 315 |
+
tgt = self.norm1(tgt)
|
| 316 |
+
tgt2, attn_output_weights = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 317 |
+
key=self.with_pos_embed(memory, pos),
|
| 318 |
+
value=memory, attn_mask=memory_mask,
|
| 319 |
+
key_padding_mask=memory_key_padding_mask)
|
| 320 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 321 |
+
tgt = self.norm2(tgt)
|
| 322 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 323 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 324 |
+
tgt = self.norm3(tgt)
|
| 325 |
+
return tgt, attn_output_weights
|
| 326 |
+
|
| 327 |
+
def forward_pre(self, tgt, memory,
|
| 328 |
+
tgt_mask: Optional[Tensor] = None,
|
| 329 |
+
memory_mask: Optional[Tensor] = None,
|
| 330 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 331 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 332 |
+
pos: Optional[Tensor] = None,
|
| 333 |
+
query_pos: Optional[Tensor] = None):
|
| 334 |
+
tgt2 = self.norm1(tgt)
|
| 335 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 336 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 337 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 338 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 339 |
+
tgt2 = self.norm2(tgt)
|
| 340 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 341 |
+
key=self.with_pos_embed(memory, pos),
|
| 342 |
+
value=memory, attn_mask=memory_mask,
|
| 343 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 344 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 345 |
+
tgt2 = self.norm3(tgt)
|
| 346 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 347 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 348 |
+
return tgt
|
| 349 |
+
|
| 350 |
+
def forward(self, tgt, memory,
|
| 351 |
+
tgt_mask: Optional[Tensor] = None,
|
| 352 |
+
memory_mask: Optional[Tensor] = None,
|
| 353 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 354 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 355 |
+
pos: Optional[Tensor] = None,
|
| 356 |
+
query_pos: Optional[Tensor] = None):
|
| 357 |
+
if self.normalize_before:
|
| 358 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 359 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 360 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 361 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _get_clones(module, N):
|
| 365 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 366 |
+
|
| 367 |
+
def _get_activation_fn(activation):
|
| 368 |
+
"""Return an activation function given a string"""
|
| 369 |
+
if activation == "relu":
|
| 370 |
+
return F.relu
|
| 371 |
+
if activation == "gelu":
|
| 372 |
+
return F.gelu
|
| 373 |
+
if activation == "glu":
|
| 374 |
+
return F.glu
|
| 375 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 376 |
+
|
models/matcher.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modules to compute the matching cost and solve the corresponding LSAP.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from scipy.optimize import linear_sum_assignment
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
class HungarianMatcher_Line(nn.Module):
|
| 9 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
| 10 |
+
|
| 11 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
| 12 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
| 13 |
+
while the others are un-matched (and thus treated as non-objects).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, cost_class: float = 1, cost_line: float = 1):
|
| 17 |
+
"""Creates the matcher
|
| 18 |
+
|
| 19 |
+
Params:
|
| 20 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
| 21 |
+
cost_line: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
| 22 |
+
"""
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.cost_class = cost_class
|
| 25 |
+
self.cost_line = cost_line
|
| 26 |
+
assert cost_class != 0 or cost_line != 0, "all costs cant be 0"
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def forward(self, outputs, targets):
|
| 30 |
+
""" Performs the matching
|
| 31 |
+
|
| 32 |
+
Params:
|
| 33 |
+
outputs: This is a dict that contains at least these entries:
|
| 34 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 35 |
+
"pred_lines": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
| 36 |
+
|
| 37 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 38 |
+
"labels": Tensor of dim [num_target_lines] (where num_target_lines is the number of ground-truth
|
| 39 |
+
objects in the target) containing the class labels
|
| 40 |
+
"lines": Tensor of dim [num_target_lines, 4] containing the target box coordinates
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 44 |
+
- index_i is the indices of the selected predictions (in order)
|
| 45 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 46 |
+
For each batch element, it holds:
|
| 47 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_lines)
|
| 48 |
+
"""
|
| 49 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
| 50 |
+
|
| 51 |
+
# We flatten to compute the cost matrices in a batch
|
| 52 |
+
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
| 53 |
+
|
| 54 |
+
out_line = outputs["pred_lines"].flatten(0, 1) # [batch_size * num_queries, 4]
|
| 55 |
+
tgt_line = torch.cat([v["lines"] for v in targets])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Also concat the target labels and lines
|
| 59 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
| 60 |
+
|
| 61 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
| 62 |
+
# but approximate it in 1 - proba[target class].
|
| 63 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
| 64 |
+
cost_class = -out_prob[:, tgt_ids]
|
| 65 |
+
|
| 66 |
+
# Compute the L1 cost between lines
|
| 67 |
+
cost_line = torch.cdist(out_line, tgt_line, p=1)
|
| 68 |
+
|
| 69 |
+
# Final cost matrix
|
| 70 |
+
C = self.cost_line * cost_line + self.cost_class * cost_class
|
| 71 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 72 |
+
|
| 73 |
+
sizes = [len(v["lines"]) for v in targets]
|
| 74 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
| 75 |
+
|
| 76 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_matcher(args, type=None):
|
| 81 |
+
return HungarianMatcher_Line(cost_class=args.set_cost_class, cost_line=args.set_cost_line)
|
models/misc.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Misc functions, including distributed helpers.
|
| 4 |
+
|
| 5 |
+
Mostly copy-paste from torchvision references.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
from collections import defaultdict, deque
|
| 11 |
+
import datetime
|
| 12 |
+
import pickle
|
| 13 |
+
from typing import Optional, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
| 20 |
+
import torchvision
|
| 21 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
| 22 |
+
from torchvision.ops import _new_empty_tensor
|
| 23 |
+
from torchvision.ops.misc import _output_size
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SmoothedValue(object):
|
| 27 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 28 |
+
window or the global series average.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, window_size=20, fmt=None):
|
| 32 |
+
if fmt is None:
|
| 33 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 34 |
+
self.deque = deque(maxlen=window_size)
|
| 35 |
+
self.total = 0.0
|
| 36 |
+
self.count = 0
|
| 37 |
+
self.fmt = fmt
|
| 38 |
+
|
| 39 |
+
def update(self, value, n=1):
|
| 40 |
+
self.deque.append(value)
|
| 41 |
+
self.count += n
|
| 42 |
+
self.total += value * n
|
| 43 |
+
|
| 44 |
+
def synchronize_between_processes(self):
|
| 45 |
+
"""
|
| 46 |
+
Warning: does not synchronize the deque!
|
| 47 |
+
"""
|
| 48 |
+
if not is_dist_avail_and_initialized():
|
| 49 |
+
return
|
| 50 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 51 |
+
dist.barrier()
|
| 52 |
+
dist.all_reduce(t)
|
| 53 |
+
t = t.tolist()
|
| 54 |
+
self.count = int(t[0])
|
| 55 |
+
self.total = t[1]
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def median(self):
|
| 59 |
+
d = torch.tensor(list(self.deque))
|
| 60 |
+
return d.median().item()
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def avg(self):
|
| 64 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 65 |
+
return d.mean().item()
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def global_avg(self):
|
| 69 |
+
return self.total / self.count
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def max(self):
|
| 73 |
+
return max(self.deque)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def value(self):
|
| 77 |
+
return self.deque[-1]
|
| 78 |
+
|
| 79 |
+
def __str__(self):
|
| 80 |
+
return self.fmt.format(
|
| 81 |
+
median=self.median,
|
| 82 |
+
avg=self.avg,
|
| 83 |
+
global_avg=self.global_avg,
|
| 84 |
+
max=self.max,
|
| 85 |
+
value=self.value)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def all_gather(data):
|
| 89 |
+
"""
|
| 90 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| 91 |
+
Args:
|
| 92 |
+
data: any picklable object
|
| 93 |
+
Returns:
|
| 94 |
+
list[data]: list of data gathered from each rank
|
| 95 |
+
"""
|
| 96 |
+
world_size = get_world_size()
|
| 97 |
+
if world_size == 1:
|
| 98 |
+
return [data]
|
| 99 |
+
|
| 100 |
+
# serialized to a Tensor
|
| 101 |
+
buffer = pickle.dumps(data)
|
| 102 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 103 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
| 104 |
+
|
| 105 |
+
# obtain Tensor size of each rank
|
| 106 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
| 107 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
| 108 |
+
dist.all_gather(size_list, local_size)
|
| 109 |
+
size_list = [int(size.item()) for size in size_list]
|
| 110 |
+
max_size = max(size_list)
|
| 111 |
+
|
| 112 |
+
# receiving Tensor from all ranks
|
| 113 |
+
# we pad the tensor because torch all_gather does not support
|
| 114 |
+
# gathering tensors of different shapes
|
| 115 |
+
tensor_list = []
|
| 116 |
+
for _ in size_list:
|
| 117 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
| 118 |
+
if local_size != max_size:
|
| 119 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
| 120 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
| 121 |
+
dist.all_gather(tensor_list, tensor)
|
| 122 |
+
|
| 123 |
+
data_list = []
|
| 124 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 125 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 126 |
+
data_list.append(pickle.loads(buffer))
|
| 127 |
+
|
| 128 |
+
return data_list
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def reduce_dict(input_dict, average=True):
|
| 132 |
+
"""
|
| 133 |
+
Args:
|
| 134 |
+
input_dict (dict): all the values will be reduced
|
| 135 |
+
average (bool): whether to do average or sum
|
| 136 |
+
Reduce the values in the dictionary from all processes so that all processes
|
| 137 |
+
have the averaged results. Returns a dict with the same fields as
|
| 138 |
+
input_dict, after reduction.
|
| 139 |
+
"""
|
| 140 |
+
world_size = get_world_size()
|
| 141 |
+
if world_size < 2:
|
| 142 |
+
return input_dict
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
names = []
|
| 145 |
+
values = []
|
| 146 |
+
# sort the keys so that they are consistent across processes
|
| 147 |
+
for k in sorted(input_dict.keys()):
|
| 148 |
+
names.append(k)
|
| 149 |
+
values.append(input_dict[k])
|
| 150 |
+
values = torch.stack(values, dim=0)
|
| 151 |
+
dist.all_reduce(values)
|
| 152 |
+
if average:
|
| 153 |
+
values /= world_size
|
| 154 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 155 |
+
return reduced_dict
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MetricLogger(object):
|
| 159 |
+
def __init__(self, delimiter="\t"):
|
| 160 |
+
self.meters = defaultdict(SmoothedValue)
|
| 161 |
+
self.delimiter = delimiter
|
| 162 |
+
|
| 163 |
+
def update(self, **kwargs):
|
| 164 |
+
for k, v in kwargs.items():
|
| 165 |
+
if isinstance(v, torch.Tensor):
|
| 166 |
+
v = v.item()
|
| 167 |
+
assert isinstance(v, (float, int))
|
| 168 |
+
self.meters[k].update(v)
|
| 169 |
+
|
| 170 |
+
def __getattr__(self, attr):
|
| 171 |
+
if attr in self.meters:
|
| 172 |
+
return self.meters[attr]
|
| 173 |
+
if attr in self.__dict__:
|
| 174 |
+
return self.__dict__[attr]
|
| 175 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 176 |
+
type(self).__name__, attr))
|
| 177 |
+
|
| 178 |
+
def __str__(self):
|
| 179 |
+
loss_str = []
|
| 180 |
+
for name, meter in self.meters.items():
|
| 181 |
+
loss_str.append(
|
| 182 |
+
"{}: {}".format(name, str(meter))
|
| 183 |
+
)
|
| 184 |
+
return self.delimiter.join(loss_str)
|
| 185 |
+
|
| 186 |
+
def synchronize_between_processes(self):
|
| 187 |
+
for meter in self.meters.values():
|
| 188 |
+
meter.synchronize_between_processes()
|
| 189 |
+
|
| 190 |
+
def add_meter(self, name, meter):
|
| 191 |
+
self.meters[name] = meter
|
| 192 |
+
|
| 193 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 194 |
+
i = 0
|
| 195 |
+
if not header:
|
| 196 |
+
header = ''
|
| 197 |
+
start_time = time.time()
|
| 198 |
+
end = time.time()
|
| 199 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 200 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 201 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 202 |
+
if torch.cuda.is_available():
|
| 203 |
+
log_msg = self.delimiter.join([
|
| 204 |
+
header,
|
| 205 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 206 |
+
'eta: {eta}',
|
| 207 |
+
'{meters}',
|
| 208 |
+
'time: {time}',
|
| 209 |
+
'data: {data}',
|
| 210 |
+
'max mem: {memory:.0f}'
|
| 211 |
+
])
|
| 212 |
+
else:
|
| 213 |
+
log_msg = self.delimiter.join([
|
| 214 |
+
header,
|
| 215 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 216 |
+
'eta: {eta}',
|
| 217 |
+
'{meters}',
|
| 218 |
+
'time: {time}',
|
| 219 |
+
'data: {data}'
|
| 220 |
+
])
|
| 221 |
+
MB = 1024.0 * 1024.0
|
| 222 |
+
for obj in iterable:
|
| 223 |
+
data_time.update(time.time() - end)
|
| 224 |
+
yield obj
|
| 225 |
+
iter_time.update(time.time() - end)
|
| 226 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 227 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 228 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 229 |
+
if torch.cuda.is_available():
|
| 230 |
+
print(log_msg.format(
|
| 231 |
+
i, len(iterable), eta=eta_string,
|
| 232 |
+
meters=str(self),
|
| 233 |
+
time=str(iter_time), data=str(data_time),
|
| 234 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 235 |
+
else:
|
| 236 |
+
print(log_msg.format(
|
| 237 |
+
i, len(iterable), eta=eta_string,
|
| 238 |
+
meters=str(self),
|
| 239 |
+
time=str(iter_time), data=str(data_time)))
|
| 240 |
+
i += 1
|
| 241 |
+
end = time.time()
|
| 242 |
+
total_time = time.time() - start_time
|
| 243 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 244 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 245 |
+
header, total_time_str, total_time / len(iterable)))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_sha():
|
| 249 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 250 |
+
|
| 251 |
+
def _run(command):
|
| 252 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
| 253 |
+
sha = 'N/A'
|
| 254 |
+
diff = "clean"
|
| 255 |
+
branch = 'N/A'
|
| 256 |
+
try:
|
| 257 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
| 258 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
| 259 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
| 260 |
+
diff = "has uncommited changes" if diff else "clean"
|
| 261 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
| 262 |
+
except Exception:
|
| 263 |
+
pass
|
| 264 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 265 |
+
return message
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def collate_fn(batch):
|
| 269 |
+
batch = list(zip(*batch))
|
| 270 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
| 271 |
+
return tuple(batch)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _max_by_axis(the_list):
|
| 275 |
+
# type: (List[List[int]]) -> List[int]
|
| 276 |
+
maxes = the_list[0]
|
| 277 |
+
for sublist in the_list[1:]:
|
| 278 |
+
for index, item in enumerate(sublist):
|
| 279 |
+
maxes[index] = max(maxes[index], item)
|
| 280 |
+
return maxes
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
| 284 |
+
# TODO make this more general
|
| 285 |
+
if tensor_list[0].ndim == 3:
|
| 286 |
+
if torchvision._is_tracing():
|
| 287 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
| 288 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
| 289 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
| 290 |
+
|
| 291 |
+
# TODO make it support different-sized images
|
| 292 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
| 293 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
| 294 |
+
batch_shape = [len(tensor_list)] + max_size
|
| 295 |
+
b, c, h, w = batch_shape
|
| 296 |
+
dtype = tensor_list[0].dtype
|
| 297 |
+
device = tensor_list[0].device
|
| 298 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
| 299 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
| 300 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
| 301 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 302 |
+
m[: img.shape[1], :img.shape[2]] = False
|
| 303 |
+
else:
|
| 304 |
+
raise ValueError('not supported')
|
| 305 |
+
return NestedTensor(tensor, mask)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
| 309 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
| 310 |
+
@torch.jit.unused
|
| 311 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list):
|
| 312 |
+
max_size = []
|
| 313 |
+
for i in range(tensor_list[0].dim()):
|
| 314 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
| 315 |
+
max_size.append(max_size_i)
|
| 316 |
+
max_size = tuple(max_size)
|
| 317 |
+
|
| 318 |
+
# work around for
|
| 319 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 320 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
| 321 |
+
# which is not yet supported in onnx
|
| 322 |
+
padded_imgs = []
|
| 323 |
+
padded_masks = []
|
| 324 |
+
for img in tensor_list:
|
| 325 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
| 326 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
| 327 |
+
padded_imgs.append(padded_img)
|
| 328 |
+
|
| 329 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
| 330 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
| 331 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
| 332 |
+
|
| 333 |
+
tensor = torch.stack(padded_imgs)
|
| 334 |
+
mask = torch.stack(padded_masks)
|
| 335 |
+
|
| 336 |
+
return NestedTensor(tensor, mask=mask)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class NestedTensor(object):
|
| 340 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
| 341 |
+
self.tensors = tensors
|
| 342 |
+
self.mask = mask
|
| 343 |
+
|
| 344 |
+
def to(self, device):
|
| 345 |
+
# type: (Device) -> NestedTensor # noqa
|
| 346 |
+
cast_tensor = self.tensors.to(device)
|
| 347 |
+
mask = self.mask
|
| 348 |
+
if mask is not None:
|
| 349 |
+
assert mask is not None
|
| 350 |
+
cast_mask = mask.to(device)
|
| 351 |
+
else:
|
| 352 |
+
cast_mask = None
|
| 353 |
+
return NestedTensor(cast_tensor, cast_mask)
|
| 354 |
+
|
| 355 |
+
def decompose(self):
|
| 356 |
+
return self.tensors, self.mask
|
| 357 |
+
|
| 358 |
+
def __repr__(self):
|
| 359 |
+
return str(self.tensors)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def setup_for_distributed(is_master):
|
| 363 |
+
"""
|
| 364 |
+
This function disables printing when not in master process
|
| 365 |
+
"""
|
| 366 |
+
import builtins as __builtin__
|
| 367 |
+
builtin_print = __builtin__.print
|
| 368 |
+
|
| 369 |
+
def print(*args, **kwargs):
|
| 370 |
+
force = kwargs.pop('force', False)
|
| 371 |
+
if is_master or force:
|
| 372 |
+
builtin_print(*args, **kwargs)
|
| 373 |
+
|
| 374 |
+
__builtin__.print = print
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def is_dist_avail_and_initialized():
|
| 378 |
+
if not dist.is_available():
|
| 379 |
+
return False
|
| 380 |
+
if not dist.is_initialized():
|
| 381 |
+
return False
|
| 382 |
+
return True
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def get_world_size():
|
| 386 |
+
if not is_dist_avail_and_initialized():
|
| 387 |
+
return 1
|
| 388 |
+
return dist.get_world_size()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def get_rank():
|
| 392 |
+
if not is_dist_avail_and_initialized():
|
| 393 |
+
return 0
|
| 394 |
+
return dist.get_rank()
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def is_main_process():
|
| 398 |
+
return get_rank() == 0
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def save_on_master(*args, **kwargs):
|
| 402 |
+
if is_main_process():
|
| 403 |
+
torch.save(*args, **kwargs)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def init_distributed_mode(args):
|
| 407 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 408 |
+
args.rank = int(os.environ["RANK"])
|
| 409 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 410 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 411 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 412 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 413 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 414 |
+
else:
|
| 415 |
+
print('Not using distributed mode')
|
| 416 |
+
args.distributed = False
|
| 417 |
+
return
|
| 418 |
+
|
| 419 |
+
args.distributed = True
|
| 420 |
+
|
| 421 |
+
torch.cuda.set_device(args.gpu)
|
| 422 |
+
args.dist_backend = 'nccl'
|
| 423 |
+
print('| distributed init (rank {}): {}'.format(
|
| 424 |
+
args.rank, args.dist_url), flush=True)
|
| 425 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 426 |
+
world_size=args.world_size, rank=args.rank)
|
| 427 |
+
torch.distributed.barrier()
|
| 428 |
+
setup_for_distributed(args.rank == 0)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
@torch.no_grad()
|
| 432 |
+
def accuracy(output, target, topk=(1,)):
|
| 433 |
+
"""Computes the precision@k for the specified values of k"""
|
| 434 |
+
if target.numel() == 0:
|
| 435 |
+
return [torch.zeros([], device=output.device)]
|
| 436 |
+
maxk = max(topk)
|
| 437 |
+
batch_size = target.size(0)
|
| 438 |
+
|
| 439 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 440 |
+
pred = pred.t()
|
| 441 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 442 |
+
|
| 443 |
+
res = []
|
| 444 |
+
for k in topk:
|
| 445 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
| 446 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 447 |
+
return res
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
| 451 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
| 452 |
+
"""
|
| 453 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
| 454 |
+
This will eventually be supported natively by PyTorch, and this
|
| 455 |
+
class can go away.
|
| 456 |
+
"""
|
| 457 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
| 458 |
+
if input.numel() > 0:
|
| 459 |
+
return torch.nn.functional.interpolate(
|
| 460 |
+
input, size, scale_factor, mode, align_corners
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
| 464 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
| 465 |
+
return _new_empty_tensor(input, output_shape)
|
| 466 |
+
else:
|
| 467 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
models/multi_head_attention.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file provides definition of multi head attention
|
| 3 |
+
|
| 4 |
+
borrowed from https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention
|
| 5 |
+
"""
|
| 6 |
+
import warnings
|
| 7 |
+
from typing import Tuple, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.nn.modules.linear import _LinearWithBias
|
| 12 |
+
from torch.nn.init import xavier_uniform_
|
| 13 |
+
from torch.nn.init import constant_
|
| 14 |
+
from torch.nn.init import xavier_normal_
|
| 15 |
+
from torch.nn.parameter import Parameter
|
| 16 |
+
from torch.nn.modules.module import Module
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
| 19 |
+
from torch import _VF
|
| 20 |
+
|
| 21 |
+
# Activation functions
|
| 22 |
+
def dropout(input, p=0.5, training=True, inplace=False):
|
| 23 |
+
# type: (Tensor, float, bool, bool) -> Tensor
|
| 24 |
+
r"""
|
| 25 |
+
During training, randomly zeroes some of the elements of the input
|
| 26 |
+
tensor with probability :attr:`p` using samples from a Bernoulli
|
| 27 |
+
distribution.
|
| 28 |
+
|
| 29 |
+
See :class:`~torch.nn.Dropout` for details.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
p: probability of an element to be zeroed. Default: 0.5
|
| 33 |
+
training: apply dropout if is ``True``. Default: ``True``
|
| 34 |
+
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
|
| 35 |
+
"""
|
| 36 |
+
if not torch.jit.is_scripting():
|
| 37 |
+
if type(input) is not Tensor and has_torch_function((input,)):
|
| 38 |
+
return handle_torch_function(
|
| 39 |
+
dropout, (input,), input, p=p, training=training, inplace=inplace)
|
| 40 |
+
if p < 0. or p > 1.:
|
| 41 |
+
raise ValueError("dropout probability has to be between 0 and 1, "
|
| 42 |
+
"but got {}".format(p))
|
| 43 |
+
return (_VF.dropout_(input, p, training)
|
| 44 |
+
if inplace
|
| 45 |
+
else _VF.dropout(input, p, training))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_softmax_dim(name, ndim, stacklevel):
|
| 49 |
+
# type: (str, int, int) -> int
|
| 50 |
+
warnings.warn("Implicit dimension choice for {} has been deprecated. "
|
| 51 |
+
"Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel)
|
| 52 |
+
if ndim == 0 or ndim == 1 or ndim == 3:
|
| 53 |
+
ret = 0
|
| 54 |
+
else:
|
| 55 |
+
ret = 1
|
| 56 |
+
return ret
|
| 57 |
+
|
| 58 |
+
def softmax(input, dim=None, _stacklevel=3, dtype=None):
|
| 59 |
+
# type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
|
| 60 |
+
r"""Applies a softmax function.
|
| 61 |
+
Softmax is defined as:
|
| 62 |
+
:math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
|
| 63 |
+
It is applied to all slices along dim, and will re-scale them so that the elements
|
| 64 |
+
lie in the range `[0, 1]` and sum to 1.
|
| 65 |
+
See :class:`~torch.nn.Softmax` for more details.
|
| 66 |
+
Args:
|
| 67 |
+
input (Tensor): input
|
| 68 |
+
dim (int): A dimension along which softmax will be computed.
|
| 69 |
+
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
| 70 |
+
If specified, the input tensor is casted to :attr:`dtype` before the operation
|
| 71 |
+
is performed. This is useful for preventing data type overflows. Default: None.
|
| 72 |
+
.. note::
|
| 73 |
+
This function doesn't work directly with NLLLoss,
|
| 74 |
+
which expects the Log to be computed between the Softmax and itself.
|
| 75 |
+
Use log_softmax instead (it's faster and has better numerical properties).
|
| 76 |
+
"""
|
| 77 |
+
if not torch.jit.is_scripting():
|
| 78 |
+
if type(input) is not Tensor and has_torch_function((input,)):
|
| 79 |
+
return handle_torch_function(
|
| 80 |
+
softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
|
| 81 |
+
if dim is None:
|
| 82 |
+
dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
|
| 83 |
+
if dtype is None:
|
| 84 |
+
ret = input.softmax(dim)
|
| 85 |
+
else:
|
| 86 |
+
ret = input.softmax(dim, dtype=dtype)
|
| 87 |
+
return ret
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def linear(input, weight, bias=None):
|
| 91 |
+
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
| 92 |
+
r"""
|
| 93 |
+
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
|
| 94 |
+
This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.
|
| 95 |
+
Shape:
|
| 96 |
+
- Input: :math:`(N, *, in\_features)` N is the batch size, `*` means any number of
|
| 97 |
+
additional dimensions
|
| 98 |
+
- Weight: :math:`(out\_features, in\_features)`
|
| 99 |
+
- Bias: :math:`(out\_features)`
|
| 100 |
+
- Output: :math:`(N, *, out\_features)`
|
| 101 |
+
"""
|
| 102 |
+
tens_ops = (input, weight)
|
| 103 |
+
if not torch.jit.is_scripting():
|
| 104 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
| 105 |
+
return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
|
| 106 |
+
if input.dim() == 2 and bias is not None:
|
| 107 |
+
# fused op is marginally faster
|
| 108 |
+
ret = torch.addmm(bias, input, weight.t())
|
| 109 |
+
else:
|
| 110 |
+
output = input.matmul(weight.t())
|
| 111 |
+
if bias is not None:
|
| 112 |
+
output += bias
|
| 113 |
+
ret = output
|
| 114 |
+
return ret
|
| 115 |
+
|
| 116 |
+
def multi_head_attention_forward(query: Tensor, key: Tensor, value: Tensor, embed_dim_to_check: int,num_heads: int,
|
| 117 |
+
in_proj_weight: Tensor, in_proj_bias: Tensor, bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool,
|
| 118 |
+
dropout_p: float, out_proj_weight: Tensor, out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None,
|
| 119 |
+
need_weights: bool = True, attn_mask: Optional[Tensor] = None, use_separate_proj_weight: bool = False, q_proj_weight: Optional[Tensor] = None,
|
| 120 |
+
k_proj_weight: Optional[Tensor] = None, v_proj_weight: Optional[Tensor] = None, static_k: Optional[Tensor] = None, static_v: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
| 121 |
+
r"""
|
| 122 |
+
Args:
|
| 123 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
| 124 |
+
See "Attention Is All You Need" for more details.
|
| 125 |
+
embed_dim_to_check: total dimension of the model.
|
| 126 |
+
num_heads: parallel attention heads.
|
| 127 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
| 128 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
| 129 |
+
add_zero_attn: add a new batch of zeros to the key and
|
| 130 |
+
value sequences at dim=1.
|
| 131 |
+
dropout_p: probability of an element to be zeroed.
|
| 132 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
| 133 |
+
training: apply dropout if is ``True``.
|
| 134 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
| 135 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
| 136 |
+
the corresponding value on the attention layer will be filled with -inf.
|
| 137 |
+
need_weights: output attn_output_weights.
|
| 138 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
| 139 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
| 140 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
| 141 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
| 142 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
| 143 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
| 144 |
+
static_k, static_v: static key and value used for attention operators.
|
| 145 |
+
Shape:
|
| 146 |
+
Inputs:
|
| 147 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
| 148 |
+
the embedding dimension.
|
| 149 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
| 150 |
+
the embedding dimension.
|
| 151 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
| 152 |
+
the embedding dimension.
|
| 153 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
| 154 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
| 155 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
| 156 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 157 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
| 158 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
| 159 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
| 160 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
| 161 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
| 162 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
| 163 |
+
is provided, it will be added to the attention weight.
|
| 164 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
| 165 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
| 166 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
| 167 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
| 168 |
+
Outputs:
|
| 169 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
| 170 |
+
E is the embedding dimension.
|
| 171 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
| 172 |
+
L is the target sequence length, S is the source sequence length.
|
| 173 |
+
"""
|
| 174 |
+
if not torch.jit.is_scripting():
|
| 175 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
| 176 |
+
out_proj_weight, out_proj_bias)
|
| 177 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
| 178 |
+
return handle_torch_function(
|
| 179 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
| 180 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
| 181 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
| 182 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
| 183 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
| 184 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
| 185 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
| 186 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
| 187 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 188 |
+
assert embed_dim == embed_dim_to_check
|
| 189 |
+
# allow MHA to have different sizes for the feature dimension
|
| 190 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
| 191 |
+
|
| 192 |
+
head_dim = embed_dim // num_heads
|
| 193 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
| 194 |
+
scaling = float(head_dim) ** -0.5
|
| 195 |
+
|
| 196 |
+
if not use_separate_proj_weight:
|
| 197 |
+
if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)):
|
| 198 |
+
# self-attention
|
| 199 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
| 200 |
+
|
| 201 |
+
elif (key is value or torch.equal(key, value)):
|
| 202 |
+
# encoder-decoder attention
|
| 203 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 204 |
+
_b = in_proj_bias
|
| 205 |
+
_start = 0
|
| 206 |
+
_end = embed_dim
|
| 207 |
+
_w = in_proj_weight[_start:_end, :]
|
| 208 |
+
if _b is not None:
|
| 209 |
+
_b = _b[_start:_end]
|
| 210 |
+
q = linear(query, _w, _b)
|
| 211 |
+
|
| 212 |
+
if key is None:
|
| 213 |
+
assert value is None
|
| 214 |
+
k = None
|
| 215 |
+
v = None
|
| 216 |
+
else:
|
| 217 |
+
|
| 218 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 219 |
+
_b = in_proj_bias
|
| 220 |
+
_start = embed_dim
|
| 221 |
+
_end = None
|
| 222 |
+
_w = in_proj_weight[_start:, :]
|
| 223 |
+
if _b is not None:
|
| 224 |
+
_b = _b[_start:]
|
| 225 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
| 226 |
+
|
| 227 |
+
else:
|
| 228 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 229 |
+
_b = in_proj_bias
|
| 230 |
+
_start = 0
|
| 231 |
+
_end = embed_dim
|
| 232 |
+
_w = in_proj_weight[_start:_end, :]
|
| 233 |
+
if _b is not None:
|
| 234 |
+
_b = _b[_start:_end]
|
| 235 |
+
q = linear(query, _w, _b)
|
| 236 |
+
|
| 237 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 238 |
+
_b = in_proj_bias
|
| 239 |
+
_start = embed_dim
|
| 240 |
+
_end = embed_dim * 2
|
| 241 |
+
_w = in_proj_weight[_start:_end, :]
|
| 242 |
+
if _b is not None:
|
| 243 |
+
_b = _b[_start:_end]
|
| 244 |
+
k = linear(key, _w, _b)
|
| 245 |
+
|
| 246 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
| 247 |
+
_b = in_proj_bias
|
| 248 |
+
_start = embed_dim * 2
|
| 249 |
+
_end = None
|
| 250 |
+
_w = in_proj_weight[_start:, :]
|
| 251 |
+
if _b is not None:
|
| 252 |
+
_b = _b[_start:]
|
| 253 |
+
v = linear(value, _w, _b)
|
| 254 |
+
else:
|
| 255 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
| 256 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
| 257 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
| 258 |
+
|
| 259 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
| 260 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
| 261 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
| 262 |
+
|
| 263 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
| 264 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
| 265 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
| 266 |
+
|
| 267 |
+
if in_proj_bias is not None:
|
| 268 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
| 269 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
| 270 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
| 271 |
+
else:
|
| 272 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
| 273 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
| 274 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
| 275 |
+
q = q * scaling
|
| 276 |
+
|
| 277 |
+
if attn_mask is not None:
|
| 278 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
| 279 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
| 280 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
| 281 |
+
if attn_mask.dtype == torch.uint8:
|
| 282 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
| 283 |
+
attn_mask = attn_mask.to(torch.bool)
|
| 284 |
+
|
| 285 |
+
if attn_mask.dim() == 2:
|
| 286 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 287 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
| 288 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
| 289 |
+
elif attn_mask.dim() == 3:
|
| 290 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
| 291 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
| 292 |
+
else:
|
| 293 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
| 294 |
+
# attn_mask's dim is 3 now.
|
| 295 |
+
|
| 296 |
+
# convert ByteTensor key_padding_mask to bool
|
| 297 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
| 298 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
| 299 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
| 300 |
+
|
| 301 |
+
if bias_k is not None and bias_v is not None:
|
| 302 |
+
if static_k is None and static_v is None:
|
| 303 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
| 304 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
| 305 |
+
if attn_mask is not None:
|
| 306 |
+
attn_mask = pad(attn_mask, (0, 1))
|
| 307 |
+
if key_padding_mask is not None:
|
| 308 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
| 309 |
+
else:
|
| 310 |
+
assert static_k is None, "bias cannot be added to static key."
|
| 311 |
+
assert static_v is None, "bias cannot be added to static value."
|
| 312 |
+
else:
|
| 313 |
+
assert bias_k is None
|
| 314 |
+
assert bias_v is None
|
| 315 |
+
|
| 316 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
| 317 |
+
if k is not None:
|
| 318 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 319 |
+
if v is not None:
|
| 320 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
| 321 |
+
|
| 322 |
+
if static_k is not None:
|
| 323 |
+
assert static_k.size(0) == bsz * num_heads
|
| 324 |
+
assert static_k.size(2) == head_dim
|
| 325 |
+
k = static_k
|
| 326 |
+
|
| 327 |
+
if static_v is not None:
|
| 328 |
+
assert static_v.size(0) == bsz * num_heads
|
| 329 |
+
assert static_v.size(2) == head_dim
|
| 330 |
+
v = static_v
|
| 331 |
+
|
| 332 |
+
src_len = k.size(1)
|
| 333 |
+
|
| 334 |
+
if key_padding_mask is not None:
|
| 335 |
+
assert key_padding_mask.size(0) == bsz
|
| 336 |
+
assert key_padding_mask.size(1) == src_len
|
| 337 |
+
|
| 338 |
+
if add_zero_attn:
|
| 339 |
+
src_len += 1
|
| 340 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
| 341 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
| 342 |
+
if attn_mask is not None:
|
| 343 |
+
attn_mask = pad(attn_mask, (0, 1))
|
| 344 |
+
if key_padding_mask is not None:
|
| 345 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
| 346 |
+
|
| 347 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
| 348 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
| 349 |
+
|
| 350 |
+
if attn_mask is not None:
|
| 351 |
+
if attn_mask.dtype == torch.bool:
|
| 352 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
| 353 |
+
else:
|
| 354 |
+
attn_output_weights += attn_mask
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if key_padding_mask is not None:
|
| 358 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
| 359 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
| 360 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 361 |
+
float('-inf'),
|
| 362 |
+
)
|
| 363 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
| 364 |
+
|
| 365 |
+
attn_output_weights = softmax(
|
| 366 |
+
attn_output_weights, dim=-1)
|
| 367 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
| 368 |
+
|
| 369 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
| 370 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
| 371 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 372 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
| 373 |
+
|
| 374 |
+
if need_weights:
|
| 375 |
+
# average attention weights over heads
|
| 376 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
| 377 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
| 378 |
+
else:
|
| 379 |
+
return attn_output, None
|
| 380 |
+
|
| 381 |
+
class MultiheadAttention(Module):
|
| 382 |
+
r"""Allows the model to jointly attend to information
|
| 383 |
+
from different representation subspaces.
|
| 384 |
+
See reference: Attention Is All You Need
|
| 385 |
+
|
| 386 |
+
.. math::
|
| 387 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 388 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
embed_dim: total dimension of the model.
|
| 392 |
+
num_heads: parallel attention heads.
|
| 393 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
| 394 |
+
bias: add bias as module parameter. Default: True.
|
| 395 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
| 396 |
+
add_zero_attn: add a new batch of zeros to the key and
|
| 397 |
+
value sequences at dim=1.
|
| 398 |
+
kdim: total number of features in key. Default: None.
|
| 399 |
+
vdim: total number of features in value. Default: None.
|
| 400 |
+
|
| 401 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
| 402 |
+
query, key, and value have the same number of features.
|
| 403 |
+
|
| 404 |
+
Examples::
|
| 405 |
+
|
| 406 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 407 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 408 |
+
"""
|
| 409 |
+
bias_k: Optional[torch.Tensor]
|
| 410 |
+
bias_v: Optional[torch.Tensor]
|
| 411 |
+
|
| 412 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
| 413 |
+
super(MultiheadAttention, self).__init__()
|
| 414 |
+
self.embed_dim = embed_dim
|
| 415 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 416 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 417 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 418 |
+
|
| 419 |
+
self.num_heads = num_heads
|
| 420 |
+
self.dropout = dropout
|
| 421 |
+
self.head_dim = embed_dim // num_heads
|
| 422 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 423 |
+
|
| 424 |
+
if self._qkv_same_embed_dim is False:
|
| 425 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 426 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 427 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 428 |
+
self.register_parameter('in_proj_weight', None)
|
| 429 |
+
else:
|
| 430 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 431 |
+
self.register_parameter('q_proj_weight', None)
|
| 432 |
+
self.register_parameter('k_proj_weight', None)
|
| 433 |
+
self.register_parameter('v_proj_weight', None)
|
| 434 |
+
|
| 435 |
+
if bias:
|
| 436 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 437 |
+
else:
|
| 438 |
+
self.register_parameter('in_proj_bias', None)
|
| 439 |
+
self.out_proj = _LinearWithBias(embed_dim, embed_dim)
|
| 440 |
+
|
| 441 |
+
if add_bias_kv:
|
| 442 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
| 443 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
| 444 |
+
else:
|
| 445 |
+
self.bias_k = self.bias_v = None
|
| 446 |
+
|
| 447 |
+
self.add_zero_attn = add_zero_attn
|
| 448 |
+
|
| 449 |
+
self._reset_parameters()
|
| 450 |
+
|
| 451 |
+
def _reset_parameters(self):
|
| 452 |
+
if self._qkv_same_embed_dim:
|
| 453 |
+
xavier_uniform_(self.in_proj_weight)
|
| 454 |
+
else:
|
| 455 |
+
xavier_uniform_(self.q_proj_weight)
|
| 456 |
+
xavier_uniform_(self.k_proj_weight)
|
| 457 |
+
xavier_uniform_(self.v_proj_weight)
|
| 458 |
+
|
| 459 |
+
if self.in_proj_bias is not None:
|
| 460 |
+
constant_(self.in_proj_bias, 0.)
|
| 461 |
+
constant_(self.out_proj.bias, 0.)
|
| 462 |
+
if self.bias_k is not None:
|
| 463 |
+
xavier_normal_(self.bias_k)
|
| 464 |
+
if self.bias_v is not None:
|
| 465 |
+
xavier_normal_(self.bias_v)
|
| 466 |
+
|
| 467 |
+
def __setstate__(self, state):
|
| 468 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
| 469 |
+
if '_qkv_same_embed_dim' not in state:
|
| 470 |
+
state['_qkv_same_embed_dim'] = True
|
| 471 |
+
|
| 472 |
+
super(MultiheadAttention, self).__setstate__(state)
|
| 473 |
+
|
| 474 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
| 475 |
+
need_weights=True, attn_mask=None):
|
| 476 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
| 477 |
+
r"""
|
| 478 |
+
Args:
|
| 479 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
| 480 |
+
See "Attention Is All You Need" for more details.
|
| 481 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
| 482 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
| 483 |
+
the corresponding value on the attention layer will be ignored. When given
|
| 484 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
| 485 |
+
layer will be ignored
|
| 486 |
+
need_weights: output attn_output_weights.
|
| 487 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
| 488 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
| 489 |
+
|
| 490 |
+
Shape:
|
| 491 |
+
- Inputs:
|
| 492 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
| 493 |
+
the embedding dimension.
|
| 494 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
| 495 |
+
the embedding dimension.
|
| 496 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
| 497 |
+
the embedding dimension.
|
| 498 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
| 499 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
| 500 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
| 501 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 502 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
| 503 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
| 504 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
| 505 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
| 506 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
| 507 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
| 508 |
+
is provided, it will be added to the attention weight.
|
| 509 |
+
|
| 510 |
+
- Outputs:
|
| 511 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
| 512 |
+
E is the embedding dimension.
|
| 513 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
| 514 |
+
L is the target sequence length, S is the source sequence length.
|
| 515 |
+
"""
|
| 516 |
+
if not self._qkv_same_embed_dim:
|
| 517 |
+
return multi_head_attention_forward(
|
| 518 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 519 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 520 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 521 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 522 |
+
training=self.training,
|
| 523 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 524 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
| 525 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
| 526 |
+
v_proj_weight=self.v_proj_weight)
|
| 527 |
+
else:
|
| 528 |
+
return multi_head_attention_forward(
|
| 529 |
+
query, key, value, self.embed_dim, self.num_heads,
|
| 530 |
+
self.in_proj_weight, self.in_proj_bias,
|
| 531 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
| 532 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
| 533 |
+
training=self.training,
|
| 534 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
| 535 |
+
attn_mask=attn_mask)
|
| 536 |
+
|
| 537 |
+
|
models/position_encoding.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Various positional encodings for the transformer.
|
| 3 |
+
borrowed from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from .misc import NestedTensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PositionEmbeddingSine(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.num_pos_feats = num_pos_feats
|
| 20 |
+
self.temperature = temperature
|
| 21 |
+
self.normalize = normalize
|
| 22 |
+
if scale is not None and normalize is False:
|
| 23 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 24 |
+
if scale is None:
|
| 25 |
+
scale = 2 * math.pi
|
| 26 |
+
self.scale = scale
|
| 27 |
+
|
| 28 |
+
def forward(self, tensor_list: NestedTensor):
|
| 29 |
+
x = tensor_list.tensors
|
| 30 |
+
mask = tensor_list.mask
|
| 31 |
+
assert mask is not None
|
| 32 |
+
not_mask = ~mask
|
| 33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 35 |
+
if self.normalize:
|
| 36 |
+
eps = 1e-6
|
| 37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 39 |
+
|
| 40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 42 |
+
|
| 43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
| 47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 48 |
+
return pos
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PositionEmbeddingLearned(nn.Module):
|
| 52 |
+
"""
|
| 53 |
+
Absolute pos embedding, learned.
|
| 54 |
+
"""
|
| 55 |
+
def __init__(self, num_pos_feats=256):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
| 58 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
| 59 |
+
self.reset_parameters()
|
| 60 |
+
|
| 61 |
+
def reset_parameters(self):
|
| 62 |
+
nn.init.uniform_(self.row_embed.weight)
|
| 63 |
+
nn.init.uniform_(self.col_embed.weight)
|
| 64 |
+
|
| 65 |
+
def forward(self, tensor_list: NestedTensor):
|
| 66 |
+
x = tensor_list.tensors
|
| 67 |
+
h, w = x.shape[-2:]
|
| 68 |
+
i = torch.arange(w, device=x.device)
|
| 69 |
+
j = torch.arange(h, device=x.device)
|
| 70 |
+
x_emb = self.col_embed(i)
|
| 71 |
+
y_emb = self.row_embed(j)
|
| 72 |
+
pos = torch.cat([
|
| 73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
| 74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
| 75 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_position_encoding(args):
|
| 80 |
+
N_steps = args.hidden_dim // 2
|
| 81 |
+
if args.position_embedding in ('v2', 'sine'):
|
| 82 |
+
# TODO find a better way of exposing other arguments
|
| 83 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
| 84 |
+
elif args.position_embedding in ('v3', 'learned'):
|
| 85 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
| 88 |
+
|
| 89 |
+
return position_embedding
|
models/preprocessing.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.transforms.functional as functional
|
| 2 |
+
|
| 3 |
+
class Compose(object):
|
| 4 |
+
def __init__(self, transforms):
|
| 5 |
+
self.transforms = transforms
|
| 6 |
+
|
| 7 |
+
def __call__(self, image):
|
| 8 |
+
for t in self.transforms:
|
| 9 |
+
image = t(image)
|
| 10 |
+
return image
|
| 11 |
+
|
| 12 |
+
def __repr__(self):
|
| 13 |
+
format_string = self.__class__.__name__ + "("
|
| 14 |
+
for t in self.transforms:
|
| 15 |
+
format_string += "\n"
|
| 16 |
+
format_string += " {0}".format(t)
|
| 17 |
+
format_string += "\n)"
|
| 18 |
+
return format_string
|
| 19 |
+
|
| 20 |
+
class Normalize(object):
|
| 21 |
+
def __init__(self, mean, std):
|
| 22 |
+
self.mean = mean
|
| 23 |
+
self.std = std
|
| 24 |
+
|
| 25 |
+
def __call__(self, image):
|
| 26 |
+
image = functional.normalize(image, mean=self.mean, std=self.std)
|
| 27 |
+
return image
|
| 28 |
+
|
| 29 |
+
class ToTensor(object):
|
| 30 |
+
def __call__(self, img):
|
| 31 |
+
return functional.to_tensor(img)
|
| 32 |
+
|
| 33 |
+
def resize(image, size, max_size=None):
|
| 34 |
+
# size can be min_size (scalar) or (w, h) tuple
|
| 35 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| 36 |
+
w, h = image_size
|
| 37 |
+
if max_size is not None:
|
| 38 |
+
min_original_size = float(min((w, h)))
|
| 39 |
+
max_original_size = float(max((w, h)))
|
| 40 |
+
if max_original_size / min_original_size * size > max_size:
|
| 41 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
| 42 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
| 43 |
+
return (h, w)
|
| 44 |
+
if w < h:
|
| 45 |
+
ow = size
|
| 46 |
+
oh = int(size * h / w)
|
| 47 |
+
else:
|
| 48 |
+
oh = size
|
| 49 |
+
ow = int(size * w / h)
|
| 50 |
+
return (oh, ow)
|
| 51 |
+
|
| 52 |
+
def get_size(image_size, size, max_size=None):
|
| 53 |
+
if isinstance(size, (list, tuple)):
|
| 54 |
+
return size[::-1]
|
| 55 |
+
else:
|
| 56 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
| 57 |
+
|
| 58 |
+
size = get_size(image.size, size, max_size)
|
| 59 |
+
rescaled_image = functional.resize(image, size)
|
| 60 |
+
|
| 61 |
+
return rescaled_image
|
| 62 |
+
|
| 63 |
+
class Resize(object):
|
| 64 |
+
def __init__(self, sizes, max_size=None):
|
| 65 |
+
assert isinstance(sizes, (list, tuple))
|
| 66 |
+
self.sizes = sizes
|
| 67 |
+
self.max_size = max_size
|
| 68 |
+
|
| 69 |
+
def __call__(self, img):
|
| 70 |
+
size = self.sizes
|
| 71 |
+
return resize(img, size, self.max_size)
|
models/transformer.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
DETR Transformer class.
|
| 4 |
+
|
| 5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
| 6 |
+
* positional encodings are passed in MHattention
|
| 7 |
+
* extra LN at the end of encoder is removed
|
| 8 |
+
* decoder returns a stack of activations from all decoding layers
|
| 9 |
+
"""
|
| 10 |
+
import copy
|
| 11 |
+
from typing import Optional, List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
from .multi_head_attention import MultiheadAttention
|
| 17 |
+
|
| 18 |
+
class Transformer(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
| 21 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
| 22 |
+
activation="relu", normalize_before=False,
|
| 23 |
+
return_intermediate_dec=False):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
| 27 |
+
dropout, activation, normalize_before)
|
| 28 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
| 29 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 30 |
+
|
| 31 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
| 32 |
+
dropout, activation, normalize_before)
|
| 33 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 34 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
| 35 |
+
return_intermediate=return_intermediate_dec)
|
| 36 |
+
|
| 37 |
+
self._reset_parameters()
|
| 38 |
+
|
| 39 |
+
self.d_model = d_model
|
| 40 |
+
self.nhead = nhead
|
| 41 |
+
|
| 42 |
+
def _reset_parameters(self):
|
| 43 |
+
for p in self.parameters():
|
| 44 |
+
if p.dim() > 1:
|
| 45 |
+
nn.init.xavier_uniform_(p)
|
| 46 |
+
|
| 47 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
| 48 |
+
# flatten NxCxHxW to HWxNxC
|
| 49 |
+
bs, c, h, w = src.shape
|
| 50 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 51 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
| 52 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
| 53 |
+
mask = mask.flatten(1)
|
| 54 |
+
|
| 55 |
+
tgt = torch.zeros_like(query_embed)
|
| 56 |
+
|
| 57 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
| 58 |
+
|
| 59 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
| 60 |
+
pos=pos_embed, query_pos=query_embed)
|
| 61 |
+
return hs.transpose(1, 2), memory#.permute(1, 2, 0).view(bs, c, h, w)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TransformerEncoder(nn.Module):
|
| 65 |
+
|
| 66 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 69 |
+
self.num_layers = num_layers
|
| 70 |
+
self.norm = norm
|
| 71 |
+
|
| 72 |
+
def forward(self, src,
|
| 73 |
+
mask: Optional[Tensor] = None,
|
| 74 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 75 |
+
pos: Optional[Tensor] = None):
|
| 76 |
+
output = src
|
| 77 |
+
|
| 78 |
+
for layer in self.layers:
|
| 79 |
+
output = layer(output, src_mask=mask,
|
| 80 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
| 81 |
+
|
| 82 |
+
if self.norm is not None:
|
| 83 |
+
output = self.norm(output)
|
| 84 |
+
|
| 85 |
+
return output
|
| 86 |
+
|
| 87 |
+
class TransformerDecoder(nn.Module):
|
| 88 |
+
|
| 89 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
| 92 |
+
self.num_layers = num_layers
|
| 93 |
+
self.norm = norm
|
| 94 |
+
self.return_intermediate = return_intermediate
|
| 95 |
+
|
| 96 |
+
def forward(self, tgt, memory,
|
| 97 |
+
tgt_mask: Optional[Tensor] = None,
|
| 98 |
+
memory_mask: Optional[Tensor] = None,
|
| 99 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 100 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 101 |
+
pos: Optional[Tensor] = None,
|
| 102 |
+
query_pos: Optional[Tensor] = None):
|
| 103 |
+
output = tgt
|
| 104 |
+
|
| 105 |
+
intermediate = []
|
| 106 |
+
|
| 107 |
+
for layer in self.layers:
|
| 108 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
| 109 |
+
memory_mask=memory_mask,
|
| 110 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 111 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 112 |
+
pos=pos, query_pos=query_pos)
|
| 113 |
+
if self.return_intermediate:
|
| 114 |
+
intermediate.append(self.norm(output))
|
| 115 |
+
|
| 116 |
+
if self.norm is not None:
|
| 117 |
+
output = self.norm(output)
|
| 118 |
+
if self.return_intermediate:
|
| 119 |
+
intermediate.pop()
|
| 120 |
+
intermediate.append(output)
|
| 121 |
+
|
| 122 |
+
if self.return_intermediate:
|
| 123 |
+
return torch.stack(intermediate)
|
| 124 |
+
|
| 125 |
+
return output.unsqueeze(0)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TransformerEncoderLayer(nn.Module):
|
| 129 |
+
|
| 130 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 131 |
+
activation="relu", normalize_before=False):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 134 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 135 |
+
self.dropout = nn.Dropout(dropout)
|
| 136 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 137 |
+
|
| 138 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 139 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 140 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 141 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 142 |
+
|
| 143 |
+
self.activation = _get_activation_fn(activation)
|
| 144 |
+
self.normalize_before = normalize_before
|
| 145 |
+
|
| 146 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 147 |
+
return tensor if pos is None else tensor + pos
|
| 148 |
+
|
| 149 |
+
def forward_post(self,
|
| 150 |
+
src,
|
| 151 |
+
src_mask: Optional[Tensor] = None,
|
| 152 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 153 |
+
pos: Optional[Tensor] = None):
|
| 154 |
+
q = k = self.with_pos_embed(src, pos)
|
| 155 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
| 156 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 157 |
+
src = src + self.dropout1(src2)
|
| 158 |
+
src = self.norm1(src)
|
| 159 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| 160 |
+
src = src + self.dropout2(src2)
|
| 161 |
+
src = self.norm2(src)
|
| 162 |
+
return src
|
| 163 |
+
|
| 164 |
+
def forward_pre(self, src,
|
| 165 |
+
src_mask: Optional[Tensor] = None,
|
| 166 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 167 |
+
pos: Optional[Tensor] = None):
|
| 168 |
+
src2 = self.norm1(src)
|
| 169 |
+
q = k = self.with_pos_embed(src2, pos)
|
| 170 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
| 171 |
+
key_padding_mask=src_key_padding_mask)[0]
|
| 172 |
+
src = src + self.dropout1(src2)
|
| 173 |
+
src2 = self.norm2(src)
|
| 174 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| 175 |
+
src = src + self.dropout2(src2)
|
| 176 |
+
return src
|
| 177 |
+
|
| 178 |
+
def forward(self, src,
|
| 179 |
+
src_mask: Optional[Tensor] = None,
|
| 180 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 181 |
+
pos: Optional[Tensor] = None):
|
| 182 |
+
if self.normalize_before:
|
| 183 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| 184 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TransformerDecoderLayer(nn.Module):
|
| 188 |
+
|
| 189 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
| 190 |
+
activation="relu", normalize_before=False):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 193 |
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 194 |
+
# Implementation of Feedforward model
|
| 195 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 196 |
+
self.dropout = nn.Dropout(dropout)
|
| 197 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 198 |
+
|
| 199 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 200 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 201 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 202 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 203 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 204 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 205 |
+
|
| 206 |
+
self.activation = _get_activation_fn(activation)
|
| 207 |
+
self.normalize_before = normalize_before
|
| 208 |
+
|
| 209 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 210 |
+
return tensor if pos is None else tensor + pos
|
| 211 |
+
|
| 212 |
+
def forward_post(self, tgt, memory,
|
| 213 |
+
tgt_mask: Optional[Tensor] = None,
|
| 214 |
+
memory_mask: Optional[Tensor] = None,
|
| 215 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 216 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 217 |
+
pos: Optional[Tensor] = None,
|
| 218 |
+
query_pos: Optional[Tensor] = None):
|
| 219 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
| 220 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
| 221 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 222 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 223 |
+
tgt = self.norm1(tgt)
|
| 224 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
| 225 |
+
key=self.with_pos_embed(memory, pos),
|
| 226 |
+
value=memory, attn_mask=memory_mask,
|
| 227 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 228 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 229 |
+
tgt = self.norm2(tgt)
|
| 230 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
| 231 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 232 |
+
tgt = self.norm3(tgt)
|
| 233 |
+
return tgt
|
| 234 |
+
|
| 235 |
+
def forward_pre(self, tgt, memory,
|
| 236 |
+
tgt_mask: Optional[Tensor] = None,
|
| 237 |
+
memory_mask: Optional[Tensor] = None,
|
| 238 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 239 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 240 |
+
pos: Optional[Tensor] = None,
|
| 241 |
+
query_pos: Optional[Tensor] = None):
|
| 242 |
+
tgt2 = self.norm1(tgt)
|
| 243 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 244 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 245 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 246 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 247 |
+
tgt2 = self.norm2(tgt)
|
| 248 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
| 249 |
+
key=self.with_pos_embed(memory, pos),
|
| 250 |
+
value=memory, attn_mask=memory_mask,
|
| 251 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
| 252 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 253 |
+
tgt2 = self.norm3(tgt)
|
| 254 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 255 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 256 |
+
return tgt
|
| 257 |
+
|
| 258 |
+
def forward(self, tgt, memory,
|
| 259 |
+
tgt_mask: Optional[Tensor] = None,
|
| 260 |
+
memory_mask: Optional[Tensor] = None,
|
| 261 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 262 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 263 |
+
pos: Optional[Tensor] = None,
|
| 264 |
+
query_pos: Optional[Tensor] = None):
|
| 265 |
+
if self.normalize_before:
|
| 266 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
| 267 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 268 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
| 269 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _get_clones(module, N):
|
| 273 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def build_transformer(args):
|
| 277 |
+
|
| 278 |
+
return Transformer(
|
| 279 |
+
d_model=args.hidden_dim,
|
| 280 |
+
dropout=args.dropout,
|
| 281 |
+
nhead=args.nheads,
|
| 282 |
+
dim_feedforward=args.dim_feedforward,
|
| 283 |
+
num_encoder_layers=args.enc_layers,
|
| 284 |
+
num_decoder_layers=args.dec_layers,
|
| 285 |
+
normalize_before=args.pre_norm,
|
| 286 |
+
return_intermediate_dec=True,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def _get_activation_fn(activation):
|
| 290 |
+
"""Return an activation function given a string"""
|
| 291 |
+
if activation == "relu":
|
| 292 |
+
return F.relu
|
| 293 |
+
if activation == "gelu":
|
| 294 |
+
return F.gelu
|
| 295 |
+
if activation == "glu":
|
| 296 |
+
return F.glu
|
| 297 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.8.1
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
jinja2
|
| 5 |
+
scipy
|
tappeto-per-calibrazione.jpg
ADDED
|
test.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageDraw
|
| 2 |
+
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from models.letr import build
|
| 7 |
+
from models.misc import nested_tensor_from_tensor_list
|
| 8 |
+
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
| 9 |
+
|
| 10 |
+
def create_letr():
|
| 11 |
+
# obtain checkpoints
|
| 12 |
+
checkpoint = torch.load('checkpoint0024.pth', map_location='cpu')
|
| 13 |
+
|
| 14 |
+
# load model
|
| 15 |
+
args = checkpoint['args']
|
| 16 |
+
args.device = 'cpu'
|
| 17 |
+
model, _, _ = build(args)
|
| 18 |
+
model.load_state_dict(checkpoint['model'])
|
| 19 |
+
model.eval()
|
| 20 |
+
return model
|
| 21 |
+
|
| 22 |
+
def draw_fig(image, outputs, orig_size):
|
| 23 |
+
# find lines
|
| 24 |
+
out_logits, out_line = outputs['pred_logits'], outputs['pred_lines']
|
| 25 |
+
prob = F.softmax(out_logits, -1)
|
| 26 |
+
scores, labels = prob[..., :-1].max(-1)
|
| 27 |
+
img_h, img_w = orig_size.unbind(0)
|
| 28 |
+
scale_fct = torch.unsqueeze(torch.stack(
|
| 29 |
+
[img_w, img_h, img_w, img_h], dim=0), dim=0)
|
| 30 |
+
lines = out_line * scale_fct[:, None, :]
|
| 31 |
+
lines = lines.view(1000, 2, 2)
|
| 32 |
+
lines = lines.flip([-1]) # this is yxyx format
|
| 33 |
+
scores = scores.detach().numpy()
|
| 34 |
+
keep = scores >= 0.7
|
| 35 |
+
keep = keep.squeeze()
|
| 36 |
+
lines = lines[keep]
|
| 37 |
+
if len(lines) != 0:
|
| 38 |
+
lines = lines.reshape(lines.shape[0], -1)
|
| 39 |
+
|
| 40 |
+
# draw lines
|
| 41 |
+
draw = ImageDraw.Draw(image)
|
| 42 |
+
for tp_id, line in enumerate(lines):
|
| 43 |
+
y1, x1, y2, x2 = line
|
| 44 |
+
draw.line((x1, y1, x2, y2), fill=500)
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
model = create_letr()
|
| 48 |
+
|
| 49 |
+
test_size = 256
|
| 50 |
+
normalize = Compose([
|
| 51 |
+
ToTensor(),
|
| 52 |
+
Normalize([0.538, 0.494, 0.453], [0.257, 0.263, 0.273]),
|
| 53 |
+
Resize([test_size]),
|
| 54 |
+
])
|
| 55 |
+
|
| 56 |
+
image = Image.open('demo.png')
|
| 57 |
+
h, w = image.height, image.width
|
| 58 |
+
orig_size = torch.as_tensor([int(h), int(w)])
|
| 59 |
+
|
| 60 |
+
img = normalize(image)
|
| 61 |
+
inputs = nested_tensor_from_tensor_list([img])
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = model(inputs)[0]
|
| 65 |
+
draw_fig(image, outputs, orig_size)
|
| 66 |
+
|
| 67 |
+
image.save('output.png')
|