Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| from dataclasses import dataclass | |
| from huggingface_hub import hf_hub_download | |
| from src.models.ijepa import IJEPATargetEncoder | |
| class ViTConfig: | |
| img_size: int = 224 | |
| in_chans: int = 3 | |
| patch_size: int = 14 | |
| embed_dim: int = 1280 | |
| depth: int = 32 | |
| num_heads: int = 16 | |
| mlp_ratio: float = 4.0 | |
| def load_model_from_hf( | |
| repo_id: str, | |
| device: str = "cuda", | |
| token: str = None | |
| ): | |
| """ | |
| Downloads and loads the I-JEPA model from a Hugging Face Model Repository. | |
| """ | |
| print(f"Fetching model files from {repo_id}...") | |
| # 1. Download Config | |
| config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="config.json", | |
| token=token | |
| ) | |
| # 2. Download Weights | |
| weights_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model_weights.pth", | |
| token=token | |
| ) | |
| # 3. Initialize Architecture from downloaded config | |
| with open(config_path, 'r') as f: | |
| config_dict = json.load(f) | |
| config = ViTConfig(**config_dict) | |
| model = IJEPATargetEncoder( | |
| img_size=config.img_size, | |
| patch_size=config.patch_size, | |
| embed_dim=config.embed_dim, | |
| depth=config.depth, | |
| num_heads=config.num_heads, | |
| mlp_ratio=config.mlp_ratio | |
| ) | |
| # 4. Load Weights | |
| print("Loading state dict...") | |
| state_dict = torch.load(weights_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model = model.to(device).eval() | |
| print("Model successfully loaded from Hugging Face.") | |
| return model | |