| |
| """ |
| Gemma4 Prometheus evaluation script. |
| |
| Tests: |
| 1. Coherent text generation (GPTQ model, 2-GPU pipeline parallel) |
| 2. Max context length search with FP16 KV cache |
| 3. Max context length search with FP8 KV cache (software quantization) |
| 4. Perplexity on WikiText-2 (GPTQ model) |
| 5. KL divergence: GPTQ-4bit vs merged model (bnb-8bit reference) |
| |
| GPU setup: GPU-828df6fd (phys 0) + GPU-89c6bfdc (phys 4) β logical 0,1 |
| Both fully free, 24 GB each, 48 GB total. |
| """ |
|
|
| import os, sys, gc, json, time, math |
| os.environ.update({ |
| "CUDA_VISIBLE_DEVICES": "GPU-828df6fd-3fd0-ed25-0b2b-2b6d9d8dca47,GPU-89c6bfdc-6f42-d312-de77-a9fb1ae370d8", |
| "CUDA_DEVICE_ORDER": "PCI_BUS_ID", |
| "PYTORCH_ALLOC_CONF": "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7", |
| "HF_HUB_DISABLE_PROGRESS_BARS": "1", |
| "TOKENIZERS_PARALLELISM": "false", |
| }) |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| GPTQ_DIR = "/home/op/outputs/gemma4-prometheus/gptq-4bit" |
| MERGED_DIR = "/home/op/outputs/gemma4-prometheus/merged-model" |
| RESULTS_DIR = "/home/op/outputs/gemma4-prometheus/eval" |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
| RESULTS = {} |
|
|
| |
|
|
| def log(msg): print(f"[EVAL] {msg}", flush=True) |
|
|
| def free_vram(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| def vram_used(): |
| used = [] |
| for i in range(torch.cuda.device_count()): |
| used.append(torch.cuda.memory_allocated(i) // 1024**2) |
| return used |
|
|
| |
|
|
| class Fp8DynamicCache: |
| """ |
| FP8 KV cache: subclasses DynamicCache and stores K/V tensors in |
| torch.float8_e4m3fn format (half the memory of FP16/BF16). |
| Dequantizes to BF16 before returning to attention. |
| |
| Per-tensor symmetric quantization: scale = max(|T|) / 448. |
| RTX 3090 (Ampere sm86) stores FP8 but computes in BF16 (software FP8). |
| """ |
|
|
| def __init__(self): |
| from transformers import DynamicCache |
| |
| self._dc = DynamicCache() |
| |
| self._fp8_key = [] |
| self._fp8_val = [] |
| self._scale_k = [] |
| self._scale_v = [] |
|
|
| |
| def __getattr__(self, name): |
| |
| |
| try: |
| return object.__getattribute__(self, '_dc').__getattribute__(name) |
| except AttributeError: |
| raise AttributeError(name) |
|
|
| |
| @staticmethod |
| def _to_fp8(t: torch.Tensor): |
| scale = t.detach().abs().max().float() / 448.0 + 1e-12 |
| q = (t.float() / scale).clamp(-448, 448).to(torch.float8_e4m3fn) |
| return q, scale |
|
|
| @staticmethod |
| def _from_fp8(q: torch.Tensor, scale: torch.Tensor, dtype): |
| return q.to(dtype) * scale.to(dtype) |
|
|
| |
| def update(self, key_states, value_states, layer_idx, cache_kwargs=None): |
| dtype = key_states.dtype |
| qk, sk = self._to_fp8(key_states) |
| qv, sv = self._to_fp8(value_states) |
|
|
| if len(self._fp8_key) <= layer_idx: |
| self._fp8_key.append(qk) |
| self._fp8_val.append(qv) |
| self._scale_k.append(sk) |
| self._scale_v.append(sv) |
| else: |
| |
| self._fp8_key[layer_idx] = torch.cat([self._fp8_key[layer_idx], qk], dim=-2) |
| self._fp8_val[layer_idx] = torch.cat([self._fp8_val[layer_idx], qv], dim=-2) |
| |
| self._scale_k[layer_idx] = torch.maximum(self._scale_k[layer_idx], sk) |
| self._scale_v[layer_idx] = torch.maximum(self._scale_v[layer_idx], sv) |
|
|
| |
| if layer_idx == 0: |
| self._dc._seen_tokens += key_states.shape[-2] |
|
|
| k_out = self._from_fp8(self._fp8_key[layer_idx], self._scale_k[layer_idx], dtype) |
| v_out = self._from_fp8(self._fp8_val[layer_idx], self._scale_v[layer_idx], dtype) |
| return k_out, v_out |
|
|
| def get_seq_length(self, layer_idx=0): |
| if not self._fp8_key: |
| return 0 |
| return self._fp8_key[0].shape[-2] |
|
|
| def get_max_length(self): |
| return None |
|
|
| def __len__(self): |
| return len(self._fp8_key) |
|
|
| @property |
| def seen_tokens(self): |
| return self._dc._seen_tokens |
|
|
| |
|
|
| def load_gptq_model(device_map="balanced"): |
| from gptqmodel import GPTQModel |
| from transformers import AutoTokenizer |
|
|
| log(f"Loading GPTQ model from {GPTQ_DIR} [device_map={device_map}]") |
| max_mem = {0: "22GiB", 1: "22GiB", "cpu": "40GiB"} |
| tok = AutoTokenizer.from_pretrained(GPTQ_DIR) |
| model = GPTQModel.load( |
| GPTQ_DIR, |
| device_map=device_map, |
| max_memory=max_mem, |
| ) |
| model.eval() |
| log(f"GPTQ model loaded. VRAM used (MiB): {vram_used()}") |
| return model, tok |
|
|
|
|
| def load_merged_model_bnb8(): |
| from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig |
| log(f"Loading merged model (bnb-8bit) from {MERGED_DIR}") |
| bnb_cfg = BitsAndBytesConfig(load_in_8bit=True) |
| max_mem = {0: "23GiB", 1: "23GiB", "cpu": "60GiB"} |
| tok = AutoTokenizer.from_pretrained(MERGED_DIR) |
| model = AutoModelForImageTextToText.from_pretrained( |
| MERGED_DIR, |
| quantization_config=bnb_cfg, |
| device_map="auto", |
| max_memory=max_mem, |
| ) |
| model.eval() |
| log(f"Merged model loaded (bnb-8bit). VRAM: {vram_used()}") |
| return model, tok |
|
|
| |
|
|
| COHERENCE_PROMPTS = [ |
| "Explain how neural networks learn from data.", |
| "What is the difference between supervised and unsupervised learning?", |
| "Describe the concept of gradient descent in machine learning.", |
| "What are transformers in the context of natural language processing?", |
| "Explain what quantization means for neural network models.", |
| ] |
|
|
| def test_coherence(model, tok): |
| log("=== Coherence Test ===") |
| results = [] |
| for prompt in COHERENCE_PROMPTS: |
| messages = [{"role": "user", "content": prompt}] |
| text = tok.apply_chat_template( |
| messages, tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
| ids = tok(text, return_tensors="pt").input_ids |
| dev = next(model.parameters()).device |
| ids = ids.to(dev) |
| with torch.no_grad(): |
| out = model.generate( |
| ids, |
| max_new_tokens=256, |
| do_sample=False, |
| temperature=None, |
| top_p=None, |
| pad_token_id=tok.eos_token_id, |
| ) |
| new = out[0, ids.shape[1]:] |
| response = tok.decode(new, skip_special_tokens=True).strip() |
| ok = len(response.split()) >= 15 |
| results.append({"prompt": prompt, "response": response[:600], "ok": ok}) |
| log(f" Q: {prompt[:60]}...") |
| log(f" A: {response[:200]}...") |
| log(f" OK: {ok}") |
|
|
| RESULTS["coherence"] = { |
| "passed": sum(r["ok"] for r in results), |
| "total": len(results), |
| "samples": results, |
| } |
| return results |
|
|
| |
|
|
| def _try_context(model, tok, seq_len, use_fp8_cache=False): |
| """Return True if forward-pass over seq_len tokens succeeds without OOM.""" |
| free_vram() |
| prompt = "The quick brown fox jumps over the lazy dog. " * (seq_len // 10 + 1) |
| ids = tok(prompt, return_tensors="pt", truncation=True, max_length=seq_len).input_ids |
| actual_len = ids.shape[1] |
| dev = next(model.parameters()).device |
| ids = ids.to(dev) |
| try: |
| with torch.no_grad(): |
| if use_fp8_cache: |
| past = Fp8DynamicCache() |
| _ = model(input_ids=ids, past_key_values=past, use_cache=True) |
| else: |
| |
| _ = model(input_ids=ids, use_cache=True) |
| free_vram() |
| return True, actual_len |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: |
| if "out of memory" in str(e).lower() or "CUDA" in str(e).upper(): |
| free_vram() |
| return False, actual_len |
| raise |
|
|
| def search_max_context(model, tok, use_fp8_cache=False, lo=1024, hi=200_000, label=""): |
| """Binary search for max context length.""" |
| log(f"=== Context search: {label} ===") |
| |
| ok, _ = _try_context(model, tok, lo, use_fp8_cache) |
| if not ok: |
| log(f" Even {lo} tokens failed!") |
| return lo |
|
|
| last_ok = lo |
| |
| candidates = [2048, 4096, 8192, 16384, 32768, 65536, 100000, 131072, 160000, 200000] |
| coarse_hi = lo |
| for c in candidates: |
| if c > hi: |
| break |
| log(f" Trying {c} tokens...") |
| ok, _ = _try_context(model, tok, c, use_fp8_cache) |
| if ok: |
| last_ok = c |
| coarse_hi = c |
| else: |
| hi = c |
| break |
|
|
| |
| lo = last_ok |
| while lo < hi - 512: |
| mid = (lo + hi) // 2 |
| log(f" Binary search: lo={lo} mid={mid} hi={hi}") |
| ok, _ = _try_context(model, tok, mid, use_fp8_cache) |
| if ok: |
| lo = mid |
| last_ok = mid |
| else: |
| hi = mid |
|
|
| log(f" Max context ({label}): {last_ok} tokens") |
| return last_ok |
|
|
| |
|
|
| def compute_perplexity(model, tok, stride=512, max_tokens=4096, dataset_name="wikitext-2-raw-v1"): |
| """Sliding-window perplexity on WikiText-2.""" |
| log("=== Perplexity (WikiText-2) ===") |
| from datasets import load_dataset |
| data = load_dataset("wikitext", dataset_name, split="test") |
| text = "\n\n".join(data["text"]) |
|
|
| encodings = tok(text, return_tensors="pt") |
| input_ids = encodings.input_ids[0][:max_tokens] |
| seq_len = input_ids.shape[0] |
|
|
| nlls = [] |
| dev = next(model.parameters()).device |
| pbar = range(0, seq_len, stride) |
| for begin in pbar: |
| end = min(begin + stride * 2, seq_len) |
| chunk = input_ids[begin:end].unsqueeze(0).to(dev) |
| target_len = min(stride, end - begin) |
| labels = chunk.clone() |
| |
| labels[0, :-target_len] = -100 |
| with torch.no_grad(): |
| out = model(input_ids=chunk, labels=labels) |
| nll = out.loss |
| if not torch.isnan(nll) and not torch.isinf(nll): |
| nlls.append(nll.item()) |
| if len(nlls) % 5 == 0: |
| log(f" Progress: {begin}/{seq_len}, current ppl={math.exp(sum(nlls)/len(nlls)):.2f}") |
|
|
| ppl = math.exp(sum(nlls) / len(nlls)) |
| log(f" Perplexity: {ppl:.4f}") |
| return ppl |
|
|
| |
|
|
| KL_PROMPTS = [ |
| "Explain the concept of entropy in information theory.", |
| "What is backpropagation and how does it work?", |
| "Describe the attention mechanism in transformer models.", |
| "What are the main differences between RNNs and transformers?", |
| "How does weight quantization affect model accuracy?", |
| "Explain the curse of dimensionality in machine learning.", |
| "What is transfer learning and when is it useful?", |
| "Describe how a convolutional neural network processes images.", |
| ] |
|
|
| def get_logits(model, tok, prompts, max_len=512): |
| """Return (prompt, logits_tensor) for each prompt (logits over vocab for next token).""" |
| dev = next(model.parameters()).device |
| all_logits = [] |
| for p in prompts: |
| ids = tok(p, return_tensors="pt", truncation=True, max_length=max_len).input_ids.to(dev) |
| with torch.no_grad(): |
| out = model(input_ids=ids) |
| |
| logits = out.logits[0, -1, :].float().cpu() |
| all_logits.append(logits) |
| return all_logits |
|
|
|
|
| def compute_kl_divergence(logits_ref, logits_cmp, top_k=1000): |
| """KL(ref || cmp) averaged over prompts, using top-k tokens.""" |
| kl_vals = [] |
| for lr, lc in zip(logits_ref, logits_cmp): |
| |
| vals, idx = lr.topk(top_k) |
| lc_sub = lc[idx] |
| p = F.softmax(vals, dim=-1).double() |
| q = F.softmax(lc_sub, dim=-1).double() |
| q = q.clamp(min=1e-10) |
| kl = (p * (p.log() - q.log())).sum().item() |
| kl_vals.append(kl) |
| return float(np.mean(kl_vals)), float(np.std(kl_vals)) |
|
|
| |
|
|
| def main(): |
| log("=" * 60) |
| log("Gemma4 Prometheus Evaluation Suite") |
| log(f"CUDA devices visible: {os.environ.get('CUDA_VISIBLE_DEVICES','')}") |
| log(f"GPU count: {torch.cuda.device_count()}") |
| for i in range(torch.cuda.device_count()): |
| props = torch.cuda.get_device_properties(i) |
| log(f" GPU {i}: {props.name}, {props.total_memory // 1024**2} MiB") |
| log("=" * 60) |
|
|
| |
| log("\n[Phase 1] Loading GPTQ model (2 GPU pipeline parallel)...") |
| gptq_model, tok = load_gptq_model() |
|
|
| RESULTS["setup"] = { |
| "gptq_model_path": GPTQ_DIR, |
| "merged_model_path": MERGED_DIR, |
| "gpus": [ |
| {"index": i, |
| "name": torch.cuda.get_device_properties(i).name, |
| "total_mib": torch.cuda.get_device_properties(i).total_memory // 1024**2} |
| for i in range(torch.cuda.device_count()) |
| ], |
| "parallelism": "pipeline_parallel_device_map", |
| "note": "True tensor-parallelism requires vLLM which does not yet support Gemma4 architecture", |
| } |
|
|
| |
| log("\n[Phase 1a] Coherence test...") |
| test_coherence(gptq_model, tok) |
|
|
| |
| log("\n[Phase 1b] Context search: FP16 KV cache...") |
| try: |
| max_fp16 = search_max_context(gptq_model, tok, use_fp8_cache=False, |
| lo=2048, hi=200_000, label="FP16-KV") |
| RESULTS["max_context_fp16"] = max_fp16 |
| except Exception as e: |
| log(f"Context search FP16 failed: {e}") |
| RESULTS["max_context_fp16"] = f"ERROR: {e}" |
|
|
| |
| log("\n[Phase 1c] Context search: FP8 KV cache...") |
| try: |
| max_fp8 = search_max_context(gptq_model, tok, use_fp8_cache=True, |
| lo=2048, hi=200_000, label="FP8-KV") |
| RESULTS["max_context_fp8"] = max_fp8 |
| except Exception as e: |
| log(f"Context search FP8 failed: {e}") |
| RESULTS["max_context_fp8"] = f"ERROR: {e}" |
|
|
| |
| log("\n[Phase 1d] Perplexity (WikiText-2)...") |
| try: |
| ppl = compute_perplexity(gptq_model, tok) |
| RESULTS["perplexity_gptq"] = {"value": ppl, "dataset": "wikitext-2-raw-v1"} |
| except Exception as e: |
| log(f"Perplexity failed: {e}") |
| RESULTS["perplexity_gptq"] = {"error": str(e)} |
|
|
| |
| log("\n[Phase 1e] Computing GPTQ logits for KL divergence...") |
| try: |
| gptq_logits = get_logits(gptq_model, tok, KL_PROMPTS) |
| log(f" Got {len(gptq_logits)} logit tensors from GPTQ model") |
| except Exception as e: |
| log(f"GPTQ logits failed: {e}") |
| gptq_logits = None |
|
|
| |
| log("\n[Phase 1] Unloading GPTQ model...") |
| del gptq_model |
| free_vram() |
|
|
| |
| log("\n[Phase 2] Loading merged model (bnb-8bit) for KL reference...") |
| try: |
| ref_model, ref_tok = load_merged_model_bnb8() |
|
|
| |
| log("\n[Phase 2a] Perplexity of merged model (bnb-8bit)...") |
| try: |
| ppl_ref = compute_perplexity(ref_model, ref_tok) |
| RESULTS["perplexity_merged_bnb8"] = { |
| "value": ppl_ref, |
| "dataset": "wikitext-2-raw-v1", |
| "note": "bnb-8bit quantized for memory", |
| } |
| except Exception as e: |
| log(f"Merged perplexity failed: {e}") |
| RESULTS["perplexity_merged_bnb8"] = {"error": str(e)} |
|
|
| |
| log("\n[Phase 2b] Computing merged model logits for KL divergence...") |
| if gptq_logits is not None: |
| ref_logits = get_logits(ref_model, ref_tok, KL_PROMPTS) |
| kl_mean, kl_std = compute_kl_divergence(ref_logits, gptq_logits) |
| log(f" KL(merged || gptq-4bit) mean={kl_mean:.4f} std={kl_std:.4f}") |
| RESULTS["kl_divergence"] = { |
| "mean": kl_mean, |
| "std": kl_std, |
| "direction": "KL(merged_bnb8 || gptq_4bit)", |
| "num_prompts": len(KL_PROMPTS), |
| "top_k_tokens": 1000, |
| "note": "Merged model loaded in bnb-8bit; adds small reference noise", |
| } |
| else: |
| log(" Skipping KL: GPTQ logits unavailable") |
|
|
| del ref_model |
| free_vram() |
| except Exception as e: |
| log(f"Phase 2 failed: {e}") |
| import traceback; traceback.print_exc() |
| RESULTS["phase2_error"] = str(e) |
|
|
| |
| out_path = os.path.join(RESULTS_DIR, "eval_results.json") |
| with open(out_path, "w") as f: |
| json.dump(RESULTS, f, indent=2, default=str) |
| log(f"\n=== Results saved to {out_path} ===") |
|
|
| |
| log("\n" + "=" * 60) |
| log("EVALUATION SUMMARY") |
| log("=" * 60) |
| if "coherence" in RESULTS: |
| c = RESULTS["coherence"] |
| log(f"Coherence: {c['passed']}/{c['total']} prompts OK") |
| if "max_context_fp16" in RESULTS: |
| log(f"Max ctx FP16: {RESULTS['max_context_fp16']} tokens") |
| if "max_context_fp8" in RESULTS: |
| log(f"Max ctx FP8: {RESULTS['max_context_fp8']} tokens") |
| if "perplexity_gptq" in RESULTS: |
| pv = RESULTS["perplexity_gptq"] |
| log(f"PPL GPTQ-4bit: {pv.get('value','error'):.4f}") |
| if "perplexity_merged_bnb8" in RESULTS: |
| pv = RESULTS["perplexity_merged_bnb8"] |
| log(f"PPL merged-8bit:{pv.get('value','error'):.4f}") |
| if "kl_divergence" in RESULTS: |
| kl = RESULTS["kl_divergence"] |
| log(f"KL divergence: mean={kl['mean']:.4f} std={kl['std']:.4f}") |
| log("=" * 60) |
|
|
| if __name__ == "__main__": |
| main() |
|
|