Spaces:
Sleeping
Sleeping
File size: 2,375 Bytes
9ec3d0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from . import heads
from . import UltralyticsModel
from . import Concat
import torch
class PatchEmbedder(torch.nn.Module):
"""
Pathc embedder. Each patch is a visual tokens. This class to work architecture without backbone.
"""
def __init__(self, imgsz:int, out_dim:int, patch_size:int=16, device=torch.device('cpu')):
super().__init__()
self.device = device
self.imgsz = imgsz
self.proj = torch.nn.Conv2d(3, out_dim, kernel_size=patch_size, stride=patch_size).to(device)
with torch.no_grad():
dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
y = self.proj(dummy)
self.bb_out_shape = y.shape
del dummy, y
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x) # [B, out_dim, H', W']
class Backbone(torch.nn.Module):
"""
This class extracts 'ultralytics.engine.model.Model's backbone.
"""
def __init__(self,
model: UltralyticsModel,
imgsz:int=640,
device=torch.device('cpu')):
super(Backbone, self).__init__()
self.model_name = model.model_name
self.device = device
self.layers = torch.nn.ModuleList(model.model.model[:-1]).to(self.device)
self.imgsz = imgsz
dummy = torch.zeros(1, 3, imgsz, imgsz, device=self.device, requires_grad=False)
out = self.shape_infer(dummy)
self.bb_out_shape = out.shape
del out, dummy
def forward(self, x):
outputs = []
for m in self.layers:
if isinstance(m, heads):
return outputs[-1]
elif isinstance(m, Concat):
x = m([outputs[f] for f in m.f])
else:
x = m(x) if m.f == -1 else m(outputs[m.f])
outputs.append(x)
feats = outputs[-1] # feats: [B, C, H, W]
return feats
@torch.no_grad()
def shape_infer(self, x):
outputs = []
for m in self.layers:
if isinstance(m, heads):
return outputs[-1]
elif isinstance(m, Concat):
x = m([outputs[f] for f in m.f])
else:
x = m(x) if m.f == -1 else m(outputs[m.f])
outputs.append(x)
return outputs[-1]
|