import json import torch from dataclasses import dataclass from huggingface_hub import hf_hub_download from src.models.ijepa import IJEPATargetEncoder @dataclass class ViTConfig: img_size: int = 224 in_chans: int = 3 patch_size: int = 14 embed_dim: int = 1280 depth: int = 32 num_heads: int = 16 mlp_ratio: float = 4.0 def load_model_from_hf( repo_id: str, device: str = "cuda", token: str = None ): """ Downloads and loads the I-JEPA model from a Hugging Face Model Repository. """ print(f"Fetching model files from {repo_id}...") # 1. Download Config config_path = hf_hub_download( repo_id=repo_id, filename="config.json", token=token ) # 2. Download Weights weights_path = hf_hub_download( repo_id=repo_id, filename="model_weights.pth", token=token ) # 3. Initialize Architecture from downloaded config with open(config_path, 'r') as f: config_dict = json.load(f) config = ViTConfig(**config_dict) model = IJEPATargetEncoder( img_size=config.img_size, patch_size=config.patch_size, embed_dim=config.embed_dim, depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio ) # 4. Load Weights print("Loading state dict...") state_dict = torch.load(weights_path, map_location='cpu') model.load_state_dict(state_dict) model = model.to(device).eval() print("Model successfully loaded from Hugging Face.") return model