File size: 2,916 Bytes
954cf8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Compute a Fixed Corpus Basis (FCB) for cross-document and
cross-model stable state vector extraction.

The FCB is the principal subspace of the key manifold computed
from a diverse reference corpus. Unlike per-document SVD,
the FCB is document-independent — all documents projected
with the same FCB exist in the same coordinate system.
"""

from __future__ import annotations

import argparse
import gc
import sys
from pathlib import Path

import torch
from llama_cpp import Llama

from kvcos.core.blob_parser import parse_state_blob
from kvcos.core.state_extractor import MARStateExtractor
from scripts.generate_alignment_dataset import DOCUMENTS


def main() -> int:
    parser = argparse.ArgumentParser(description="Compute Fixed Corpus Basis")
    parser.add_argument("--model", required=True)
    parser.add_argument("--layer-range", type=int, nargs=2, default=[8, 24])
    parser.add_argument("--gate-start", type=int, default=6)
    parser.add_argument("--rank", type=int, default=122)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()

    llm = Llama(model_path=args.model, n_ctx=2048, n_gpu_layers=-1, verbose=False)
    meta = llm.metadata
    n_kv = int(meta.get("llama.attention.head_count_kv", "8"))
    head_dim = int(meta.get("llama.embedding_length", "4096")) // int(
        meta.get("llama.attention.head_count", "32")
    )
    model_name = meta.get("general.name", "unknown")

    print(f"Model: {model_name} ({n_kv} KV heads, {head_dim} head_dim)")
    print(f"Layer range: {args.layer_range}, gate_start: {args.gate_start}")
    print(f"Collecting key tensors from {len(DOCUMENTS)} documents...")

    key_tensors: list[torch.Tensor] = []
    for i, doc in enumerate(DOCUMENTS):
        llm.reset()
        llm(doc.strip(), max_tokens=1, temperature=0.0)
        s = llm.save_state()
        parsed = parse_state_blob(
            bytes(s.llama_state), n_kv_heads=n_kv, head_dim=head_dim
        )
        key_tensors.append(parsed.keys)
        if (i + 1) % 10 == 0:
            print(f"  {i + 1}/{len(DOCUMENTS)}")
    del llm
    gc.collect()

    print("Computing corpus SVD...")
    basis = MARStateExtractor.compute_corpus_basis(
        key_tensors=key_tensors,
        layer_range=tuple(args.layer_range),
        gate_start=args.gate_start,
        rank=args.rank,
    )

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(
        {
            "basis": basis,
            "model_name": model_name,
            "layer_range": args.layer_range,
            "gate_start": args.gate_start,
            "rank": args.rank,
            "n_corpus_docs": len(DOCUMENTS),
            "key_tensors": key_tensors,
        },
        str(output_path),
    )

    print(f"Basis shape: {basis.shape}")
    print(f"Saved: {output_path}")
    return 0


if __name__ == "__main__":
    sys.exit(main())