Meme-Recommender / utils /model_loading_util.py
Diwakar Basnet
feat: integrate I-JEPA manager and HF model repository loading
052f26d
import json
import torch
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from src.models.ijepa import IJEPATargetEncoder
@dataclass
class ViTConfig:
img_size: int = 224
in_chans: int = 3
patch_size: int = 14
embed_dim: int = 1280
depth: int = 32
num_heads: int = 16
mlp_ratio: float = 4.0
def load_model_from_hf(
repo_id: str,
device: str = "cuda",
token: str = None
):
"""
Downloads and loads the I-JEPA model from a Hugging Face Model Repository.
"""
print(f"Fetching model files from {repo_id}...")
# 1. Download Config
config_path = hf_hub_download(
repo_id=repo_id,
filename="config.json",
token=token
)
# 2. Download Weights
weights_path = hf_hub_download(
repo_id=repo_id,
filename="model_weights.pth",
token=token
)
# 3. Initialize Architecture from downloaded config
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = ViTConfig(**config_dict)
model = IJEPATargetEncoder(
img_size=config.img_size,
patch_size=config.patch_size,
embed_dim=config.embed_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio
)
# 4. Load Weights
print("Loading state dict...")
state_dict = torch.load(weights_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("Model successfully loaded from Hugging Face.")
return model