import torch from transformers import AutoTokenizer, AutoModelForCausalLM import time, json, os, sys # ── config ────────────────────────────────────────── MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" MODEL_PATHS = { "mistral-7b": "~/kv-hack/mistral-model", "llama-3-8b": "~/kv-hack/llama-model", } model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) results_dir = f"~/kv-hack/results/{MODEL_NAME}" os.makedirs(os.path.expanduser(results_dir), exist_ok=True) # ──────────────────────────────────────────────────── print(f"Running baseline for: {MODEL_NAME}") print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float16, device_map="cuda" ) model.eval() results = {} for ctx_len in [1024, 4096, 8192]: print(f"\nTesting context length: {ctx_len}") input_ids = torch.randint(1, 1000, (1, ctx_len)).cuda() # warmup with torch.no_grad(): for _ in range(2): out = model(input_ids, use_cache=True) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() # measure times = [] with torch.no_grad(): for _ in range(5): t0 = time.time() out = model(input_ids, use_cache=True) torch.cuda.synchronize() times.append(time.time() - t0) peak_mem = torch.cuda.max_memory_allocated() / 1e9 avg_time = sum(times) / len(times) results[ctx_len] = { "peak_memory_gb": round(peak_mem, 2), "avg_prefill_ms": round(avg_time * 1000, 1), } print(f" Peak memory: {peak_mem:.2f} GB") print(f" Avg prefill: {avg_time*1000:.1f} ms") # decode speed print("\nTesting decode speed...") input_ids = torch.randint(1, 1000, (1, 512)).cuda() with torch.no_grad(): t0 = time.time() out = model.generate( input_ids, max_new_tokens=100, do_sample=False, pad_token_id=tokenizer.eos_token_id ) torch.cuda.synchronize() elapsed = time.time() - t0 tokens_per_sec = 100 / elapsed results["decode_tokens_per_sec"] = round(tokens_per_sec, 1) print(f" Decode speed: {tokens_per_sec:.1f} tokens/sec") # save out_path = os.path.expanduser(f"{results_dir}/baseline.json") with open(out_path, "w") as f: json.dump(results, f, indent=2) print(f"\n✅ Saved to {out_path}") print(json.dumps(results, indent=2))