hydra / load_model.py
icxcn's picture
Upload folder using huggingface_hub
93145cf verified
"""Load Hydra BitNet model."""
import torch
from safetensors.torch import load_file
def load_hydra(model_path: str, device: str = "cpu"):
"""Load Hydra model from HuggingFace format."""
import sys
from pathlib import Path
# Add aisim to path if needed
aisim_path = Path(__file__).parent.parent / "aisim"
if aisim_path.exists():
sys.path.insert(0, str(aisim_path))
from bitnet_moe import M2MSentinel
import json
# Load config
with open(f"{model_path}/config.json") as f:
config = json.load(f)
# Create model
model = M2MSentinel(
vocab_size=config["vocab_size"],
dim=config["hidden_size"],
depth=config["num_hidden_layers"],
experts=config["num_experts"],
)
# Load weights
weights = load_file(f"{model_path}/model.safetensors")
model.load_state_dict(weights)
model = model.to(device)
model.eval()
return model, config
if __name__ == "__main__":
import sys
model_path = sys.argv[1] if len(sys.argv) > 1 else "."
model, config = load_hydra(model_path)
print(f"Loaded model: {config}")