agunzzm commited on
Commit
2a5ab99
·
verified ·
1 Parent(s): ada92e1

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +35 -0
  2. model.py +59 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "image_feature_extractor_cnn",
3
+ "torch_dtype": "float32",
4
+ "in_channels": 3,
5
+ "image_size": 224,
6
+ "embed_dim": 256,
7
+ "base_channels": 32,
8
+ "num_blocks": 4,
9
+ "dropout": 0.0,
10
+ "pool": "avg",
11
+ "norm": "bn",
12
+ "task": "feature_extraction",
13
+ "outputs": {
14
+ "embedding": "float32[batch, embed_dim]"
15
+ },
16
+ "preprocessing": {
17
+ "expected_input": "float32 images in [0,1] normalized",
18
+ "suggested_normalize": {
19
+ "mean": [
20
+ 0.485,
21
+ 0.456,
22
+ 0.406
23
+ ],
24
+ "std": [
25
+ 0.229,
26
+ 0.224,
27
+ 0.225
28
+ ]
29
+ },
30
+ "resize": 224
31
+ },
32
+ "id2label": null,
33
+ "label2id": null,
34
+ "version": "0.1.0"
35
+ }
model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ImageFeatureExtractorCNN(nn.Module):
6
+ def __init__(self, in_channels=3, image_size=224, embed_dim=256, base_channels=32, num_blocks=4, dropout=0.0, pool="avg", norm="bn"):
7
+ super().__init__()
8
+ Norm = (lambda c: nn.BatchNorm2d(c)) if norm == "bn" else (lambda c: nn.GroupNorm(8, c))
9
+ act = nn.SiLU
10
+
11
+ layers = []
12
+ c = base_channels
13
+ layers += [
14
+ nn.Conv2d(in_channels, c, kernel_size=3, stride=2, padding=1, bias=False),
15
+ Norm(c),
16
+ act(),
17
+ ]
18
+ for i in range(num_blocks):
19
+ c_out = c * 2 if i > 0 else c
20
+ stride = 2 if i > 0 else 1
21
+ block = nn.Sequential(
22
+ nn.Conv2d(c, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
23
+ Norm(c_out),
24
+ act(),
25
+ nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
26
+ Norm(c_out),
27
+ act(),
28
+ )
29
+ layers.append(block)
30
+ if dropout and dropout > 0:
31
+ layers.append(nn.Dropout2d(dropout))
32
+ c = c_out
33
+
34
+ self.backbone = nn.Sequential(*layers)
35
+ self.global_pool = nn.AdaptiveAvgPool2d(1) if pool == "avg" else nn.AdaptiveMaxPool2d(1)
36
+ self.proj = nn.Linear(c, embed_dim)
37
+
38
+ def forward(self, x):
39
+ h = self.backbone(x)
40
+ h = self.global_pool(h).flatten(1)
41
+ return self.proj(h)
42
+
43
+ def load_pretrained(model_dir: str, map_location="cpu"):
44
+ with open(f"{model_dir}/config.json", "r", encoding="utf-8") as f:
45
+ cfg = json.load(f)
46
+ m = ImageFeatureExtractorCNN(
47
+ in_channels=cfg["in_channels"],
48
+ image_size=cfg["image_size"],
49
+ embed_dim=cfg["embed_dim"],
50
+ base_channels=cfg["base_channels"],
51
+ num_blocks=cfg["num_blocks"],
52
+ dropout=cfg["dropout"],
53
+ pool=cfg["pool"],
54
+ norm=cfg["norm"],
55
+ )
56
+ sd = torch.load(f"{model_dir}/pytorch_model.bin", map_location=map_location)
57
+ m.load_state_dict(sd)
58
+ m.eval()
59
+ return m, cfg
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:693bb19917d8e7778a3584abbdefea01909a197634292b33c2e202ee3a3e08e3
3
+ size 5019350