| from huggingface_hub import hf_hub_download |
| import os |
| import torch |
| import timm |
| from timm.layers import PatchEmbed |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class RadJEPAConfig(PretrainedConfig): |
| model_type = "radjepa" |
|
|
| def __init__(self, image_size=224, patch_size=14, embed_dim=768, **kwargs): |
| super().__init__(**kwargs) |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.embed_dim = embed_dim |
|
|
|
|
| class RadJEPAEncoder(PreTrainedModel): |
| config_class = RadJEPAConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.model = timm.create_model( |
| "vit_base_patch16_224", |
| pretrained=False, |
| num_classes=0 |
| ) |
|
|
| self.model.patch_embed = PatchEmbed( |
| img_size=config.image_size, |
| patch_size=config.patch_size, |
| in_chans=3, |
| embed_dim=config.embed_dim, |
| ) |
|
|
| num_patches = self.model.patch_embed.num_patches |
| self.model.cls_token = None |
| self.model.num_prefix_tokens = 0 |
|
|
| self.model.pos_embed = torch.nn.Parameter( |
| torch.zeros(1, num_patches, config.embed_dim) |
| ) |
| torch.nn.init.trunc_normal_(self.model.pos_embed, std=0.02) |
|
|
| def forward(self, pixel_values): |
| tokens = self.model.forward_features(pixel_values) |
| return tokens.mean(dim=1) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| config = RadJEPAConfig.from_pretrained(pretrained_model_name_or_path) |
| model = cls(config) |
|
|
| |
| ckpt_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="jepa_encoder.pth.tar", |
| repo_type="model" |
| ) |
|
|
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
|
|
| if "encoder" in ckpt: |
| state_dict = ckpt["encoder"] |
| elif "state_dict" in ckpt and "encoder" in ckpt["state_dict"]: |
| state_dict = ckpt["state_dict"]["encoder"] |
| else: |
| raise RuntimeError("Encoder weights not found in checkpoint") |
|
|
| state_dict = { |
| k.replace("module.", "").replace("encoder.", ""): v |
| for k, v in state_dict.items() |
| } |
|
|
| model.model.load_state_dict(state_dict, strict=True) |
| return model |
|
|