File size: 2,106 Bytes
2a5ab99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import json
import torch
import torch.nn as nn
class ImageFeatureExtractorCNN(nn.Module):
def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"):
super().__init__()
Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c))
act = nn.SiLU
layers = []
c = base_channels
layers += [
nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False),
Norm(c),
act(),
]
for i in range(num_blocks):
c_out = c * 2 if i > 0 else c
stride = 2 if i > 0 else 1
block = nn.Sequential(
nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
Norm(c_out),
act(),
nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
Norm(c_out),
act(),
)
layers.append(block)
if dropout and dropout > 0:
layers.append(nn.Dropout2d(dropout))
c = c_out
self.backbone = nn.Sequential(*layers)
self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1)
self.proj = nn.Linear(c, embed_dim)
def forward(self, x):
h = self.backbone(x)
h = self.global_pool(h).flatten(1)
return self.proj(h)
def load_pretrained(model_dir: str, map_location="cpu"):
with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
cfg = json.load(f)
m = ImageFeatureExtractorCNN(
in_channels=cfg["in_channels"],
image_size=cfg["image_size"],
embed_dim=cfg["embed_dim"],
base_channels=cfg["base_channels"],
num_blocks=cfg["num_blocks"],
dropout=cfg["dropout"],
pool=cfg["pool"],
norm=cfg["norm"],
)
sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
m.load_state_dict(sd)
m.eval()
return m, cfg
|