File size: 4,728 Bytes
d4ec3e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
TurboQuant inference with Qwen models.
Demonstrates TurboQuant KV cache compression as a drop-in replacement
for the default DynamicCache during model.generate().
"""
import sys
sys.path.insert(0, "/home/azureuser/turboquant")
import argparse
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from turboquant.cache import TurboQuantCache
def load_model(model_name: str, load_in_4bit: bool = True):
"""Load model and tokenizer."""
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
kwargs = {
"device_map": "auto",
"trust_remote_code": True,
"torch_dtype": torch.bfloat16,
}
if load_in_4bit:
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B")
return model, tokenizer
def generate_with_cache(model, tokenizer, prompt: str, cache_type: str = "turboquant",
max_new_tokens: int = 100, nbits: int = 4,
skip_layers: set[int] | None = None):
"""Generate text using specified cache type."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = inputs.input_ids.shape[1]
# Create cache
if cache_type == "turboquant":
cache = TurboQuantCache(
model.config,
nbits=nbits,
residual_length=128,
device=str(model.device),
skip_layers=skip_layers,
)
else:
cache = None # Use default DynamicCache
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
past_key_values=cache,
do_sample=False,
)
elapsed = time.time() - start
mem_peak = torch.cuda.max_memory_allocated()
mem_used = torch.cuda.memory_allocated() - mem_before
generated = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
n_tokens = outputs.shape[1] - input_len
print(f"\n Cache: {cache_type}")
print(f" Tokens: {n_tokens} in {elapsed:.2f}s ({n_tokens/elapsed:.1f} tok/s)")
print(f" Peak GPU memory: {mem_peak / 1024**3:.2f} GB")
print(f" Cache memory delta: {mem_used / 1024**2:.1f} MB")
print(f" Output: {generated[:200]}...")
return generated, elapsed, mem_peak
def main():
parser = argparse.ArgumentParser(description="TurboQuant inference")
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct",
help="Model name (default: Qwen2.5-1.5B for testing)")
parser.add_argument("--prompt", default="Explain quantum computing in simple terms.",
help="Input prompt")
parser.add_argument("--max-tokens", type=int, default=100)
parser.add_argument("--nbits", type=int, default=4, choices=[2, 4])
parser.add_argument("--no-4bit", action="store_true", help="Load in BF16 instead of 4-bit")
parser.add_argument("--compare", action="store_true", help="Compare TurboQuant vs default cache")
args = parser.parse_args()
model, tokenizer = load_model(args.model, load_in_4bit=not args.no_4bit)
# Auto-calibrate skip layers
skip = TurboQuantCache.calibrate_skip_layers(model, tokenizer)
print(f"Auto-detected skip layers: {skip} (kept in BF16 due to outlier KV norms)")
if args.compare:
print("\n" + "=" * 60)
print("COMPARISON: Default DynamicCache vs TurboQuantCache")
print("=" * 60)
# Default cache
gen_default, t_default, mem_default = generate_with_cache(
model, tokenizer, args.prompt, "default", args.max_tokens
)
torch.cuda.empty_cache()
# TurboQuant cache
gen_tq, t_tq, mem_tq = generate_with_cache(
model, tokenizer, args.prompt, "turboquant", args.max_tokens, args.nbits,
skip_layers=skip,
)
print(f"\n Memory savings: {(mem_default - mem_tq) / 1024**2:.1f} MB "
f"({mem_default/max(mem_tq, 1):.2f}x)")
print(f" Outputs match: {gen_default == gen_tq}")
else:
generate_with_cache(
model, tokenizer, args.prompt, "turboquant", args.max_tokens, args.nbits,
skip_layers=skip,
)
if __name__ == "__main__":
main()
|