import json import torch import torch.nn as nn class ImageFeatureExtractorCNN(nn.Module): def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"): super().__init__() Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c)) act = nn.SiLU layers = [] c = base_channels layers += [ nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False), Norm(c), act(), ] for i in range(num_blocks): c_out = c * 2 if i > 0 else c stride = 2 if i > 0 else 1 block = nn.Sequential( nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False), Norm(c_out), act(), nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False), Norm(c_out), act(), ) layers.append(block) if dropout and dropout > 0: layers.append(nn.Dropout2d(dropout)) c = c_out self.backbone = nn.Sequential(*layers) self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1) self.proj = nn.Linear(c, embed_dim) def forward(self, x): h = self.backbone(x) h = self.global_pool(h).flatten(1) return self.proj(h) def load_pretrained(model_dir: str, map_location="cpu"): with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f: cfg = json.load(f) m = ImageFeatureExtractorCNN( in_channels=cfg["in_channels"], image_size=cfg["image_size"], embed_dim=cfg["embed_dim"], base_channels=cfg["base_channels"], num_blocks=cfg["num_blocks"], dropout=cfg["dropout"], pool=cfg["pool"], norm=cfg["norm"], ) sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location) m.load_state_dict(sd) m.eval() return m, cfg