SQCU's picture
Initial upload: Gemma-3-270M and Qwen-0.5B fully fine-tuned BTRM models
3724307 verified
#!/usr/bin/env python3
"""
Pairwise BTRM comparison script.
Compare text against pairs of heads for intuitive interpretation.
Uses fully fine-tuned models (base LLM + heads trained together).
Usage:
python compare.py --text "Your text here"
python compare.py --text "Your text" --heads oblivion,skyrim
python compare.py --file input.txt --model gemma
echo "Some text" | python compare.py --stdin
"""
import argparse
import sys
import torch
import yaml
from pathlib import Path
from typing import Optional
# Head descriptions for pretty printing
HEAD_INFO = {
"skyrim": "Nordic fantasy RPG",
"oblivion": "Imperial fantasy RPG",
"fonv": "Post-apocalyptic Western",
"gallia": "Franco-Roman bureaucratic",
"marmotte": "Alpine corporate dystopia",
"multiturn_dialogue": "Raw quoted dialogue",
"brainrot_aesop": "Vocabulary teaching",
}
class BTRMModel:
"""Load and run fully fine-tuned BTRM model."""
def __init__(self, model_name: str = "gemma", device: Optional[str] = None):
"""
Args:
model_name: "gemma" or "qwen"
device: "cuda", "cpu", or None for auto
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
# Determine paths
script_dir = Path(__file__).parent
if model_name == "gemma":
model_dir = script_dir / "gemma_btrm"
elif model_name == "qwen":
model_dir = script_dir / "qwen_btrm"
else:
raise ValueError(f"Unknown model: {model_name}. Use 'gemma' or 'qwen'")
if not model_dir.exists():
raise FileNotFoundError(f"Model directory not found: {model_dir}")
# Load config
with open(model_dir / "config.yaml") as f:
self.config = yaml.safe_load(f)
# Load the FINE-TUNED base model (not the original!)
base_model_path = model_dir / "base_model"
print(f"Loading fine-tuned {model_name} from {base_model_path}...", file=sys.stderr)
from transformers import AutoModelForCausalLM, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
str(base_model_path), trust_remote_code=True
)
self.base_model = AutoModelForCausalLM.from_pretrained(
str(base_model_path),
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(self.device)
self.base_model.eval()
# Load heads
heads_path = model_dir / "btrm_heads.pt"
heads_data = torch.load(heads_path, map_location=self.device)
self.head_names = heads_data["head_names"]
self.heads = torch.nn.ModuleDict()
hidden_dim = self.base_model.config.hidden_size
for name in self.head_names:
head = torch.nn.Linear(hidden_dim, 1)
head.load_state_dict(heads_data["heads"][name])
self.heads[name] = head.to(self.device)
print(f"Loaded {len(self.head_names)} heads: {', '.join(self.head_names)}", file=sys.stderr)
def score(self, text: str) -> dict[str, float]:
"""Get scores for all heads."""
# Tokenize
inputs = self.tokenizer(
text, return_tensors="pt", truncation=True, max_length=2048
).to(self.device)
# Get hidden state from fine-tuned model
with torch.no_grad():
outputs = self.base_model(**inputs, output_hidden_states=True)
hidden = outputs.hidden_states[-1][:, -1, :] # Last token, last layer
# Score each head
scores = {}
for name, head in self.heads.items():
score = head(hidden).item()
scores[name] = score
return scores
def compare(self, text: str, head_a: str, head_b: str) -> dict:
"""Compare two specific heads."""
if head_a not in self.head_names:
raise ValueError(f"Unknown head: {head_a}. Available: {self.head_names}")
if head_b not in self.head_names:
raise ValueError(f"Unknown head: {head_b}. Available: {self.head_names}")
scores = self.score(text)
return {
head_a: scores[head_a],
head_b: scores[head_b],
"delta": scores[head_a] - scores[head_b],
"winner": head_a if scores[head_a] > scores[head_b] else head_b,
}
def top_contrasts(self, text: str, n: int = 5) -> list[dict]:
"""Find the pairs with largest score differences."""
scores = self.score(text)
pairs = []
names = list(scores.keys())
for i, a in enumerate(names):
for b in names[i+1:]:
pairs.append({
"head_a": a,
"head_b": b,
"score_a": scores[a],
"score_b": scores[b],
"delta": abs(scores[a] - scores[b]),
"higher": a if scores[a] > scores[b] else b,
})
return sorted(pairs, key=lambda x: x["delta"], reverse=True)[:n]
def format_scores(scores: dict[str, float]) -> str:
"""Pretty-print scores with bar chart."""
lines = []
sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
max_name_len = max(len(name) for name in scores.keys())
for name, score in sorted_scores:
# Create bar (scale: -2 to +2 mapped to 0-20 chars)
bar_pos = int((score + 2) / 4 * 20)
bar_pos = max(0, min(20, bar_pos))
bar = "─" * bar_pos + "β”‚" + "─" * (20 - bar_pos)
desc = HEAD_INFO.get(name, "")
lines.append(f" {name:<{max_name_len}} {score:+.3f} [{bar}] {desc}")
return "\n".join(lines)
def format_comparison(result: dict, head_a: str, head_b: str) -> str:
"""Pretty-print a pairwise comparison."""
delta = result["delta"]
winner = result["winner"]
lines = [
f" {head_a}: {result[head_a]:+.3f} ({HEAD_INFO.get(head_a, '')})",
f" {head_b}: {result[head_b]:+.3f} ({HEAD_INFO.get(head_b, '')})",
f" ─────────────────────",
f" Ξ” = {delta:+.3f} β†’ leans **{winner}**",
]
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(
description="Compare text using multi-head BTRM (fully fine-tuned models)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Score all heads
python compare.py --text "The Dragonborn climbed to High Hrothgar..."
# Compare specific pair
python compare.py --text "Breaking news today" --heads skyrim,fonv
# Use Qwen model instead of Gemma
python compare.py --text "Your text" --model qwen
# Read from file
python compare.py --file story.txt
# Pipe from stdin
echo "Your text" | python compare.py --stdin
Available heads:
skyrim, oblivion, fonv, gallia, marmotte, multiturn_dialogue, brainrot_aesop
"""
)
parser.add_argument("--text", "-t", help="Text to analyze")
parser.add_argument("--file", "-f", help="Read text from file")
parser.add_argument("--stdin", action="store_true", help="Read from stdin")
parser.add_argument("--model", "-m", default="gemma", choices=["gemma", "qwen"],
help="Model to use (default: gemma, recommended for better discrimination)")
parser.add_argument("--heads", help="Comma-separated pair of heads to compare (e.g., oblivion,skyrim)")
parser.add_argument("--contrasts", type=int, metavar="N",
help="Show top N pairwise contrasts")
args = parser.parse_args()
# Get input text
if args.text:
text = args.text
elif args.file:
text = Path(args.file).read_text()
elif args.stdin or not sys.stdin.isatty():
text = sys.stdin.read()
else:
parser.print_help()
print("\n\nError: Provide text via --text, --file, or --stdin", file=sys.stderr)
sys.exit(1)
text = text.strip()
if not text:
print("Error: Empty input", file=sys.stderr)
sys.exit(1)
# Load model
try:
model = BTRMModel(args.model)
except FileNotFoundError as e:
print(f"Error: {e}", file=sys.stderr)
print("Make sure you're running from the repo directory with model folders.", file=sys.stderr)
sys.exit(1)
print(f"\n{'='*60}", file=sys.stderr)
print(f"Text: {text[:70]}{'...' if len(text) > 70 else ''}", file=sys.stderr)
print(f"{'='*60}\n", file=sys.stderr)
if args.heads:
# Pairwise comparison
parts = args.heads.split(",")
if len(parts) != 2:
print("Error: --heads requires exactly 2 comma-separated head names", file=sys.stderr)
print(f"Available: {', '.join(HEAD_INFO.keys())}", file=sys.stderr)
sys.exit(1)
head_a, head_b = parts[0].strip(), parts[1].strip()
try:
result = model.compare(text, head_a, head_b)
except ValueError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
print("Pairwise Comparison:")
print(format_comparison(result, head_a, head_b))
elif args.contrasts:
# Show top contrasts
contrasts = model.top_contrasts(text, args.contrasts)
print(f"Top {len(contrasts)} Pairwise Contrasts:")
for c in contrasts:
print(f" {c['head_a']} vs {c['head_b']}: Ξ”={c['delta']:.3f} (higher: {c['higher']})")
else:
# Show all scores (default)
scores = model.score(text)
print("All Head Scores:")
print(format_scores(scores))
print(f"\nTip: Use --heads {list(scores.keys())[0]},{list(scores.keys())[1]} for pairwise comparison")
if __name__ == "__main__":
main()