"""Load Hydra BitNet model.""" import torch from safetensors.torch import load_file def load_hydra(model_path: str, device: str = "cpu"): """Load Hydra model from HuggingFace format.""" import sys from pathlib import Path # Add aisim to path if needed aisim_path = Path(__file__).parent.parent / "aisim" if aisim_path.exists(): sys.path.insert(0, str(aisim_path)) from bitnet_moe import M2MSentinel import json # Load config with open(f"{model_path}/config.json") as f: config = json.load(f) # Create model model = M2MSentinel( vocab_size=config["vocab_size"], dim=config["hidden_size"], depth=config["num_hidden_layers"], experts=config["num_experts"], ) # Load weights weights = load_file(f"{model_path}/model.safetensors") model.load_state_dict(weights) model = model.to(device) model.eval() return model, config if __name__ == "__main__": import sys model_path = sys.argv[1] if len(sys.argv) > 1 else "." model, config = load_hydra(model_path) print(f"Loaded model: {config}")