| """AttnVQ: VQQuantizedCache wired into model.generate().""" |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from vqkv.compressed_cache import VQQuantizedCache |
|
|
|
|
| |
| tok = AutoTokenizer.from_pretrained("poolside/Laguna-XS.2", trust_remote_code=True, fix_mistral_regex=True) |
| model = AutoModelForCausalLM.from_pretrained("poolside/Laguna-XS.2", torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True).eval() |
|
|
| |
| CODEBOOKS_PATH = "artifacts/codebooks.pt" |
| codebooks = torch.load(CODEBOOKS_PATH, map_location="cuda", weights_only=False) |
|
|
| |
| quantizers, layers = codebooks["fitted"]["productvq-32x256-2b"], codebooks["meta"]["full_layers"] |
| cache = VQQuantizedCache(quantizers, layers) |
|
|
| |
| ids = tok("Hello", return_tensors="pt").to(model.device) |
| out = model.generate(**ids, max_new_tokens=32, past_key_values=cache, use_cache=True) |
| print(tok.decode(out[0, ids["input_ids"].shape[1]:], skip_special_tokens=True)) |
|
|
| |
| print(cache.memory_footprint()) |
|
|