Upload 3 files
Browse files- config.json +35 -0
- model.py +59 -0
- pytorch_model.bin +3 -0
config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "image_feature_extractor_cnn",
|
| 3 |
+
"torch_dtype": "float32",
|
| 4 |
+
"in_channels": 3,
|
| 5 |
+
"image_size": 224,
|
| 6 |
+
"embed_dim": 256,
|
| 7 |
+
"base_channels": 32,
|
| 8 |
+
"num_blocks": 4,
|
| 9 |
+
"dropout": 0.0,
|
| 10 |
+
"pool": "avg",
|
| 11 |
+
"norm": "bn",
|
| 12 |
+
"task": "feature_extraction",
|
| 13 |
+
"outputs": {
|
| 14 |
+
"embedding": "float32[batch, embed_dim]"
|
| 15 |
+
},
|
| 16 |
+
"preprocessing": {
|
| 17 |
+
"expected_input": "float32 images in [0,1] normalized",
|
| 18 |
+
"suggested_normalize": {
|
| 19 |
+
"mean": [
|
| 20 |
+
0.485,
|
| 21 |
+
0.456,
|
| 22 |
+
0.406
|
| 23 |
+
],
|
| 24 |
+
"std": [
|
| 25 |
+
0.229,
|
| 26 |
+
0.224,
|
| 27 |
+
0.225
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"resize": 224
|
| 31 |
+
},
|
| 32 |
+
"id2label": null,
|
| 33 |
+
"label2id": null,
|
| 34 |
+
"version": "0.1.0"
|
| 35 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class ImageFeatureExtractorCNN(nn.Module):
|
| 6 |
+
def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"):
|
| 7 |
+
super().__init__()
|
| 8 |
+
Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c))
|
| 9 |
+
act = nn.SiLU
|
| 10 |
+
|
| 11 |
+
layers = []
|
| 12 |
+
c = base_channels
|
| 13 |
+
layers += [
|
| 14 |
+
nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False),
|
| 15 |
+
Norm(c),
|
| 16 |
+
act(),
|
| 17 |
+
]
|
| 18 |
+
for i in range(num_blocks):
|
| 19 |
+
c_out = c * 2 if i > 0 else c
|
| 20 |
+
stride = 2 if i > 0 else 1
|
| 21 |
+
block = nn.Sequential(
|
| 22 |
+
nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
|
| 23 |
+
Norm(c_out),
|
| 24 |
+
act(),
|
| 25 |
+
nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
|
| 26 |
+
Norm(c_out),
|
| 27 |
+
act(),
|
| 28 |
+
)
|
| 29 |
+
layers.append(block)
|
| 30 |
+
if dropout and dropout > 0:
|
| 31 |
+
layers.append(nn.Dropout2d(dropout))
|
| 32 |
+
c = c_out
|
| 33 |
+
|
| 34 |
+
self.backbone = nn.Sequential(*layers)
|
| 35 |
+
self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1)
|
| 36 |
+
self.proj = nn.Linear(c, embed_dim)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
h = self.backbone(x)
|
| 40 |
+
h = self.global_pool(h).flatten(1)
|
| 41 |
+
return self.proj(h)
|
| 42 |
+
|
| 43 |
+
def load_pretrained(model_dir: str, map_location="cpu"):
|
| 44 |
+
with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
|
| 45 |
+
cfg = json.load(f)
|
| 46 |
+
m = ImageFeatureExtractorCNN(
|
| 47 |
+
in_channels=cfg["in_channels"],
|
| 48 |
+
image_size=cfg["image_size"],
|
| 49 |
+
embed_dim=cfg["embed_dim"],
|
| 50 |
+
base_channels=cfg["base_channels"],
|
| 51 |
+
num_blocks=cfg["num_blocks"],
|
| 52 |
+
dropout=cfg["dropout"],
|
| 53 |
+
pool=cfg["pool"],
|
| 54 |
+
norm=cfg["norm"],
|
| 55 |
+
)
|
| 56 |
+
sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
|
| 57 |
+
m.load_state_dict(sd)
|
| 58 |
+
m.eval()
|
| 59 |
+
return m, cfg
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:693bb19917d8e7778a3584abbdefea01909a197634292b33c2e202ee3a3e08e3
|
| 3 |
+
size 5019350
|