Spaces:
Running
Running
File size: 1,573 Bytes
052f26d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import json
import torch
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from src.models.ijepa import IJEPATargetEncoder
@dataclass
class ViTConfig:
img_size: int = 224
in_chans: int = 3
patch_size: int = 14
embed_dim: int = 1280
depth: int = 32
num_heads: int = 16
mlp_ratio: float = 4.0
def load_model_from_hf(
repo_id: str,
device: str = "cuda",
token: str = None
):
"""
Downloads and loads the I-JEPA model from a Hugging Face Model Repository.
"""
print(f"Fetching model files from {repo_id}...")
# 1. Download Config
config_path = hf_hub_download(
repo_id=repo_id,
filename="config.json",
token=token
)
# 2. Download Weights
weights_path = hf_hub_download(
repo_id=repo_id,
filename="model_weights.pth",
token=token
)
# 3. Initialize Architecture from downloaded config
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = ViTConfig(**config_dict)
model = IJEPATargetEncoder(
img_size=config.img_size,
patch_size=config.patch_size,
embed_dim=config.embed_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio
)
# 4. Load Weights
print("Loading state dict...")
state_dict = torch.load(weights_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("Model successfully loaded from Hugging Face.")
return model
|