#!/usr/bin/env python3 """ Pairwise BTRM comparison script. Compare text against pairs of heads for intuitive interpretation. Uses fully fine-tuned models (base LLM + heads trained together). Usage: python compare.py --text "Your text here" python compare.py --text "Your text" --heads oblivion,skyrim python compare.py --file input.txt --model gemma echo "Some text" | python compare.py --stdin """ import argparse import sys import torch import yaml from pathlib import Path from typing import Optional # Head descriptions for pretty printing HEAD_INFO = { "skyrim": "Nordic fantasy RPG", "oblivion": "Imperial fantasy RPG", "fonv": "Post-apocalyptic Western", "gallia": "Franco-Roman bureaucratic", "marmotte": "Alpine corporate dystopia", "multiturn_dialogue": "Raw quoted dialogue", "brainrot_aesop": "Vocabulary teaching", } class BTRMModel: """Load and run fully fine-tuned BTRM model.""" def __init__(self, model_name: str = "gemma", device: Optional[str] = None): """ Args: model_name: "gemma" or "qwen" device: "cuda", "cpu", or None for auto """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model_name = model_name # Determine paths script_dir = Path(__file__).parent if model_name == "gemma": model_dir = script_dir / "gemma_btrm" elif model_name == "qwen": model_dir = script_dir / "qwen_btrm" else: raise ValueError(f"Unknown model: {model_name}. Use 'gemma' or 'qwen'") if not model_dir.exists(): raise FileNotFoundError(f"Model directory not found: {model_dir}") # Load config with open(model_dir / "config.yaml") as f: self.config = yaml.safe_load(f) # Load the FINE-TUNED base model (not the original!) base_model_path = model_dir / "base_model" print(f"Loading fine-tuned {model_name} from {base_model_path}...", file=sys.stderr) from transformers import AutoModelForCausalLM, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( str(base_model_path), trust_remote_code=True ) self.base_model = AutoModelForCausalLM.from_pretrained( str(base_model_path), trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(self.device) self.base_model.eval() # Load heads heads_path = model_dir / "btrm_heads.pt" heads_data = torch.load(heads_path, map_location=self.device) self.head_names = heads_data["head_names"] self.heads = torch.nn.ModuleDict() hidden_dim = self.base_model.config.hidden_size for name in self.head_names: head = torch.nn.Linear(hidden_dim, 1) head.load_state_dict(heads_data["heads"][name]) self.heads[name] = head.to(self.device) print(f"Loaded {len(self.head_names)} heads: {', '.join(self.head_names)}", file=sys.stderr) def score(self, text: str) -> dict[str, float]: """Get scores for all heads.""" # Tokenize inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Get hidden state from fine-tuned model with torch.no_grad(): outputs = self.base_model(**inputs, output_hidden_states=True) hidden = outputs.hidden_states[-1][:, -1, :] # Last token, last layer # Score each head scores = {} for name, head in self.heads.items(): score = head(hidden).item() scores[name] = score return scores def compare(self, text: str, head_a: str, head_b: str) -> dict: """Compare two specific heads.""" if head_a not in self.head_names: raise ValueError(f"Unknown head: {head_a}. Available: {self.head_names}") if head_b not in self.head_names: raise ValueError(f"Unknown head: {head_b}. Available: {self.head_names}") scores = self.score(text) return { head_a: scores[head_a], head_b: scores[head_b], "delta": scores[head_a] - scores[head_b], "winner": head_a if scores[head_a] > scores[head_b] else head_b, } def top_contrasts(self, text: str, n: int = 5) -> list[dict]: """Find the pairs with largest score differences.""" scores = self.score(text) pairs = [] names = list(scores.keys()) for i, a in enumerate(names): for b in names[i+1:]: pairs.append({ "head_a": a, "head_b": b, "score_a": scores[a], "score_b": scores[b], "delta": abs(scores[a] - scores[b]), "higher": a if scores[a] > scores[b] else b, }) return sorted(pairs, key=lambda x: x["delta"], reverse=True)[:n] def format_scores(scores: dict[str, float]) -> str: """Pretty-print scores with bar chart.""" lines = [] sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) max_name_len = max(len(name) for name in scores.keys()) for name, score in sorted_scores: # Create bar (scale: -2 to +2 mapped to 0-20 chars) bar_pos = int((score + 2) / 4 * 20) bar_pos = max(0, min(20, bar_pos)) bar = "─" * bar_pos + "│" + "─" * (20 - bar_pos) desc = HEAD_INFO.get(name, "") lines.append(f" {name:<{max_name_len}} {score:+.3f} [{bar}] {desc}") return "\n".join(lines) def format_comparison(result: dict, head_a: str, head_b: str) -> str: """Pretty-print a pairwise comparison.""" delta = result["delta"] winner = result["winner"] lines = [ f" {head_a}: {result[head_a]:+.3f} ({HEAD_INFO.get(head_a, '')})", f" {head_b}: {result[head_b]:+.3f} ({HEAD_INFO.get(head_b, '')})", f" ─────────────────────", f" Δ = {delta:+.3f} → leans **{winner}**", ] return "\n".join(lines) def main(): parser = argparse.ArgumentParser( description="Compare text using multi-head BTRM (fully fine-tuned models)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Score all heads python compare.py --text "The Dragonborn climbed to High Hrothgar..." # Compare specific pair python compare.py --text "Breaking news today" --heads skyrim,fonv # Use Qwen model instead of Gemma python compare.py --text "Your text" --model qwen # Read from file python compare.py --file story.txt # Pipe from stdin echo "Your text" | python compare.py --stdin Available heads: skyrim, oblivion, fonv, gallia, marmotte, multiturn_dialogue, brainrot_aesop """ ) parser.add_argument("--text", "-t", help="Text to analyze") parser.add_argument("--file", "-f", help="Read text from file") parser.add_argument("--stdin", action="store_true", help="Read from stdin") parser.add_argument("--model", "-m", default="gemma", choices=["gemma", "qwen"], help="Model to use (default: gemma, recommended for better discrimination)") parser.add_argument("--heads", help="Comma-separated pair of heads to compare (e.g., oblivion,skyrim)") parser.add_argument("--contrasts", type=int, metavar="N", help="Show top N pairwise contrasts") args = parser.parse_args() # Get input text if args.text: text = args.text elif args.file: text = Path(args.file).read_text() elif args.stdin or not sys.stdin.isatty(): text = sys.stdin.read() else: parser.print_help() print("\n\nError: Provide text via --text, --file, or --stdin", file=sys.stderr) sys.exit(1) text = text.strip() if not text: print("Error: Empty input", file=sys.stderr) sys.exit(1) # Load model try: model = BTRMModel(args.model) except FileNotFoundError as e: print(f"Error: {e}", file=sys.stderr) print("Make sure you're running from the repo directory with model folders.", file=sys.stderr) sys.exit(1) print(f"\n{'='*60}", file=sys.stderr) print(f"Text: {text[:70]}{'...' if len(text) > 70 else ''}", file=sys.stderr) print(f"{'='*60}\n", file=sys.stderr) if args.heads: # Pairwise comparison parts = args.heads.split(",") if len(parts) != 2: print("Error: --heads requires exactly 2 comma-separated head names", file=sys.stderr) print(f"Available: {', '.join(HEAD_INFO.keys())}", file=sys.stderr) sys.exit(1) head_a, head_b = parts[0].strip(), parts[1].strip() try: result = model.compare(text, head_a, head_b) except ValueError as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) print("Pairwise Comparison:") print(format_comparison(result, head_a, head_b)) elif args.contrasts: # Show top contrasts contrasts = model.top_contrasts(text, args.contrasts) print(f"Top {len(contrasts)} Pairwise Contrasts:") for c in contrasts: print(f" {c['head_a']} vs {c['head_b']}: Δ={c['delta']:.3f} (higher: {c['higher']})") else: # Show all scores (default) scores = model.score(text) print("All Head Scores:") print(format_scores(scores)) print(f"\nTip: Use --heads {list(scores.keys())[0]},{list(scores.keys())[1]} for pairwise comparison") if __name__ == "__main__": main()