"""AttnVQ: VQQuantizedCache wired into model.generate().""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer from vqkv.compressed_cache import VQQuantizedCache # load model tok = AutoTokenizer.from_pretrained("poolside/Laguna-XS.2", trust_remote_code=True, fix_mistral_regex=True) model = AutoModelForCausalLM.from_pretrained("poolside/Laguna-XS.2", torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True).eval() # load codebooks or fit and use your own CODEBOOKS_PATH = "artifacts/codebooks.pt" codebooks = torch.load(CODEBOOKS_PATH, map_location="cuda", weights_only=False) # build cache quantizers, layers = codebooks["fitted"]["productvq-32x256-2b"], codebooks["meta"]["full_layers"] cache = VQQuantizedCache(quantizers, layers) # persists uint8 codebook indices # generate ids = tok("Hello", return_tensors="pt").to(model.device) out = model.generate(**ids, max_new_tokens=32, past_key_values=cache, use_cache=True) print(tok.decode(out[0, ids["input_ids"].shape[1]:], skip_special_tokens=True)) # print memory footprint print(cache.memory_footprint())