hermes-edge / tests /test_model.py
bclermo's picture
Upload folder using huggingface_hub
a84640a verified
Raw
History Blame Contribute Delete
3.06 kB
"""Smoke tests for the Hermes mobile model and tooling.
These tests avoid the heavy LiteRT stack so they run anywhere torch is present.
"""
import os
import sys
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from hermes.chat_template import ( # noqa: E402
Message,
build_prompt,
format_tool_call,
parse_tool_call,
)
from hermes.config import ( # noqa: E402
HermesConfig,
get_config,
hermes_1b_config,
hermes_270m_config,
hermes_500m_config,
)
torch = pytest.importorskip("torch")
from hermes.model import build_model # noqa: E402
def test_config_invariants():
cfg = hermes_1b_config()
assert cfg.hidden_size == cfg.num_heads * cfg.head_dim
assert cfg.num_heads % cfg.num_kv_heads == 0
# ~1B target, within a sane band.
params = cfg.estimated_parameters()
assert 0.8e9 < params < 1.3e9
def test_config_rejects_bad_gqa():
with pytest.raises(ValueError):
HermesConfig(hidden_size=2048, num_heads=32, num_kv_heads=5, head_dim=64)
def test_get_config_presets():
assert get_config("hermes-270m").num_layers == hermes_270m_config().num_layers
assert get_config("hermes-500m").num_layers == hermes_500m_config().num_layers
with pytest.raises(KeyError):
get_config("nope")
def test_hermes_500m_config():
cfg = hermes_500m_config()
assert cfg.hidden_size == 1536
assert cfg.num_heads == 24 and cfg.num_kv_heads == 6
assert cfg.num_layers == 24
assert cfg.hidden_size == cfg.num_heads * cfg.head_dim
assert "hermes-500m" in get_config.__globals__["PRESETS"]
def test_forward_and_loss():
cfg = HermesConfig(
vocab_size=128, hidden_size=64, intermediate_size=128, num_layers=2,
num_heads=4, num_kv_heads=2, head_dim=16, max_seq_len=32,
)
model = build_model(cfg)
ids = torch.randint(0, cfg.vocab_size, (2, 8))
out = model(ids, labels=ids)
assert out["logits"].shape == (2, 8, cfg.vocab_size)
assert out["loss"].item() > 0
def test_generate_runs():
cfg = HermesConfig(
vocab_size=128, hidden_size=64, intermediate_size=128, num_layers=2,
num_heads=4, num_kv_heads=2, head_dim=16, max_seq_len=32,
)
model = build_model(cfg)
ids = torch.randint(0, cfg.vocab_size, (1, 4))
out = model.generate(ids, max_new_tokens=5, temperature=0.0)
assert out.shape[1] == 9
def test_tool_call_roundtrip():
text = "sure " + format_tool_call("calculator", {"expression": "2+2"})
call = parse_tool_call(text)
assert call["name"] == "calculator"
assert call["arguments"]["expression"] == "2+2"
def test_parse_tool_call_absent():
assert parse_tool_call("just a normal answer") is None
def test_build_prompt_includes_tools():
msgs = [Message("user", "hi")]
tools = [{"name": "calculator", "description": "math"}]
prompt = build_prompt(msgs, tools=tools)
assert "<|im_start|>system" in prompt
assert "calculator" in prompt
assert prompt.rstrip().endswith("<|im_start|>assistant")