| | |
| | """ |
| | Pairwise BTRM comparison script. |
| | |
| | Compare text against pairs of heads for intuitive interpretation. |
| | Uses fully fine-tuned models (base LLM + heads trained together). |
| | |
| | Usage: |
| | python compare.py --text "Your text here" |
| | python compare.py --text "Your text" --heads oblivion,skyrim |
| | python compare.py --file input.txt --model gemma |
| | echo "Some text" | python compare.py --stdin |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import torch |
| | import yaml |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | |
| | HEAD_INFO = { |
| | "skyrim": "Nordic fantasy RPG", |
| | "oblivion": "Imperial fantasy RPG", |
| | "fonv": "Post-apocalyptic Western", |
| | "gallia": "Franco-Roman bureaucratic", |
| | "marmotte": "Alpine corporate dystopia", |
| | "multiturn_dialogue": "Raw quoted dialogue", |
| | "brainrot_aesop": "Vocabulary teaching", |
| | } |
| |
|
| |
|
| | class BTRMModel: |
| | """Load and run fully fine-tuned BTRM model.""" |
| |
|
| | def __init__(self, model_name: str = "gemma", device: Optional[str] = None): |
| | """ |
| | Args: |
| | model_name: "gemma" or "qwen" |
| | device: "cuda", "cpu", or None for auto |
| | """ |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model_name = model_name |
| |
|
| | |
| | script_dir = Path(__file__).parent |
| | if model_name == "gemma": |
| | model_dir = script_dir / "gemma_btrm" |
| | elif model_name == "qwen": |
| | model_dir = script_dir / "qwen_btrm" |
| | else: |
| | raise ValueError(f"Unknown model: {model_name}. Use 'gemma' or 'qwen'") |
| |
|
| | if not model_dir.exists(): |
| | raise FileNotFoundError(f"Model directory not found: {model_dir}") |
| |
|
| | |
| | with open(model_dir / "config.yaml") as f: |
| | self.config = yaml.safe_load(f) |
| |
|
| | |
| | base_model_path = model_dir / "base_model" |
| | print(f"Loading fine-tuned {model_name} from {base_model_path}...", file=sys.stderr) |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | str(base_model_path), trust_remote_code=True |
| | ) |
| | self.base_model = AutoModelForCausalLM.from_pretrained( |
| | str(base_model_path), |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | ).to(self.device) |
| | self.base_model.eval() |
| |
|
| | |
| | heads_path = model_dir / "btrm_heads.pt" |
| | heads_data = torch.load(heads_path, map_location=self.device) |
| | self.head_names = heads_data["head_names"] |
| | self.heads = torch.nn.ModuleDict() |
| |
|
| | hidden_dim = self.base_model.config.hidden_size |
| | for name in self.head_names: |
| | head = torch.nn.Linear(hidden_dim, 1) |
| | head.load_state_dict(heads_data["heads"][name]) |
| | self.heads[name] = head.to(self.device) |
| |
|
| | print(f"Loaded {len(self.head_names)} heads: {', '.join(self.head_names)}", file=sys.stderr) |
| |
|
| | def score(self, text: str) -> dict[str, float]: |
| | """Get scores for all heads.""" |
| | |
| | inputs = self.tokenizer( |
| | text, return_tensors="pt", truncation=True, max_length=2048 |
| | ).to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.base_model(**inputs, output_hidden_states=True) |
| | hidden = outputs.hidden_states[-1][:, -1, :] |
| |
|
| | |
| | scores = {} |
| | for name, head in self.heads.items(): |
| | score = head(hidden).item() |
| | scores[name] = score |
| |
|
| | return scores |
| |
|
| | def compare(self, text: str, head_a: str, head_b: str) -> dict: |
| | """Compare two specific heads.""" |
| | if head_a not in self.head_names: |
| | raise ValueError(f"Unknown head: {head_a}. Available: {self.head_names}") |
| | if head_b not in self.head_names: |
| | raise ValueError(f"Unknown head: {head_b}. Available: {self.head_names}") |
| |
|
| | scores = self.score(text) |
| | return { |
| | head_a: scores[head_a], |
| | head_b: scores[head_b], |
| | "delta": scores[head_a] - scores[head_b], |
| | "winner": head_a if scores[head_a] > scores[head_b] else head_b, |
| | } |
| |
|
| | def top_contrasts(self, text: str, n: int = 5) -> list[dict]: |
| | """Find the pairs with largest score differences.""" |
| | scores = self.score(text) |
| | pairs = [] |
| | names = list(scores.keys()) |
| | for i, a in enumerate(names): |
| | for b in names[i+1:]: |
| | pairs.append({ |
| | "head_a": a, |
| | "head_b": b, |
| | "score_a": scores[a], |
| | "score_b": scores[b], |
| | "delta": abs(scores[a] - scores[b]), |
| | "higher": a if scores[a] > scores[b] else b, |
| | }) |
| | return sorted(pairs, key=lambda x: x["delta"], reverse=True)[:n] |
| |
|
| |
|
| | def format_scores(scores: dict[str, float]) -> str: |
| | """Pretty-print scores with bar chart.""" |
| | lines = [] |
| | sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
| | max_name_len = max(len(name) for name in scores.keys()) |
| |
|
| | for name, score in sorted_scores: |
| | |
| | bar_pos = int((score + 2) / 4 * 20) |
| | bar_pos = max(0, min(20, bar_pos)) |
| | bar = "β" * bar_pos + "β" + "β" * (20 - bar_pos) |
| |
|
| | desc = HEAD_INFO.get(name, "") |
| | lines.append(f" {name:<{max_name_len}} {score:+.3f} [{bar}] {desc}") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | def format_comparison(result: dict, head_a: str, head_b: str) -> str: |
| | """Pretty-print a pairwise comparison.""" |
| | delta = result["delta"] |
| | winner = result["winner"] |
| |
|
| | lines = [ |
| | f" {head_a}: {result[head_a]:+.3f} ({HEAD_INFO.get(head_a, '')})", |
| | f" {head_b}: {result[head_b]:+.3f} ({HEAD_INFO.get(head_b, '')})", |
| | f" βββββββββββββββββββββ", |
| | f" Ξ = {delta:+.3f} β leans **{winner}**", |
| | ] |
| | return "\n".join(lines) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Compare text using multi-head BTRM (fully fine-tuned models)", |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | # Score all heads |
| | python compare.py --text "The Dragonborn climbed to High Hrothgar..." |
| | |
| | # Compare specific pair |
| | python compare.py --text "Breaking news today" --heads skyrim,fonv |
| | |
| | # Use Qwen model instead of Gemma |
| | python compare.py --text "Your text" --model qwen |
| | |
| | # Read from file |
| | python compare.py --file story.txt |
| | |
| | # Pipe from stdin |
| | echo "Your text" | python compare.py --stdin |
| | |
| | Available heads: |
| | skyrim, oblivion, fonv, gallia, marmotte, multiturn_dialogue, brainrot_aesop |
| | """ |
| | ) |
| | parser.add_argument("--text", "-t", help="Text to analyze") |
| | parser.add_argument("--file", "-f", help="Read text from file") |
| | parser.add_argument("--stdin", action="store_true", help="Read from stdin") |
| | parser.add_argument("--model", "-m", default="gemma", choices=["gemma", "qwen"], |
| | help="Model to use (default: gemma, recommended for better discrimination)") |
| | parser.add_argument("--heads", help="Comma-separated pair of heads to compare (e.g., oblivion,skyrim)") |
| | parser.add_argument("--contrasts", type=int, metavar="N", |
| | help="Show top N pairwise contrasts") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if args.text: |
| | text = args.text |
| | elif args.file: |
| | text = Path(args.file).read_text() |
| | elif args.stdin or not sys.stdin.isatty(): |
| | text = sys.stdin.read() |
| | else: |
| | parser.print_help() |
| | print("\n\nError: Provide text via --text, --file, or --stdin", file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | text = text.strip() |
| | if not text: |
| | print("Error: Empty input", file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | |
| | try: |
| | model = BTRMModel(args.model) |
| | except FileNotFoundError as e: |
| | print(f"Error: {e}", file=sys.stderr) |
| | print("Make sure you're running from the repo directory with model folders.", file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | print(f"\n{'='*60}", file=sys.stderr) |
| | print(f"Text: {text[:70]}{'...' if len(text) > 70 else ''}", file=sys.stderr) |
| | print(f"{'='*60}\n", file=sys.stderr) |
| |
|
| | if args.heads: |
| | |
| | parts = args.heads.split(",") |
| | if len(parts) != 2: |
| | print("Error: --heads requires exactly 2 comma-separated head names", file=sys.stderr) |
| | print(f"Available: {', '.join(HEAD_INFO.keys())}", file=sys.stderr) |
| | sys.exit(1) |
| | head_a, head_b = parts[0].strip(), parts[1].strip() |
| |
|
| | try: |
| | result = model.compare(text, head_a, head_b) |
| | except ValueError as e: |
| | print(f"Error: {e}", file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | print("Pairwise Comparison:") |
| | print(format_comparison(result, head_a, head_b)) |
| |
|
| | elif args.contrasts: |
| | |
| | contrasts = model.top_contrasts(text, args.contrasts) |
| | print(f"Top {len(contrasts)} Pairwise Contrasts:") |
| | for c in contrasts: |
| | print(f" {c['head_a']} vs {c['head_b']}: Ξ={c['delta']:.3f} (higher: {c['higher']})") |
| |
|
| | else: |
| | |
| | scores = model.score(text) |
| | print("All Head Scores:") |
| | print(format_scores(scores)) |
| | print(f"\nTip: Use --heads {list(scores.keys())[0]},{list(scores.keys())[1]} for pairwise comparison") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|