| import json | |
| import torch | |
| import torch.nn as nn | |
| class ImageFeatureExtractorCNN(nn.Module): | |
| def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"): | |
| super().__init__() | |
| Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c)) | |
| act = nn.SiLU | |
| layers = [] | |
| c = base_channels | |
| layers += [ | |
| nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False), | |
| Norm(c), | |
| act(), | |
| ] | |
| for i in range(num_blocks): | |
| c_out = c * 2 if i > 0 else c | |
| stride = 2 if i > 0 else 1 | |
| block = nn.Sequential( | |
| nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False), | |
| Norm(c_out), | |
| act(), | |
| nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False), | |
| Norm(c_out), | |
| act(), | |
| ) | |
| layers.append(block) | |
| if dropout and dropout > 0: | |
| layers.append(nn.Dropout2d(dropout)) | |
| c = c_out | |
| self.backbone = nn.Sequential(*layers) | |
| self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1) | |
| self.proj = nn.Linear(c, embed_dim) | |
| def forward(self, x): | |
| h = self.backbone(x) | |
| h = self.global_pool(h).flatten(1) | |
| return self.proj(h) | |
| def load_pretrained(model_dir: str, map_location="cpu"): | |
| with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| m = ImageFeatureExtractorCNN( | |
| in_channels=cfg["in_channels"], | |
| image_size=cfg["image_size"], | |
| embed_dim=cfg["embed_dim"], | |
| base_channels=cfg["base_channels"], | |
| num_blocks=cfg["num_blocks"], | |
| dropout=cfg["dropout"], | |
| pool=cfg["pool"], | |
| norm=cfg["norm"], | |
| ) | |
| sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location) | |
| m.load_state_dict(sd) | |
| m.eval() | |
| return m, cfg | |