radjepa
custom_code
anas2908 commited on
Commit
c4b4651
·
verified ·
1 Parent(s): e00ff54

Initial RadJEPA encoder release

Browse files
Files changed (3) hide show
  1. config.json +6 -0
  2. jepa_encoder.pth.tar +3 -0
  3. modeling_radjepa.py +76 -0
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "radjepa",
3
+ "image_size": 224,
4
+ "patch_size": 14,
5
+ "embed_dim": 768
6
+ }
jepa_encoder.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afce5c46e600354b58033a53f88ecdc0da4a09308c5d0062f142465090e4e2aa
3
+ size 1633156351
modeling_radjepa.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm
3
+ from timm.layers import PatchEmbed
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+
6
+ class RadJEPAConfig(PretrainedConfig):
7
+ model_type = "radjepa"
8
+
9
+ def __init__(
10
+ self,
11
+ image_size=224,
12
+ patch_size=14,
13
+ embed_dim=768,
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.image_size = image_size
18
+ self.patch_size = patch_size
19
+ self.embed_dim = embed_dim
20
+
21
+
22
+ class RadJEPAEncoder(PreTrainedModel):
23
+ config_class = RadJEPAConfig
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+
28
+ self.model = timm.create_model(
29
+ "vit_base_patch16_224",
30
+ pretrained=False,
31
+ num_classes=0
32
+ )
33
+
34
+ self.model.patch_embed = PatchEmbed(
35
+ img_size=config.image_size,
36
+ patch_size=config.patch_size,
37
+ in_chans=3,
38
+ embed_dim=config.embed_dim,
39
+ )
40
+
41
+ num_patches = self.model.patch_embed.num_patches
42
+
43
+ self.model.cls_token = None
44
+ self.model.num_prefix_tokens = 0
45
+
46
+ self.model.pos_embed = torch.nn.Parameter(
47
+ torch.zeros(1, num_patches, config.embed_dim)
48
+ )
49
+ torch.nn.init.trunc_normal_(self.model.pos_embed, std=0.02)
50
+
51
+ def forward(self, pixel_values):
52
+ tokens = self.model.forward_features(pixel_values)
53
+ return tokens.mean(dim=1)
54
+
55
+ @classmethod
56
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
57
+ config = RadJEPAConfig.from_pretrained(pretrained_model_name_or_path)
58
+ model = cls(config)
59
+
60
+ ckpt_path = f"{pretrained_model_name_or_path}/jepa_encoder.pth.tar"
61
+ ckpt = torch.load(ckpt_path, map_location="cpu")
62
+
63
+ if "encoder" in ckpt:
64
+ state_dict = ckpt["encoder"]
65
+ elif "state_dict" in ckpt and "encoder" in ckpt["state_dict"]:
66
+ state_dict = ckpt["state_dict"]["encoder"]
67
+ else:
68
+ raise RuntimeError("Encoder weights not found")
69
+
70
+ state_dict = {
71
+ k.replace("module.", "").replace("encoder.", ""): v
72
+ for k, v in state_dict.items()
73
+ }
74
+
75
+ model.model.load_state_dict(state_dict, strict=True)
76
+ return model