hermes-edge / tests /test_quantization.py
bclermo's picture
Upload folder using huggingface_hub
a84640a verified
Raw
History Blame Contribute Delete
2.58 kB
"""Tests for the PTQ / fake-quant utilities (no LiteRT stack required)."""
import os
import sys
import pytest
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
torch = pytest.importorskip("torch")
import torch.nn as nn # noqa: E402
from hermes.config import HermesConfig # noqa: E402
from hermes.model import build_model # noqa: E402
from hermes.quantization import ( # noqa: E402
apply_weight_only_int4,
apply_weight_only_int8,
collect_calibration_stats,
fake_quantize_per_group,
)
def _tiny_cfg():
return HermesConfig(
vocab_size=64, hidden_size=32, intermediate_size=64, num_layers=2,
num_heads=4, num_kv_heads=2, head_dim=8, max_seq_len=16,
)
def test_int4_weight_range():
model = build_model(_tiny_cfg())
apply_weight_only_int4(model, group_size=8)
for module in model.modules():
if isinstance(module, nn.Linear):
# Reconstruct the integer codes from the dequantized weights per group.
w = module.weight.data
qmax = 7
gs = 8
out_f, in_f = w.shape
pad = (gs - in_f % gs) % gs
wp = torch.nn.functional.pad(w, (0, pad)).reshape(out_f, -1, gs)
scale = (wp.abs().amax(-1, keepdim=True) / qmax).clamp(min=1e-8)
codes = torch.round(wp / scale)
assert codes.min() >= -8 and codes.max() <= 7
def test_int8_weight_range():
model = build_model(_tiny_cfg())
apply_weight_only_int8(model)
for module in model.modules():
if isinstance(module, nn.Linear):
w = module.weight.data
qmax = 127
scale = (w.abs().amax(-1, keepdim=True) / qmax).clamp(min=1e-8)
codes = torch.round(w / scale)
assert codes.min() >= -128 and codes.max() <= 127
def test_calibration_stats_keys():
model = build_model(_tiny_cfg())
data = [torch.randint(0, 64, (1, 8)) for _ in range(3)]
stats = collect_calibration_stats(model, data, num_batches=3)
assert isinstance(stats, dict) and stats
# Layer names should reference nn.Linear submodules (e.g. q_proj).
assert any("q_proj" in name for name in stats)
for entry in stats.values():
assert {"min", "max", "abs_max", "p99"} <= set(entry)
assert entry["max"] >= entry["min"]
def test_fake_quant_is_idempotent():
w = torch.randn(16, 24)
once = fake_quantize_per_group(w, bits=4, group_size=8)
twice = fake_quantize_per_group(once, bits=4, group_size=8)
assert torch.allclose(once, twice, atol=1e-5)