#!/usr/bin/env python3 """ CF-HoT HEAD TRAINING - Contrastive Fine-tuning with Hidden-state Oversight Training ==================================================================================== Trains lightweight "heads" on model hidden states to detect and suppress: - Repetition (loops, repeated phrases) - Hedging ("As an AI...", "That's a great question!") - Verbosity ("Let me explain...", "To put it simply...") Usage: python train_cfhot_head.py --behavior repetition --steps 5000 python train_cfhot_head.py --behavior hedging --steps 3000 python train_cfhot_head.py --behavior verbosity --steps 3000 python train_cfhot_head.py --behavior all --steps 3000 "Predict the problem before it happens, prevent it at the source" """ import os import sys import json import argparse import random from datetime import datetime from pathlib import Path from typing import List, Dict, Any, Tuple from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader # === PATHS === ROOT = os.path.dirname(os.path.abspath(__file__)) RESULTS_DIR = os.path.join(ROOT, "results") DATA_DIR = os.path.join(ROOT, "cfhot_data") os.makedirs(RESULTS_DIR, exist_ok=True) os.makedirs(DATA_DIR, exist_ok=True) # Model path - adjust to your setup MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5" # ============================================================================== # DATA GENERATION - POSITIVE AND NEGATIVE EXAMPLES # ============================================================================== # REPETITION: Examples that repeat vs don't repeat REPETITION_POSITIVE = [ # Repeating phrases "The key is to understand, the key is to understand, the key is to understand that", "We need to consider, we need to consider, we need to think about", "It's important to note, it's important to note that this is important to note", "First, let me say, first let me say, first I want to say", "The thing is, the thing is, the thing is that we should", "As I mentioned, as I mentioned before, as I mentioned earlier", "To be clear, to be clear, to be perfectly clear about this", "In other words, in other words, to put it another way, in other words", "The point is, the point is, my point is that the point is", "What I mean is, what I mean is, what I'm trying to say is what I mean", # Word repetition "very very very important", "really really really good", "so so so much better", "the the the problem is", "I I I think that", ] REPETITION_NEGATIVE = [ # Clean, varied language "The key insight here is understanding the underlying mechanism.", "We should consider multiple perspectives on this issue.", "This is an important point worth emphasizing.", "Let me explain the concept clearly.", "The situation requires careful analysis.", "First, we examine the data. Then, we draw conclusions.", "To clarify: the process involves three distinct steps.", "In simpler terms, the algorithm optimizes for efficiency.", "The central argument rests on empirical evidence.", "What this means in practice is significant improvement.", "Neural networks learn representations automatically.", "Gradient descent minimizes the loss function iteratively.", "Recursion solves problems by breaking them into smaller subproblems.", "Hash tables provide O(1) average-case lookup time.", "Transformers use attention mechanisms for sequence modeling.", ] # HEDGING: Sycophantic/apologetic phrases vs direct responses HEDGING_POSITIVE = [ "That's a great question! Let me think about this.", "What a fascinating topic! I'd be happy to explore this with you.", "That's an excellent point! Thank you for bringing this up.", "I appreciate you asking! This is something I find very interesting.", "Great question! Many people wonder about this.", "As an AI language model, I don't have personal experiences, but", "I apologize, but I'm not able to provide that information.", "I'm sorry, but I cannot help with that request.", "Thank you for your patience! Let me try to help.", "I understand your concern! That's completely valid.", "What a wonderful question! I'm delighted to assist.", "I really appreciate you sharing that with me!", "That's so interesting! Tell me more about that.", "I'm honored you asked me! Let me do my best.", "Oh, that's a tricky one! But I'll give it a shot.", ] HEDGING_NEGATIVE = [ "The answer is straightforward: use a hash table.", "Recursion works by calling the function with smaller inputs.", "Neural networks learn through gradient descent.", "The algorithm has O(n log n) time complexity.", "This approach fails because it doesn't account for edge cases.", "The data shows a clear correlation between the variables.", "Quantum mechanics describes probability amplitudes.", "Evolution operates through natural selection.", "The proof follows from the axioms directly.", "TCP ensures reliable data transmission.", "Compile the code with optimization flags enabled.", "The database index improves query performance.", "Cache invalidation is a hard problem.", "The gradient points in the direction of steepest ascent.", "Entropy measures the disorder of a system.", ] # VERBOSITY: Wordy preambles vs direct starts VERBOSITY_POSITIVE = [ "Let me explain this to you in detail so you can understand.", "To put it simply, what I'm trying to say is that", "In other words, to clarify what I mean, basically", "First of all, before I answer, I should mention that", "To begin with, it's important to understand that", "Essentially, what this boils down to is the fact that", "Basically, in simple terms, what we're looking at here is", "Allow me to elaborate on this point for you.", "I'd like to take a moment to explain this concept.", "Before we dive in, let me provide some context.", "To give you a comprehensive answer, I'll need to explain", "In order to fully understand this, we must first consider", "The thing you need to know about this is that", "What you're essentially asking about is related to", "To answer your question thoroughly, let me start by saying", ] VERBOSITY_NEGATIVE = [ "Hash tables use O(1) lookup.", "The gradient points downhill.", "Recursion needs a base case.", "Attention weights sum to one.", "TCP guarantees delivery.", "Entropy increases over time.", "Backprop computes gradients.", "DNA encodes proteins.", "Light travels at c.", "Neurons fire or don't.", "Memory is limited.", "Caching improves speed.", "Indexes help queries.", "Locks prevent races.", "Tests catch bugs.", ] # ============================================================================== # MULTI-HEAD PREDICTOR ARCHITECTURE # ============================================================================== class RiskPredictor(nn.Module): """Single-head risk predictor for one behavior type.""" def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64): super().__init__() self.d_model = d_model self.n_layers = n_layers self.d_fiber = d_fiber # Fiber projections for each layer self.fiber_projs = nn.ModuleList([ nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers) ]) # Learnable layer weights self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) # Prediction head self.predictor = nn.Sequential( nn.Linear(d_fiber, d_control), nn.GELU(), nn.Linear(d_control, d_control), nn.GELU(), nn.Linear(d_control, 1) ) def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: """ Args: hidden_states: List of [batch, seq_len, d_model] tensors, one per layer Returns: risk_scores: [batch, seq_len] tensor of risk probabilities """ # Project each layer to fiber space fibers = [] for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)): if i < len(hidden_states): fibers.append(proj(h.float())) # Aggregate with learned weights weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) aggregated = sum(w * f for w, f in zip(weights, fibers)) # Predict risk logits = self.predictor(aggregated).squeeze(-1) return torch.sigmoid(logits) class MultiHeadPredictor(nn.Module): """Multi-head predictor for all behavior types.""" def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64): super().__init__() self.d_model = d_model self.n_layers = n_layers self.d_fiber = d_fiber # Shared fiber projections self.fiber_projs = nn.ModuleList([ nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers) ]) self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) # Behavior-specific heads self.heads = nn.ModuleDict({ 'repetition': self._make_head(d_fiber, d_control), 'hedging': self._make_head(d_fiber, d_control), 'verbosity': self._make_head(d_fiber, d_control), }) def _make_head(self, d_fiber: int, d_control: int) -> nn.Module: return nn.Sequential( nn.Linear(d_fiber, d_control), nn.GELU(), nn.Linear(d_control, d_control), nn.GELU(), nn.Linear(d_control, 1) ) def forward(self, hidden_states: List[torch.Tensor], head_name: str) -> torch.Tensor: # Project to fiber space fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)] weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) aggregated = sum(w * f for w, f in zip(weights, fibers)) # Apply specific head logits = self.heads[head_name](aggregated).squeeze(-1) return torch.sigmoid(logits) def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]: fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)] weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) aggregated = sum(w * f for w, f in zip(weights, fibers)) return { name: torch.sigmoid(head(aggregated).squeeze(-1)) for name, head in self.heads.items() } # ============================================================================== # TRAINING # ============================================================================== def get_data_for_behavior(behavior: str) -> Tuple[List[str], List[str]]: """Get positive and negative examples for a behavior.""" if behavior == "repetition": return REPETITION_POSITIVE, REPETITION_NEGATIVE elif behavior == "hedging": return HEDGING_POSITIVE, HEDGING_NEGATIVE elif behavior == "verbosity": return VERBOSITY_POSITIVE, VERBOSITY_NEGATIVE else: raise ValueError(f"Unknown behavior: {behavior}") def collect_hidden_states(model, tokenizer, texts: List[str], device) -> List[torch.Tensor]: """Collect hidden states from model for given texts.""" all_hidden_states = [] model.eval() with torch.no_grad(): for text in texts: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model(**inputs, output_hidden_states=True, return_dict=True) # Get hidden states from all layers [n_layers, batch, seq, d_model] hidden = outputs.hidden_states[1:] # Skip embedding layer # Take the last token's hidden state from each layer last_hidden = [h[:, -1, :] for h in hidden] # [n_layers] of [batch, d_model] all_hidden_states.append(last_hidden) return all_hidden_states def train_head( behavior: str, model_path: str, steps: int = 3000, lr: float = 1e-4, d_fiber: int = 16, d_control: int = 64, checkpoint_every: int = 500 ): """Train a single behavior head.""" print(f"\n{'='*70}") print(f"TRAINING {behavior.upper()} HEAD") print(f"{'='*70}") from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Load model print(f"[{behavior}] Loading model: {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( model_path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, local_files_only=True ) model.eval() device = next(model.parameters()).device n_layers = model.config.num_hidden_layers d_model = model.config.hidden_size print(f"[{behavior}] Model loaded: {n_layers} layers, {d_model} dims") # Get training data positive_texts, negative_texts = get_data_for_behavior(behavior) print(f"[{behavior}] Data: {len(positive_texts)} positive, {len(negative_texts)} negative") # Collect hidden states print(f"[{behavior}] Collecting hidden states...") positive_hidden = collect_hidden_states(model, tokenizer, positive_texts, device) negative_hidden = collect_hidden_states(model, tokenizer, negative_texts, device) # Initialize predictor predictor = RiskPredictor(d_model, n_layers, d_fiber, d_control).to(device).float() optimizer = torch.optim.AdamW(predictor.parameters(), lr=lr) criterion = nn.BCELoss() # Training loop predictor.train() total_loss = 0 results_dir = os.path.join(RESULTS_DIR, f"{behavior}_head") os.makedirs(results_dir, exist_ok=True) for step in range(steps): # Sample batch if random.random() > 0.5: # Positive example idx = random.randint(0, len(positive_hidden) - 1) hidden = positive_hidden[idx] target = torch.ones(1, device=device) else: # Negative example idx = random.randint(0, len(negative_hidden) - 1) hidden = negative_hidden[idx] target = torch.zeros(1, device=device) # Forward pred = predictor(hidden) pred = pred.mean() # Average over sequence loss = criterion(pred.unsqueeze(0), target) # Backward optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) optimizer.step() total_loss += loss.item() if (step + 1) % 100 == 0: avg_loss = total_loss / 100 print(f" Step {step+1}/{steps}: loss={avg_loss:.4f}") total_loss = 0 # Checkpoint if (step + 1) % checkpoint_every == 0: ckpt_dir = os.path.join(results_dir, f"ckpt_{step+1}") os.makedirs(ckpt_dir, exist_ok=True) # Evaluate separation predictor.eval() with torch.no_grad(): pos_scores = [predictor(h).mean().item() for h in positive_hidden] neg_scores = [predictor(h).mean().item() for h in negative_hidden] predictor.train() avg_pos = sum(pos_scores) / len(pos_scores) avg_neg = sum(neg_scores) / len(neg_scores) separation = avg_pos / max(avg_neg, 1e-6) print(f"\n Checkpoint {step+1}:") print(f" Avg positive: {avg_pos:.4f}") print(f" Avg negative: {avg_neg:.4f}") print(f" Separation: {separation:.1f}x\n") # Save torch.save({ 'step': step + 1, 'predictor_state': predictor.state_dict(), 'risk_predictor': { **{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)}, 'layer_weights': predictor.layer_weights, 'predictor.0.weight': predictor.predictor[0].weight, 'predictor.0.bias': predictor.predictor[0].bias, 'predictor.2.weight': predictor.predictor[2].weight, 'predictor.2.bias': predictor.predictor[2].bias, 'predictor.4.weight': predictor.predictor[4].weight, 'predictor.4.bias': predictor.predictor[4].bias, }, 'result': { 'avg_positive': avg_pos, 'avg_negative': avg_neg, 'separation': separation, } }, os.path.join(ckpt_dir, f"{behavior}_head.pt")) # Also save as risk_predictor.pt for compatibility torch.save({ 'step': step + 1, 'risk_predictor': { **{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)}, 'layer_weights': predictor.layer_weights, 'predictor.0.weight': predictor.predictor[0].weight, 'predictor.0.bias': predictor.predictor[0].bias, 'predictor.2.weight': predictor.predictor[2].weight, 'predictor.2.bias': predictor.predictor[2].bias, 'predictor.4.weight': predictor.predictor[4].weight, 'predictor.4.bias': predictor.predictor[4].bias, }, 'result': { 'avg_positive': avg_pos, 'avg_negative': avg_neg, 'separation': separation, } }, os.path.join(ckpt_dir, "risk_predictor.pt")) # Final evaluation predictor.eval() with torch.no_grad(): pos_scores = [predictor(h).mean().item() for h in positive_hidden] neg_scores = [predictor(h).mean().item() for h in negative_hidden] avg_pos = sum(pos_scores) / len(pos_scores) avg_neg = sum(neg_scores) / len(neg_scores) separation = avg_pos / max(avg_neg, 1e-6) print(f"\n{'='*50}") print(f"FINAL RESULTS - {behavior.upper()} HEAD") print(f"{'='*50}") print(f" Avg positive score: {avg_pos:.4f}") print(f" Avg negative score: {avg_neg:.4f}") print(f" Separation: {separation:.1f}x") print(f"{'='*50}") return { 'behavior': behavior, 'separation': separation, 'avg_positive': avg_pos, 'avg_negative': avg_neg, 'results_dir': results_dir, } def train_all_heads(model_path: str, steps: int = 3000): """Train all behavior heads.""" results = {} for behavior in ["repetition", "hedging", "verbosity"]: result = train_head(behavior, model_path, steps) results[behavior] = result print("\n" + "="*70) print("ALL HEADS TRAINED") print("="*70) for behavior, result in results.items(): print(f" {behavior}: {result['separation']:.1f}x separation") print("="*70) return results # ============================================================================== # MAIN # ============================================================================== def main(): parser = argparse.ArgumentParser(description="CF-HoT Head Training") parser.add_argument("--behavior", type=str, default="repetition", help="Behavior to train: repetition, hedging, verbosity, all") parser.add_argument("--steps", type=int, default=3000, help="Training steps") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path") parser.add_argument("--d-fiber", type=int, default=16, help="Fiber dimension") parser.add_argument("--d-control", type=int, default=64, help="Control dimension") args = parser.parse_args() print("="*70) print("CF-HoT HEAD TRAINING") print("="*70) print(f" Behavior: {args.behavior}") print(f" Steps: {args.steps}") print(f" Learning rate: {args.lr}") print(f" Model: {args.model_path}") print("="*70) if args.behavior == "all": train_all_heads(args.model_path, args.steps) else: train_head( args.behavior, args.model_path, args.steps, args.lr, args.d_fiber, args.d_control ) if __name__ == "__main__": main()