attnvq / generate.py
adirik's picture
update stale files
4e235ae
Raw
History Blame Contribute Delete
1.11 kB
"""AttnVQ: VQQuantizedCache wired into model.generate()."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vqkv.compressed_cache import VQQuantizedCache
# load model
tok = AutoTokenizer.from_pretrained("poolside/Laguna-XS.2", trust_remote_code=True, fix_mistral_regex=True)
model = AutoModelForCausalLM.from_pretrained("poolside/Laguna-XS.2", torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True).eval()
# load codebooks or fit and use your own
CODEBOOKS_PATH = "artifacts/codebooks.pt"
codebooks = torch.load(CODEBOOKS_PATH, map_location="cuda", weights_only=False)
# build cache
quantizers, layers = codebooks["fitted"]["productvq-32x256-2b"], codebooks["meta"]["full_layers"]
cache = VQQuantizedCache(quantizers, layers) # persists uint8 codebook indices
# generate
ids = tok("Hello", return_tensors="pt").to(model.device)
out = model.generate(**ids, max_new_tokens=32, past_key_values=cache, use_cache=True)
print(tok.decode(out[0, ids["input_ids"].shape[1]:], skip_special_tokens=True))
# print memory footprint
print(cache.memory_footprint())