File size: 2,375 Bytes
9ec3d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from . import heads
from . import UltralyticsModel
from . import Concat
import torch

class PatchEmbedder(torch.nn.Module):
    """
    Pathc embedder. Each patch is a visual tokens. This class to work architecture without backbone.
    """

    def __init__(self, imgsz:int, out_dim:int, patch_size:int=16, device=torch.device('cpu')):
        super().__init__()
        self.device = device
        self.imgsz = imgsz
        self.proj = torch.nn.Conv2d(3, out_dim, kernel_size=patch_size, stride=patch_size).to(device)

        with torch.no_grad():
            dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
            y = self.proj(dummy)
            self.bb_out_shape = y.shape
            del dummy, y

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)  # [B, out_dim, H', W']


class Backbone(torch.nn.Module):
    """
    This class extracts 'ultralytics.engine.model.Model's backbone.
    """
    def __init__(self, 
                 model: UltralyticsModel, 
                 imgsz:int=640,
                 device=torch.device('cpu')):
        
        super(Backbone, self).__init__()
        self.model_name = model.model_name
        self.device = device
        self.layers = torch.nn.ModuleList(model.model.model[:-1]).to(self.device)
        self.imgsz = imgsz
        
        dummy = torch.zeros(1, 3, imgsz, imgsz, device=self.device, requires_grad=False)
        out = self.shape_infer(dummy)
        self.bb_out_shape = out.shape

        del out, dummy
    
    def forward(self, x):
        outputs = []
        for m in self.layers:
            if isinstance(m, heads):
                return outputs[-1]
            elif isinstance(m, Concat):
                x = m([outputs[f] for f in m.f])
            else:
                x = m(x) if m.f == -1 else m(outputs[m.f])
            outputs.append(x)

        feats = outputs[-1]                 # feats: [B, C, H, W]
        return feats
    
    @torch.no_grad()
    def shape_infer(self, x):
        outputs = []
        
        for m in self.layers:
            if isinstance(m, heads):
                return outputs[-1]
            elif isinstance(m, Concat):
                x = m([outputs[f] for f in m.f])
            else:
                x = m(x) if m.f == -1 else m(outputs[m.f])
            outputs.append(x)

        return outputs[-1]