| """Load Hydra BitNet model.""" | |
| import torch | |
| from safetensors.torch import load_file | |
| def load_hydra(model_path: str, device: str = "cpu"): | |
| """Load Hydra model from HuggingFace format.""" | |
| import sys | |
| from pathlib import Path | |
| # Add aisim to path if needed | |
| aisim_path = Path(__file__).parent.parent / "aisim" | |
| if aisim_path.exists(): | |
| sys.path.insert(0, str(aisim_path)) | |
| from bitnet_moe import M2MSentinel | |
| import json | |
| # Load config | |
| with open(f"{model_path}/config.json") as f: | |
| config = json.load(f) | |
| # Create model | |
| model = M2MSentinel( | |
| vocab_size=config["vocab_size"], | |
| dim=config["hidden_size"], | |
| depth=config["num_hidden_layers"], | |
| experts=config["num_experts"], | |
| ) | |
| # Load weights | |
| weights = load_file(f"{model_path}/model.safetensors") | |
| model.load_state_dict(weights) | |
| model = model.to(device) | |
| model.eval() | |
| return model, config | |
| if __name__ == "__main__": | |
| import sys | |
| model_path = sys.argv[1] if len(sys.argv) > 1 else "." | |
| model, config = load_hydra(model_path) | |
| print(f"Loaded model: {config}") | |