radjepa
custom_code
RadJEPA / modeling_radjepa.py
anas2908's picture
Update modeling_radjepa.py
661236a verified
from huggingface_hub import hf_hub_download
import os
import torch
import timm
from timm.layers import PatchEmbed
from transformers import PreTrainedModel, PretrainedConfig
class RadJEPAConfig(PretrainedConfig):
model_type = "radjepa"
def __init__(self, image_size=224, patch_size=14, embed_dim=768, **kwargs):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
class RadJEPAEncoder(PreTrainedModel):
config_class = RadJEPAConfig
def __init__(self, config):
super().__init__(config)
self.model = timm.create_model(
"vit_base_patch16_224",
pretrained=False,
num_classes=0
)
self.model.patch_embed = PatchEmbed(
img_size=config.image_size,
patch_size=config.patch_size,
in_chans=3,
embed_dim=config.embed_dim,
)
num_patches = self.model.patch_embed.num_patches
self.model.cls_token = None
self.model.num_prefix_tokens = 0
self.model.pos_embed = torch.nn.Parameter(
torch.zeros(1, num_patches, config.embed_dim)
)
torch.nn.init.trunc_normal_(self.model.pos_embed, std=0.02)
def forward(self, pixel_values):
tokens = self.model.forward_features(pixel_values)
return tokens.mean(dim=1)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
config = RadJEPAConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# 🔑 CORRECT HF WAY
ckpt_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="jepa_encoder.pth.tar",
repo_type="model"
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if "encoder" in ckpt:
state_dict = ckpt["encoder"]
elif "state_dict" in ckpt and "encoder" in ckpt["state_dict"]:
state_dict = ckpt["state_dict"]["encoder"]
else:
raise RuntimeError("Encoder weights not found in checkpoint")
state_dict = {
k.replace("module.", "").replace("encoder.", ""): v
for k, v in state_dict.items()
}
model.model.load_state_dict(state_dict, strict=True)
return model