| |
|
| | """
|
| | Coherence evaluation for language models.
|
| |
|
| | Measures what standard benchmarks can't see:
|
| | Tier 1 β Generation diversity (repetition, collapse detection)
|
| | Tier 2 β Multi-distance prediction (context utilization, skip accuracy)
|
| | Tier 3 β Semantic consistency (chunk similarity over long generations)
|
| |
|
| | Usage:
|
| | # Custom checkpoint
|
| | python -m circuits.coherence_eval --checkpoint circuits/checkpoints/model/best.pt
|
| |
|
| | # HuggingFace model
|
| | python -m circuits.coherence_eval --model gpt2
|
| |
|
| | # Compare models
|
| | python -m circuits.coherence_eval --model EleutherAI/pythia-160m --gpu 0
|
| |
|
| | # Quick test (fewer prompts, shorter generation)
|
| | python -m circuits.coherence_eval --checkpoint path/to/model.pt --num-prompts 5 --gen-length 256
|
| |
|
| | # Run specific tiers
|
| | python -m circuits.coherence_eval --checkpoint path/to/model.pt --tiers 1,3
|
| | """
|
| |
|
| | import argparse
|
| | import json
|
| | import math
|
| | import sys
|
| | import time
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | DEFAULT_PROMPTS = [
|
| | "A thought observing itself discovers that it",
|
| | "The history of science shows that",
|
| | "In the middle of the night, the old house",
|
| | "The relationship between language and thought has been",
|
| | "When the first settlers arrived, they found",
|
| | "The mathematical proof begins by assuming",
|
| | "She opened the door to find",
|
| | "The economic implications of this policy",
|
| | "Deep beneath the ocean surface, researchers discovered",
|
| | "The most important lesson from this experiment is",
|
| | "According to recent studies, the human brain",
|
| | "The old library contained books that",
|
| | "As the temperature continued to rise, the effects on",
|
| | "The development of artificial intelligence has raised questions about",
|
| | "In the small village at the foot of the mountain",
|
| | "The fundamental principles of democracy require",
|
| | "Looking through the telescope, the astronomer noticed",
|
| | "The relationship between music and emotion",
|
| | "During the industrial revolution, working conditions",
|
| | "The ancient manuscript revealed secrets about",
|
| | ]
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class ModelWrapper:
|
| | """Unified interface for custom circuit models and HuggingFace models."""
|
| |
|
| | def __init__(self, model, tokenizer, device, model_type="hf",
|
| | skip_head=None, skip_k=0, max_seq_len=1024, name="unknown"):
|
| | self.model = model
|
| | self.tokenizer = tokenizer
|
| | self.device = device
|
| | self.model_type = model_type
|
| | self.skip_head = skip_head
|
| | self.skip_k = skip_k
|
| | self.max_seq_len = max_seq_len
|
| | self.name = name
|
| |
|
| | @classmethod
|
| | def from_checkpoint(cls, path, device):
|
| | """Load a custom circuit model from checkpoint."""
|
| | from .config import CircuitConfig
|
| | from .model import CircuitTransformer
|
| | from .mirrored import MirroredConfig, MirroredTransformer
|
| | from .slotted_mirrored import SlotMirroredConfig, SlotMirroredTransformer
|
| | from .data import get_tokenizer
|
| |
|
| | checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| | model_type = checkpoint.get("model_type", "standard")
|
| |
|
| | if model_type == "slot_mirrored":
|
| | config = SlotMirroredConfig.from_dict(checkpoint["config"])
|
| | model = SlotMirroredTransformer(config).to(device)
|
| | arch_desc = f"SlotMirrored ({config.n_slots} slots)"
|
| | elif model_type == "mirrored":
|
| | config = MirroredConfig.from_dict(checkpoint["config"])
|
| | model = MirroredTransformer(config).to(device)
|
| | arch_desc = "Mirrored"
|
| | else:
|
| | config = CircuitConfig.from_dict(checkpoint["config"])
|
| | model = CircuitTransformer(config).to(device)
|
| | arch_desc = "Standard"
|
| |
|
| |
|
| | state_dict = checkpoint["model"]
|
| | if any(k.startswith("_orig_mod.") for k in state_dict):
|
| | state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| | model.load_state_dict(state_dict)
|
| | model.eval()
|
| |
|
| | tokenizer = get_tokenizer()
|
| | skip_head = model.skip_head if hasattr(model, 'skip_head') else None
|
| | skip_k = getattr(config, 'aux_skip_k', 0)
|
| | max_seq_len = config.max_seq_len
|
| |
|
| | params = sum(p.numel() for p in model.parameters()) / 1e6
|
| | name = f"{Path(path).parent.name}/{Path(path).stem} ({arch_desc}, {params:.1f}M)"
|
| |
|
| | return cls(model, tokenizer, device, model_type="circuit",
|
| | skip_head=skip_head, skip_k=skip_k,
|
| | max_seq_len=max_seq_len, name=name)
|
| |
|
| | @classmethod
|
| | def from_pretrained(cls, model_name, device):
|
| | """Load a HuggingFace model."""
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | model_name, trust_remote_code=True,
|
| | torch_dtype=torch.float32,
|
| | ).to(device)
|
| | model.eval()
|
| |
|
| | max_seq_len = getattr(model.config, 'max_position_embeddings', 1024)
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| | params = sum(p.numel() for p in model.parameters()) / 1e6
|
| | name = f"{model_name} ({params:.1f}M)"
|
| |
|
| | return cls(model, tokenizer, device, model_type="hf",
|
| | max_seq_len=max_seq_len, name=name)
|
| |
|
| | @property
|
| | def has_skip_head(self):
|
| | return self.skip_head is not None and self.skip_k > 0
|
| |
|
| | def generate(self, prompt_text, max_new_tokens=512):
|
| | """Generate tokens at temperature 0 (greedy). Returns generated token IDs only."""
|
| | prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
|
| |
|
| | with torch.no_grad():
|
| | if self.model_type == "hf":
|
| | output_ids = self.model.generate(
|
| | prompt_ids,
|
| | max_new_tokens=max_new_tokens,
|
| | do_sample=True,
|
| | pad_token_id=self.tokenizer.pad_token_id,
|
| | temperature=0.8,
|
| | top_k=50,
|
| | top_p=0.9,
|
| | repetition_penalty=1.2,
|
| | )
|
| | else:
|
| | output_ids = self.model.generate(
|
| | prompt_ids,
|
| | max_new_tokens=max_new_tokens,
|
| | temperature=0.8,
|
| | top_k=50,
|
| | top_p=0.9,
|
| | repetition_penalty=1.2,
|
| | )
|
| |
|
| |
|
| | gen_ids = output_ids[0, prompt_ids.shape[1]:]
|
| | return prompt_ids[0], gen_ids
|
| |
|
| | def forward_with_hidden(self, input_ids):
|
| | """Forward pass returning (logits, hidden_states, skip_logits_or_None).
|
| | input_ids: [1, L] tensor.
|
| | """
|
| | with torch.no_grad():
|
| | if self.model_type == "hf":
|
| | outputs = self.model(input_ids, output_hidden_states=True)
|
| | logits = outputs.logits
|
| | hidden = outputs.hidden_states[-1]
|
| | return logits, hidden, None
|
| | else:
|
| |
|
| | hidden_capture = {}
|
| |
|
| | def hook_fn(module, inp, output):
|
| | hidden_capture['h'] = output.detach()
|
| |
|
| | handle = self.model.norm.register_forward_hook(hook_fn)
|
| | output = self.model(input_ids)
|
| | handle.remove()
|
| |
|
| | logits = output['logits']
|
| | hidden = hidden_capture['h']
|
| |
|
| | skip_logits = None
|
| | if self.has_skip_head:
|
| | skip_logits = self.skip_head(hidden)
|
| |
|
| | return logits, hidden, skip_logits
|
| |
|
| | def forward(self, input_ids):
|
| | """Forward pass returning logits only. input_ids: [1, L] tensor."""
|
| | with torch.no_grad():
|
| | if self.model_type == "hf":
|
| | return self.model(input_ids).logits
|
| | else:
|
| | return self.model(input_ids)['logits']
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def generate_all(wrapper, prompts, gen_length):
|
| | """Generate from all prompts. Returns list of (prompt_text, prompt_ids, gen_ids)."""
|
| | results = []
|
| | for prompt in prompts:
|
| | prompt_ids, gen_ids = wrapper.generate(prompt, max_new_tokens=gen_length)
|
| | results.append((prompt, prompt_ids, gen_ids))
|
| | print(f" [{len(results)}/{len(prompts)}] {len(gen_ids)} tokens", end="\r")
|
| | print()
|
| | return results
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def ngrams(tokens, n):
|
| | """Extract n-grams from token list."""
|
| | return [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
|
| |
|
| |
|
| | def compute_diversity(gen_ids):
|
| | """Compute diversity metrics for a single generation."""
|
| | tokens = gen_ids.tolist()
|
| | n = len(tokens)
|
| | if n < 4:
|
| | return {"unique_1g": 0, "unique_2g": 0, "unique_3g": 0, "unique_4g": 0,
|
| | "max_repeat": n, "collapsed": True}
|
| |
|
| | results = {}
|
| | for k in [1, 2, 3, 4]:
|
| | grams = ngrams(tokens, k)
|
| | results[f"unique_{k}g"] = len(set(grams)) / len(grams) if grams else 0.0
|
| |
|
| |
|
| | max_repeat = 1
|
| | current = 1
|
| | for i in range(1, n):
|
| | if tokens[i] == tokens[i - 1]:
|
| | current += 1
|
| | max_repeat = max(max_repeat, current)
|
| | else:
|
| | current = 1
|
| | results["max_repeat"] = max_repeat
|
| |
|
| |
|
| | max_ngram_repeat = 1
|
| | for ng_size in [2, 3, 4, 5, 8]:
|
| | grams = ngrams(tokens, ng_size)
|
| | streak = 1
|
| | for i in range(1, len(grams)):
|
| | if grams[i] == grams[i - 1]:
|
| | streak += 1
|
| | max_ngram_repeat = max(max_ngram_repeat, streak * ng_size)
|
| | else:
|
| | streak = 1
|
| | results["max_ngram_repeat_span"] = max_ngram_repeat
|
| |
|
| |
|
| | results["collapsed"] = (results["unique_4g"] < 0.5) or (max_ngram_repeat > n * 0.25)
|
| |
|
| | return results
|
| |
|
| |
|
| | def eval_diversity(generations, tokenizer, show_samples=3):
|
| | """Tier 1: Compute diversity metrics from pre-generated text."""
|
| | print("\n" + "=" * 60)
|
| | print("TIER 1: Generation Diversity")
|
| | print("=" * 60)
|
| |
|
| | all_metrics = []
|
| | sample_texts = []
|
| |
|
| | for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| | metrics = compute_diversity(gen_ids)
|
| | metrics["prompt"] = prompt
|
| | metrics["gen_length"] = len(gen_ids)
|
| | all_metrics.append(metrics)
|
| |
|
| | if i < show_samples:
|
| | text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| | sample_texts.append((prompt, text))
|
| |
|
| | n = len(all_metrics)
|
| | if n == 0:
|
| | print(" No generations to evaluate.")
|
| | return {}
|
| |
|
| |
|
| | agg = {}
|
| | for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g",
|
| | "max_repeat", "max_ngram_repeat_span"]:
|
| | values = [m[key] for m in all_metrics]
|
| | agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
| |
|
| | collapse_count = sum(1 for m in all_metrics if m["collapsed"])
|
| | agg["collapse_rate"] = collapse_count / n
|
| | avg_len = sum(m["gen_length"] for m in all_metrics) / n
|
| |
|
| |
|
| | print(f"\n Prompts evaluated: {n}")
|
| | print(f" Avg generation length: {avg_len:.0f} tokens")
|
| | print()
|
| | print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| | print(f" {'-' * 50}")
|
| | for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g"]:
|
| | m = agg[key]
|
| | print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
| | for key in ["max_repeat", "max_ngram_repeat_span"]:
|
| | m = agg[key]
|
| | print(f" {key:<24} {m['mean']:>8.1f} {int(m['min']):>8d} {int(m['max']):>8d}")
|
| | print(f"\n Collapse rate: {collapse_count}/{n} ({agg['collapse_rate']:.1%})")
|
| |
|
| |
|
| | if sample_texts:
|
| | print(f"\n --- Sample generations (first {len(sample_texts)}) ---")
|
| | for prompt, text in sample_texts:
|
| | print(f"\n Prompt: \"{prompt}\"")
|
| | preview = text[:400].replace("\n", " ")
|
| | if len(text) > 400:
|
| | preview += "..."
|
| | print(f" Output: {preview}")
|
| |
|
| | return {"per_prompt": all_metrics, "aggregate": agg}
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def prepare_eval_sequences(wrapper, num_sequences=50, data_source=None):
|
| | """Prepare ground truth sequences for Tier 2."""
|
| | max_len = wrapper.max_seq_len
|
| |
|
| | if data_source and Path(data_source).exists():
|
| | with open(data_source) as f:
|
| | text = f.read()
|
| | all_ids = wrapper.tokenizer.encode(text)
|
| | else:
|
| | try:
|
| | from datasets import load_dataset
|
| | print(" Loading WikiText-103 validation...")
|
| | ds = load_dataset("wikitext", "wikitext-103-raw-v1",
|
| | split="validation", trust_remote_code=True)
|
| | text = "\n".join(row["text"] for row in ds if row["text"].strip())
|
| | all_ids = wrapper.tokenizer.encode(text)
|
| | except Exception as e:
|
| | print(f" Could not load eval data: {e}")
|
| | print(f" Install 'datasets' or use --eval-data to provide a text file.")
|
| | return None
|
| |
|
| |
|
| | sequences = []
|
| | for i in range(0, len(all_ids) - max_len, max_len):
|
| | seq = torch.tensor(all_ids[i:i + max_len], dtype=torch.long)
|
| | sequences.append(seq)
|
| | if len(sequences) >= num_sequences:
|
| | break
|
| |
|
| | if len(sequences) < 2:
|
| | print(" Not enough text for evaluation sequences.")
|
| | return None
|
| |
|
| | print(f" Prepared {len(sequences)} sequences of {max_len} tokens")
|
| | return sequences
|
| |
|
| |
|
| | def eval_context_utilization(wrapper, sequences):
|
| | """Tier 2a: Per-position perplexity grouped by depth bucket."""
|
| | max_len = wrapper.max_seq_len
|
| |
|
| |
|
| | bucket_bounds = [0, 64, 128, 256, 512]
|
| | if max_len > 512:
|
| | bucket_bounds.append(max_len)
|
| | else:
|
| | bucket_bounds.append(max_len)
|
| |
|
| | bucket_bounds = sorted(set(b for b in bucket_bounds if b <= max_len))
|
| | if bucket_bounds[-1] < max_len:
|
| | bucket_bounds.append(max_len)
|
| | buckets = [(bucket_bounds[i], bucket_bounds[i + 1])
|
| | for i in range(len(bucket_bounds) - 1)]
|
| |
|
| |
|
| | all_losses = []
|
| | for seq in sequences:
|
| | input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| | logits = wrapper.forward(input_ids)
|
| |
|
| | shift_logits = logits[0, :-1]
|
| | shift_labels = input_ids[0, 1:]
|
| | per_token_loss = F.cross_entropy(shift_logits, shift_labels, reduction='none')
|
| | all_losses.append(per_token_loss.cpu())
|
| | print(f" [{len(all_losses)}/{len(sequences)}]", end="\r")
|
| | print()
|
| |
|
| |
|
| | stacked = torch.stack(all_losses)
|
| | bucket_results = {}
|
| | for start, end in buckets:
|
| | s = min(start, stacked.shape[1])
|
| | e = min(end, stacked.shape[1])
|
| | if s >= e:
|
| | continue
|
| | bucket_losses = stacked[:, s:e]
|
| | avg_loss = bucket_losses.mean().item()
|
| | bucket_results[f"{start}-{end}"] = {
|
| | "loss": avg_loss,
|
| | "ppl": math.exp(min(avg_loss, 20)),
|
| | "n_tokens": bucket_losses.numel(),
|
| | }
|
| |
|
| | return bucket_results
|
| |
|
| |
|
| | def eval_skip_accuracy(wrapper, sequences, distances):
|
| | """Tier 2b: Skip head prediction accuracy at various distances."""
|
| | if not wrapper.has_skip_head:
|
| | return None
|
| |
|
| | results = {f"t+{K}": {"top1": [], "top5": []} for K in distances}
|
| |
|
| | for seq in sequences:
|
| | input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| | _, hidden, _ = wrapper.forward_with_hidden(input_ids)
|
| |
|
| | for K in distances:
|
| | if K >= input_ids.shape[1]:
|
| | continue
|
| |
|
| | skip_logits = wrapper.skip_head(hidden)
|
| | targets = input_ids[0, K:]
|
| | preds = skip_logits[0, :-K]
|
| |
|
| | top1 = (preds.argmax(-1) == targets).float().mean().item()
|
| | top5_indices = preds.topk(min(5, preds.shape[-1]), dim=-1).indices
|
| | top5 = (top5_indices == targets.unsqueeze(-1)).any(-1).float().mean().item()
|
| |
|
| | results[f"t+{K}"]["top1"].append(top1)
|
| | results[f"t+{K}"]["top5"].append(top5)
|
| |
|
| | print(f" [{len(results['t+' + str(distances[0])]['top1'])}/{len(sequences)}]", end="\r")
|
| | print()
|
| |
|
| |
|
| | avg_results = {}
|
| | for key in sorted(results.keys(), key=lambda x: int(x.split("+")[1])):
|
| | vals = results[key]
|
| | if vals["top1"]:
|
| | avg_results[key] = {
|
| | "top1": sum(vals["top1"]) / len(vals["top1"]),
|
| | "top5": sum(vals["top5"]) / len(vals["top5"]),
|
| | }
|
| |
|
| | return avg_results
|
| |
|
| |
|
| | def eval_structural(wrapper, eval_data, distances, num_sequences):
|
| | """Run Tier 2 evaluation."""
|
| | print("\n" + "=" * 60)
|
| | print("TIER 2: Structural Prediction")
|
| | print("=" * 60)
|
| |
|
| | sequences = prepare_eval_sequences(wrapper, num_sequences, eval_data)
|
| | if sequences is None:
|
| | return {"context_utilization": None, "skip_accuracy": None}
|
| |
|
| |
|
| | print("\n --- 2a: Context Utilization (PPL by position depth) ---")
|
| | ctx_results = eval_context_utilization(wrapper, sequences)
|
| |
|
| | if ctx_results:
|
| | print(f"\n {'Depth':<12} {'Loss':>8} {'PPL':>10} {'Tokens':>10}")
|
| | print(f" {'-' * 42}")
|
| | for bucket, vals in ctx_results.items():
|
| | print(f" {bucket:<12} {vals['loss']:>8.3f} {vals['ppl']:>10.2f} {vals['n_tokens']:>10}")
|
| |
|
| | buckets_list = list(ctx_results.values())
|
| | if len(buckets_list) >= 2:
|
| | ratio = buckets_list[0]["ppl"] / buckets_list[-1]["ppl"]
|
| | print(f"\n Context utilization ratio (first/last): {ratio:.2f}x")
|
| | print(f" (Higher = model benefits more from additional context)")
|
| |
|
| |
|
| | skip_results = None
|
| | if wrapper.has_skip_head:
|
| | print(f"\n --- 2b: Skip Head Accuracy (trained for t+{wrapper.skip_k}) ---")
|
| | skip_results = eval_skip_accuracy(wrapper, sequences, distances)
|
| |
|
| | if skip_results:
|
| | print(f"\n {'Distance':<12} {'Top-1':>8} {'Top-5':>8}")
|
| | print(f" {'-' * 30}")
|
| | for key, vals in skip_results.items():
|
| | trained = " *" if int(key.split("+")[1]) == wrapper.skip_k else ""
|
| | print(f" {key:<12} {vals['top1']:>8.4f} {vals['top5']:>8.4f}{trained}")
|
| | print(f"\n * = trained distance")
|
| | else:
|
| | print("\n Skip head: not available")
|
| |
|
| | return {"context_utilization": ctx_results, "skip_accuracy": skip_results}
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def compute_chunk_similarity(hidden_states, chunk_size=128):
|
| | """Compute cosine similarity between chunks of hidden states.
|
| | hidden_states: [L, D] tensor.
|
| | """
|
| | L, D = hidden_states.shape
|
| | n_chunks = L // chunk_size
|
| |
|
| | if n_chunks < 2:
|
| | return None
|
| |
|
| |
|
| | chunks = []
|
| | for i in range(n_chunks):
|
| | chunk = hidden_states[i * chunk_size:(i + 1) * chunk_size]
|
| | chunks.append(chunk.mean(dim=0))
|
| |
|
| | chunk_vecs = torch.stack(chunks)
|
| | chunk_vecs = F.normalize(chunk_vecs, dim=-1)
|
| |
|
| |
|
| | sim_matrix = chunk_vecs @ chunk_vecs.T
|
| |
|
| |
|
| | mask = torch.triu(torch.ones_like(sim_matrix, dtype=torch.bool), diagonal=1)
|
| | pairwise_sims = sim_matrix[mask]
|
| |
|
| |
|
| | adjacent = [sim_matrix[i, i + 1].item() for i in range(n_chunks - 1)]
|
| |
|
| |
|
| | q1 = max(1, n_chunks // 4)
|
| | distant = []
|
| | for i in range(q1):
|
| | for j in range(n_chunks - q1, n_chunks):
|
| | if i < j:
|
| | distant.append(sim_matrix[i, j].item())
|
| |
|
| | return {
|
| | "mean_sim": pairwise_sims.mean().item(),
|
| | "min_sim": pairwise_sims.min().item(),
|
| | "adjacent_sim": sum(adjacent) / len(adjacent),
|
| | "distant_sim": sum(distant) / len(distant) if distant else 0.0,
|
| | "n_chunks": n_chunks,
|
| | }
|
| |
|
| |
|
| | def eval_consistency(wrapper, generations, chunk_size=128):
|
| | """Tier 3: Semantic consistency of generated text via hidden state similarity."""
|
| | print("\n" + "=" * 60)
|
| | print("TIER 3: Semantic Consistency")
|
| | print("=" * 60)
|
| |
|
| | all_metrics = []
|
| |
|
| | for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| | if gen_ids.shape[0] < chunk_size * 2:
|
| | continue
|
| |
|
| |
|
| | full_ids = torch.cat([prompt_ids, gen_ids]).unsqueeze(0).to(wrapper.device)
|
| |
|
| |
|
| | if full_ids.shape[1] > wrapper.max_seq_len:
|
| | full_ids = full_ids[:, :wrapper.max_seq_len]
|
| |
|
| | _, hidden, _ = wrapper.forward_with_hidden(full_ids)
|
| |
|
| |
|
| | gen_start = prompt_ids.shape[0]
|
| | gen_hidden = hidden[0, gen_start:]
|
| |
|
| | metrics = compute_chunk_similarity(gen_hidden, chunk_size)
|
| | if metrics is not None:
|
| | metrics["prompt"] = prompt
|
| | all_metrics.append(metrics)
|
| |
|
| | print(f" [{len(all_metrics)}/{len(generations)}]", end="\r")
|
| | print()
|
| |
|
| | if not all_metrics:
|
| | print(" No valid generations for consistency evaluation.")
|
| | return {}
|
| |
|
| | n = len(all_metrics)
|
| | agg = {}
|
| | for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim"]:
|
| | values = [m[key] for m in all_metrics]
|
| | agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
| |
|
| |
|
| | drift_vals = [m["adjacent_sim"] - m["distant_sim"] for m in all_metrics]
|
| | agg["topic_drift"] = {"mean": sum(drift_vals) / n,
|
| | "min": min(drift_vals), "max": max(drift_vals)}
|
| |
|
| |
|
| | print(f"\n Generations evaluated: {n}")
|
| | print(f" Chunk size: {chunk_size} tokens")
|
| | avg_chunks = sum(m["n_chunks"] for m in all_metrics) / n
|
| | print(f" Avg chunks per generation: {avg_chunks:.1f}")
|
| | print()
|
| | print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| | print(f" {'-' * 50}")
|
| | for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim", "topic_drift"]:
|
| | m = agg[key]
|
| | print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
| |
|
| | return {"per_prompt": all_metrics, "aggregate": agg}
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def print_summary(results):
|
| | """Print composite summary scores."""
|
| | print("\n" + "=" * 60)
|
| | print("SUMMARY")
|
| | print("=" * 60)
|
| |
|
| | scores = {}
|
| |
|
| |
|
| | t1 = results.get("tier1_diversity", {})
|
| | if t1 and "aggregate" in t1:
|
| | div_score = t1["aggregate"].get("unique_4g", {}).get("mean", None)
|
| | collapse = t1["aggregate"].get("collapse_rate", None)
|
| | if div_score is not None:
|
| | scores["diversity"] = div_score
|
| | print(f" Diversity (unique 4-gram): {div_score:.3f}", end="")
|
| | if collapse is not None:
|
| | print(f" (collapse: {collapse:.0%})", end="")
|
| | print()
|
| |
|
| |
|
| | t2 = results.get("tier2_structural", {})
|
| | if t2:
|
| | ctx = t2.get("context_utilization")
|
| | if ctx:
|
| | buckets = list(ctx.values())
|
| | if len(buckets) >= 2:
|
| | ratio = buckets[0]["ppl"] / buckets[-1]["ppl"]
|
| | scores["context_util"] = ratio
|
| | print(f" Context utilization: {ratio:.2f}x")
|
| |
|
| | skip = t2.get("skip_accuracy")
|
| | if skip:
|
| |
|
| | trained_key = None
|
| | for key in skip:
|
| | trained_key = key
|
| | break
|
| | if trained_key:
|
| | top5 = skip[trained_key]["top5"]
|
| | scores["skip_top5"] = top5
|
| | print(f" Skip accuracy ({trained_key} top-5): {top5:.4f}")
|
| |
|
| |
|
| | t3 = results.get("tier3_consistency", {})
|
| | if t3 and "aggregate" in t3:
|
| | coh_score = t3["aggregate"].get("mean_sim", {}).get("mean", None)
|
| | drift = t3["aggregate"].get("topic_drift", {}).get("mean", None)
|
| | if coh_score is not None:
|
| | scores["coherence"] = coh_score
|
| | print(f" Coherence (chunk sim): {coh_score:.3f}", end="")
|
| | if drift is not None:
|
| | print(f" (drift: {drift:.3f})", end="")
|
| | print()
|
| |
|
| | results["summary"] = scores
|
| | return scores
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(
|
| | description="Coherence evaluation for language models",
|
| | formatter_class=argparse.RawDescriptionHelpFormatter,
|
| | )
|
| |
|
| |
|
| | group = parser.add_mutually_exclusive_group(required=True)
|
| | group.add_argument("--checkpoint", type=str, help="Path to circuit model checkpoint")
|
| | group.add_argument("--model", type=str, help="HuggingFace model name or path")
|
| |
|
| |
|
| | parser.add_argument("--prompts", type=str, help="File with prompts (one per line)")
|
| | parser.add_argument("--num-prompts", type=int, default=20,
|
| | help="Number of prompts to use (default: 20)")
|
| | parser.add_argument("--gen-length", type=int, default=512,
|
| | help="Tokens to generate per prompt (default: 512)")
|
| | parser.add_argument("--eval-data", type=str,
|
| | help="Text file for Tier 2 (default: WikiText-103 validation)")
|
| | parser.add_argument("--num-sequences", type=int, default=50,
|
| | help="Number of sequences for Tier 2 (default: 50)")
|
| | parser.add_argument("--chunk-size", type=int, default=128,
|
| | help="Chunk size for Tier 3 similarity (default: 128)")
|
| | parser.add_argument("--distances", type=str, default="2,5,10,25,50,100",
|
| | help="Skip distances for Tier 2b (default: 2,5,10,25,50,100)")
|
| | parser.add_argument("--tiers", type=str, default="1,2,3",
|
| | help="Which tiers to run (default: 1,2,3)")
|
| |
|
| |
|
| | parser.add_argument("--gpu", type=int, default=0, help="GPU index (default: 0)")
|
| |
|
| |
|
| | parser.add_argument("--output", type=str, help="Save results to JSON file")
|
| | parser.add_argument("--samples", type=int, default=3,
|
| | help="Number of sample generations to display (default: 3)")
|
| |
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| | device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| | tiers = [int(t) for t in args.tiers.split(",")]
|
| | distances = [int(d) for d in args.distances.split(",")]
|
| |
|
| |
|
| | print("=" * 60)
|
| | print("Coherence Evaluation")
|
| | print("=" * 60)
|
| |
|
| | if args.checkpoint:
|
| | print(f"Loading: {args.checkpoint}")
|
| | wrapper = ModelWrapper.from_checkpoint(args.checkpoint, device)
|
| | else:
|
| | print(f"Loading: {args.model}")
|
| | wrapper = ModelWrapper.from_pretrained(args.model, device)
|
| |
|
| | print(f"Model: {wrapper.name}")
|
| | print(f"Device: {device}")
|
| | print(f"Max seq len: {wrapper.max_seq_len}")
|
| | if wrapper.has_skip_head:
|
| | print(f"Skip head: t+{wrapper.skip_k}")
|
| | print(f"Tiers: {tiers}")
|
| |
|
| |
|
| | if args.prompts:
|
| | with open(args.prompts) as f:
|
| | prompts = [line.strip() for line in f if line.strip()]
|
| | else:
|
| | prompts = DEFAULT_PROMPTS
|
| | prompts = prompts[:args.num_prompts]
|
| | print(f"Prompts: {len(prompts)}")
|
| |
|
| | results = {"model": wrapper.name}
|
| | t0 = time.time()
|
| |
|
| |
|
| | generations = None
|
| | if 1 in tiers or 3 in tiers:
|
| | print(f"\nGenerating {args.gen_length} tokens from {len(prompts)} prompts...")
|
| | generations = generate_all(wrapper, prompts, args.gen_length)
|
| |
|
| |
|
| | if 1 in tiers and generations:
|
| | results["tier1_diversity"] = eval_diversity(
|
| | generations, wrapper.tokenizer, show_samples=args.samples)
|
| |
|
| |
|
| | if 2 in tiers:
|
| | results["tier2_structural"] = eval_structural(
|
| | wrapper, args.eval_data, distances, args.num_sequences)
|
| |
|
| |
|
| | if 3 in tiers and generations:
|
| | results["tier3_consistency"] = eval_consistency(
|
| | wrapper, generations, args.chunk_size)
|
| |
|
| |
|
| | print_summary(results)
|
| |
|
| | elapsed = time.time() - t0
|
| | print(f"\nTotal time: {elapsed:.0f}s")
|
| |
|
| |
|
| | if args.output:
|
| | def make_serializable(obj):
|
| | if isinstance(obj, dict):
|
| | return {k: make_serializable(v) for k, v in obj.items()}
|
| | elif isinstance(obj, list):
|
| | return [make_serializable(v) for v in obj]
|
| | elif isinstance(obj, torch.Tensor):
|
| | return obj.tolist()
|
| | elif isinstance(obj, float):
|
| | if math.isnan(obj) or math.isinf(obj):
|
| | return str(obj)
|
| | return obj
|
| |
|
| | out_path = Path(args.output)
|
| | out_path.parent.mkdir(parents=True, exist_ok=True)
|
| | with open(out_path, "w") as f:
|
| | json.dump(make_serializable(results), f, indent=2)
|
| | print(f"Results saved to {args.output}")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|