| | ---
|
| | license: gemma
|
| | tags:
|
| | - reward-model
|
| | - bradley-terry
|
| | - dialogue
|
| | - multi-head
|
| | - corpus-membership
|
| | ---
|
| |
|
| | # BTRM+ (Bradley-Terry Reward Model Plus)
|
| |
|
| | Multi-head reward models for corpus membership and structural genre classification. Trained on situated dialogue from video games and synthetic settings.
|
| |
|
| | ## Models in This Repository
|
| |
|
| | | Model | Base | Heads | Training | Logsquare | Loss | L2 Drift |
|
| | |-------|------|-------|----------|-----------|------|----------|
|
| | | `qwen_2head_probe/` | Qwen2.5-0.5B | 2 | 1 epoch (LoRA) | 0.1 | ~0.42 | **0.00** (frozen) |
|
| | | `gemma_2head_probe/` | Gemma-3 270M | 2 | 1 epoch (LoRA) | 0.1 | ~0.38 | **0.00** (frozen) |
|
| | | `gemma_9head_btrm/` | Gemma-3 270M | 9 | 10x coverage | 0.01 | 0.32 | **15.53** (full FT) |
|
| |
|
| | ### Training Evolution
|
| |
|
| | **Phase 1: Frozen Probes (LoRA)**
|
| | - Quick validation that Bradley-Terry loss works
|
| | - Base transformer frozen, only adapter + BTRM heads trained
|
| | - Higher logsquare (0.1) = stronger regularization toward unit logits
|
| | - Result: Loss converges, but limited expressivity
|
| |
|
| | **Phase 2: Full Fine-Tuning**
|
| | - Unfroze base transformer for end-to-end training
|
| | - Lower logsquare (0.01) = allows larger logit magnitudes
|
| | - Added synthetic corpora + structural genre heads
|
| | - Result: 2x more weight drift, better discrimination
|
| |
|
| | ### Weight Drift Analysis
|
| |
|
| | Post-training comparison against original pre-trained weights:
|
| |
|
| | **Frozen (LoRA) Models**: Zero drift on base transformer
|
| | ```
|
| | qwen_2head_probe: 0.00 L2 (472M params unchanged)
|
| | gemma_2head_probe: 0.00 L2 (253M params unchanged)
|
| | ```
|
| |
|
| | **Full Fine-Tuned Model**: Significant drift, especially in MLP layers
|
| | ```
|
| | gemma_9head_btrm: 15.53 L2 total (268M params)
|
| | - MLP: 11.20 L2 (3.26% relative)
|
| | - Embedding: 7.94 L2 (1.60% relative)
|
| | - Attention: 7.26 L2 (2.07% relative)
|
| | - Norm: 0.01 L2 (0.00% relative)
|
| | ```
|
| |
|
| | Top drifting layers are MLP `down_proj` weights (up to 15.7% relative change).
|
| |
|
| | ## Head Types
|
| |
|
| | ### Corpus Membership (6 heads in 9-head model)
|
| | Score whether text belongs to a specific narrative setting:
|
| |
|
| | | Head | Description | In Probes? |
|
| | |------|-------------|------------|
|
| | | `oblivion` | Imperial fantasy RPG (TES IV) | Yes |
|
| | | `fonv` | Post-apocalyptic Western (Fallout NV) | Yes |
|
| | | `skyrim` | Nordic fantasy RPG (TES V) | 9-head only |
|
| | | `gallia` | Franco-Roman bureaucratic fantasy (synthetic) | 9-head only |
|
| | | `marmotte` | Alpine corporate dystopia (synthetic) | 9-head only |
|
| | | `sanguo` | Three Kingdoms romance/otome (synthetic) | 9-head only |
|
| |
|
| | ### Structural Genre (3 heads, 9-head model only)
|
| | Score text format/style:
|
| |
|
| | | Head | Description |
|
| | |------|-------------|
|
| | | `multiturn_dialogue` | Raw quoted dialogue walks |
|
| | | `fk_normed_prose` | Flesch-Kincaid controlled prose |
|
| | | `brainrot_aesop` | Vocabulary teaching passages |
|
| |
|
| | ## Usage
|
| |
|
| | ```python
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| | import torch
|
| |
|
| | # Load 9-head model (full fine-tuned)
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | "SQCU/brainrot-partition-BTRMplus",
|
| | subfolder="gemma_9head_btrm/base_model",
|
| | torch_dtype=torch.bfloat16,
|
| | )
|
| | tokenizer = AutoTokenizer.from_pretrained(
|
| | "SQCU/brainrot-partition-BTRMplus",
|
| | subfolder="gemma_9head_btrm/base_model",
|
| | )
|
| |
|
| | # Load BTRM heads
|
| | from huggingface_hub import hf_hub_download
|
| | btrm_path = hf_hub_download(
|
| | "SQCU/brainrot-partition-BTRMplus",
|
| | "gemma_9head_btrm/btrm_heads.pt"
|
| | )
|
| | btrm_state = torch.load(btrm_path)
|
| | # btrm_state["btrm_state_dict"] contains the head weights
|
| | # btrm_state["head_names"] = ["skyrim", "oblivion", "fonv", ...]
|
| | ```
|
| |
|
| | ## Training Data
|
| |
|
| | - **Reference**: Oblivion, Fallout NV, Skyrim dialogue with emotion annotations
|
| | - **Synthetic**: Gallia v9, Marmotte v6, Sanguo v1 (structural translation pipeline)
|
| | - **Negatives**: Cross-corpus soft negatives, Wattpad, FineWeb, WikiText
|
| |
|
| | ## Architecture
|
| |
|
| | ```
|
| | Input Text
|
| | β
|
| | [Gemma-3 270M Transformer] β frozen (probes) or fine-tuned (9-head)
|
| | β
|
| | Last Hidden State (mean pooled)
|
| | β
|
| | [RMSNorm β Linear(hidden β N_heads)]
|
| | β
|
| | Per-head logits (soft tanh capped at Β±10)
|
| | ```
|
| |
|
| | Loss: `log(sigmoid(pos - neg))` + logsquare regularization on logit magnitudes.
|
| |
|
| | ## Observations
|
| |
|
| | 1. **Reference corpora discriminate better** than synthetic (skyrim/oblivion heads accurate, gallia/sanguo confused)
|
| | 2. **Structural heads work excellently** - prose vs dialogue vs aesop cleanly separated
|
| | 3. **Full fine-tuning helps** - 9-head model achieves lower loss than frozen probes
|
| | 4. **MLP layers adapt most** - down_proj weights show highest relative drift
|
| |
|
| | ## License
|
| |
|
| | Base model weights: Google Gemma License / Qwen License
|
| | Training data: Bethesda game dialogue (fair use for research), synthetic generation
|
| | |