| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import time, json, os, sys |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = f"~/kv-hack/results/{MODEL_NAME}" |
| os.makedirs(os.path.expanduser(results_dir), exist_ok=True) |
| |
|
|
| print(f"Running baseline for: {MODEL_NAME}") |
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| dtype=torch.float16, |
| device_map="cuda" |
| ) |
| model.eval() |
|
|
| results = {} |
|
|
| for ctx_len in [1024, 4096, 8192]: |
| print(f"\nTesting context length: {ctx_len}") |
| input_ids = torch.randint(1, 1000, (1, ctx_len)).cuda() |
|
|
| |
| with torch.no_grad(): |
| for _ in range(2): |
| out = model(input_ids, use_cache=True) |
|
|
| torch.cuda.synchronize() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| |
| times = [] |
| with torch.no_grad(): |
| for _ in range(5): |
| t0 = time.time() |
| out = model(input_ids, use_cache=True) |
| torch.cuda.synchronize() |
| times.append(time.time() - t0) |
|
|
| peak_mem = torch.cuda.max_memory_allocated() / 1e9 |
| avg_time = sum(times) / len(times) |
|
|
| results[ctx_len] = { |
| "peak_memory_gb": round(peak_mem, 2), |
| "avg_prefill_ms": round(avg_time * 1000, 1), |
| } |
| print(f" Peak memory: {peak_mem:.2f} GB") |
| print(f" Avg prefill: {avg_time*1000:.1f} ms") |
|
|
| |
| print("\nTesting decode speed...") |
| input_ids = torch.randint(1, 1000, (1, 512)).cuda() |
| with torch.no_grad(): |
| t0 = time.time() |
| out = model.generate( |
| input_ids, |
| max_new_tokens=100, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| torch.cuda.synchronize() |
| elapsed = time.time() - t0 |
|
|
| tokens_per_sec = 100 / elapsed |
| results["decode_tokens_per_sec"] = round(tokens_per_sec, 1) |
| print(f" Decode speed: {tokens_per_sec:.1f} tokens/sec") |
|
|
| |
| out_path = os.path.expanduser(f"{results_dir}/baseline.json") |
| with open(out_path, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\nβ
Saved to {out_path}") |
| print(json.dumps(results, indent=2)) |