File size: 2,106 Bytes
2a5ab99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import torch
import torch.nn as nn

class ImageFeatureExtractorCNN(nn.Module):
    def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"):
        super().__init__()
        Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c))
        act = nn.SiLU

        layers = []
        c = base_channels
        layers += [
            nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False),
            Norm(c),
            act(),
        ]
        for i in range(num_blocks):
            c_out = c * 2 if i > 0 else c
            stride = 2 if i > 0 else 1
            block = nn.Sequential(
                nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
                Norm(c_out),
                act(),
                nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
                Norm(c_out),
                act(),
            )
            layers.append(block)
            if dropout and dropout > 0:
                layers.append(nn.Dropout2d(dropout))
            c = c_out

        self.backbone = nn.Sequential(*layers)
        self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1)
        self.proj = nn.Linear(c, embed_dim)

    def forward(self, x):
        h = self.backbone(x)
        h = self.global_pool(h).flatten(1)
        return self.proj(h)

def load_pretrained(model_dir: str, map_location="cpu"):
    with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
        cfg = json.load(f)
    m = ImageFeatureExtractorCNN(
        in_channels=cfg["in_channels"],
        image_size=cfg["image_size"],
        embed_dim=cfg["embed_dim"],
        base_channels=cfg["base_channels"],
        num_blocks=cfg["num_blocks"],
        dropout=cfg["dropout"],
        pool=cfg["pool"],
        norm=cfg["norm"],
    )
    sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
    m.load_state_dict(sd)
    m.eval()
    return m, cfg