File size: 2,578 Bytes
a84640a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Tests for the PTQ / fake-quant utilities (no LiteRT stack required)."""

import os
import sys

import pytest

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

torch = pytest.importorskip("torch")

import torch.nn as nn  # noqa: E402

from hermes.config import HermesConfig  # noqa: E402
from hermes.model import build_model  # noqa: E402
from hermes.quantization import (  # noqa: E402
    apply_weight_only_int4,
    apply_weight_only_int8,
    collect_calibration_stats,
    fake_quantize_per_group,
)


def _tiny_cfg():
    return HermesConfig(
        vocab_size=64, hidden_size=32, intermediate_size=64, num_layers=2,
        num_heads=4, num_kv_heads=2, head_dim=8, max_seq_len=16,
    )


def test_int4_weight_range():
    model = build_model(_tiny_cfg())
    apply_weight_only_int4(model, group_size=8)
    for module in model.modules():
        if isinstance(module, nn.Linear):
            # Reconstruct the integer codes from the dequantized weights per group.
            w = module.weight.data
            qmax = 7
            gs = 8
            out_f, in_f = w.shape
            pad = (gs - in_f % gs) % gs
            wp = torch.nn.functional.pad(w, (0, pad)).reshape(out_f, -1, gs)
            scale = (wp.abs().amax(-1, keepdim=True) / qmax).clamp(min=1e-8)
            codes = torch.round(wp / scale)
            assert codes.min() >= -8 and codes.max() <= 7


def test_int8_weight_range():
    model = build_model(_tiny_cfg())
    apply_weight_only_int8(model)
    for module in model.modules():
        if isinstance(module, nn.Linear):
            w = module.weight.data
            qmax = 127
            scale = (w.abs().amax(-1, keepdim=True) / qmax).clamp(min=1e-8)
            codes = torch.round(w / scale)
            assert codes.min() >= -128 and codes.max() <= 127


def test_calibration_stats_keys():
    model = build_model(_tiny_cfg())
    data = [torch.randint(0, 64, (1, 8)) for _ in range(3)]
    stats = collect_calibration_stats(model, data, num_batches=3)
    assert isinstance(stats, dict) and stats
    # Layer names should reference nn.Linear submodules (e.g. q_proj).
    assert any("q_proj" in name for name in stats)
    for entry in stats.values():
        assert {"min", "max", "abs_max", "p99"} <= set(entry)
        assert entry["max"] >= entry["min"]


def test_fake_quant_is_idempotent():
    w = torch.randn(16, 24)
    once = fake_quantize_per_group(w, bits=4, group_size=8)
    twice = fake_quantize_per_group(once, bits=4, group_size=8)
    assert torch.allclose(once, twice, atol=1e-5)