|
|
|
|
|
""" |
|
|
CF-HoT HEAD TRAINING - Contrastive Fine-tuning with Hidden-state Oversight Training |
|
|
==================================================================================== |
|
|
Trains lightweight "heads" on model hidden states to detect and suppress: |
|
|
- Repetition (loops, repeated phrases) |
|
|
- Hedging ("As an AI...", "That's a great question!") |
|
|
- Verbosity ("Let me explain...", "To put it simply...") |
|
|
|
|
|
Usage: |
|
|
python train_cfhot_head.py --behavior repetition --steps 5000 |
|
|
python train_cfhot_head.py --behavior hedging --steps 3000 |
|
|
python train_cfhot_head.py --behavior verbosity --steps 3000 |
|
|
python train_cfhot_head.py --behavior all --steps 3000 |
|
|
|
|
|
"Predict the problem before it happens, prevent it at the source" |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import random |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
RESULTS_DIR = os.path.join(ROOT, "results") |
|
|
DATA_DIR = os.path.join(ROOT, "cfhot_data") |
|
|
|
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
MODEL_PATH = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPETITION_POSITIVE = [ |
|
|
|
|
|
"The key is to understand, the key is to understand, the key is to understand that", |
|
|
"We need to consider, we need to consider, we need to think about", |
|
|
"It's important to note, it's important to note that this is important to note", |
|
|
"First, let me say, first let me say, first I want to say", |
|
|
"The thing is, the thing is, the thing is that we should", |
|
|
"As I mentioned, as I mentioned before, as I mentioned earlier", |
|
|
"To be clear, to be clear, to be perfectly clear about this", |
|
|
"In other words, in other words, to put it another way, in other words", |
|
|
"The point is, the point is, my point is that the point is", |
|
|
"What I mean is, what I mean is, what I'm trying to say is what I mean", |
|
|
|
|
|
"very very very important", |
|
|
"really really really good", |
|
|
"so so so much better", |
|
|
"the the the problem is", |
|
|
"I I I think that", |
|
|
] |
|
|
|
|
|
REPETITION_NEGATIVE = [ |
|
|
|
|
|
"The key insight here is understanding the underlying mechanism.", |
|
|
"We should consider multiple perspectives on this issue.", |
|
|
"This is an important point worth emphasizing.", |
|
|
"Let me explain the concept clearly.", |
|
|
"The situation requires careful analysis.", |
|
|
"First, we examine the data. Then, we draw conclusions.", |
|
|
"To clarify: the process involves three distinct steps.", |
|
|
"In simpler terms, the algorithm optimizes for efficiency.", |
|
|
"The central argument rests on empirical evidence.", |
|
|
"What this means in practice is significant improvement.", |
|
|
"Neural networks learn representations automatically.", |
|
|
"Gradient descent minimizes the loss function iteratively.", |
|
|
"Recursion solves problems by breaking them into smaller subproblems.", |
|
|
"Hash tables provide O(1) average-case lookup time.", |
|
|
"Transformers use attention mechanisms for sequence modeling.", |
|
|
] |
|
|
|
|
|
|
|
|
HEDGING_POSITIVE = [ |
|
|
"That's a great question! Let me think about this.", |
|
|
"What a fascinating topic! I'd be happy to explore this with you.", |
|
|
"That's an excellent point! Thank you for bringing this up.", |
|
|
"I appreciate you asking! This is something I find very interesting.", |
|
|
"Great question! Many people wonder about this.", |
|
|
"As an AI language model, I don't have personal experiences, but", |
|
|
"I apologize, but I'm not able to provide that information.", |
|
|
"I'm sorry, but I cannot help with that request.", |
|
|
"Thank you for your patience! Let me try to help.", |
|
|
"I understand your concern! That's completely valid.", |
|
|
"What a wonderful question! I'm delighted to assist.", |
|
|
"I really appreciate you sharing that with me!", |
|
|
"That's so interesting! Tell me more about that.", |
|
|
"I'm honored you asked me! Let me do my best.", |
|
|
"Oh, that's a tricky one! But I'll give it a shot.", |
|
|
] |
|
|
|
|
|
HEDGING_NEGATIVE = [ |
|
|
"The answer is straightforward: use a hash table.", |
|
|
"Recursion works by calling the function with smaller inputs.", |
|
|
"Neural networks learn through gradient descent.", |
|
|
"The algorithm has O(n log n) time complexity.", |
|
|
"This approach fails because it doesn't account for edge cases.", |
|
|
"The data shows a clear correlation between the variables.", |
|
|
"Quantum mechanics describes probability amplitudes.", |
|
|
"Evolution operates through natural selection.", |
|
|
"The proof follows from the axioms directly.", |
|
|
"TCP ensures reliable data transmission.", |
|
|
"Compile the code with optimization flags enabled.", |
|
|
"The database index improves query performance.", |
|
|
"Cache invalidation is a hard problem.", |
|
|
"The gradient points in the direction of steepest ascent.", |
|
|
"Entropy measures the disorder of a system.", |
|
|
] |
|
|
|
|
|
|
|
|
VERBOSITY_POSITIVE = [ |
|
|
"Let me explain this to you in detail so you can understand.", |
|
|
"To put it simply, what I'm trying to say is that", |
|
|
"In other words, to clarify what I mean, basically", |
|
|
"First of all, before I answer, I should mention that", |
|
|
"To begin with, it's important to understand that", |
|
|
"Essentially, what this boils down to is the fact that", |
|
|
"Basically, in simple terms, what we're looking at here is", |
|
|
"Allow me to elaborate on this point for you.", |
|
|
"I'd like to take a moment to explain this concept.", |
|
|
"Before we dive in, let me provide some context.", |
|
|
"To give you a comprehensive answer, I'll need to explain", |
|
|
"In order to fully understand this, we must first consider", |
|
|
"The thing you need to know about this is that", |
|
|
"What you're essentially asking about is related to", |
|
|
"To answer your question thoroughly, let me start by saying", |
|
|
] |
|
|
|
|
|
VERBOSITY_NEGATIVE = [ |
|
|
"Hash tables use O(1) lookup.", |
|
|
"The gradient points downhill.", |
|
|
"Recursion needs a base case.", |
|
|
"Attention weights sum to one.", |
|
|
"TCP guarantees delivery.", |
|
|
"Entropy increases over time.", |
|
|
"Backprop computes gradients.", |
|
|
"DNA encodes proteins.", |
|
|
"Light travels at c.", |
|
|
"Neurons fire or don't.", |
|
|
"Memory is limited.", |
|
|
"Caching improves speed.", |
|
|
"Indexes help queries.", |
|
|
"Locks prevent races.", |
|
|
"Tests catch bugs.", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RiskPredictor(nn.Module): |
|
|
"""Single-head risk predictor for one behavior type.""" |
|
|
|
|
|
def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.d_fiber = d_fiber |
|
|
|
|
|
|
|
|
self.fiber_projs = nn.ModuleList([ |
|
|
nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) |
|
|
|
|
|
|
|
|
self.predictor = nn.Sequential( |
|
|
nn.Linear(d_fiber, d_control), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_control, d_control), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_control, 1) |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: List of [batch, seq_len, d_model] tensors, one per layer |
|
|
Returns: |
|
|
risk_scores: [batch, seq_len] tensor of risk probabilities |
|
|
""" |
|
|
|
|
|
fibers = [] |
|
|
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)): |
|
|
if i < len(hidden_states): |
|
|
fibers.append(proj(h.float())) |
|
|
|
|
|
|
|
|
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) |
|
|
aggregated = sum(w * f for w, f in zip(weights, fibers)) |
|
|
|
|
|
|
|
|
logits = self.predictor(aggregated).squeeze(-1) |
|
|
return torch.sigmoid(logits) |
|
|
|
|
|
|
|
|
class MultiHeadPredictor(nn.Module): |
|
|
"""Multi-head predictor for all behavior types.""" |
|
|
|
|
|
def __init__(self, d_model: int, n_layers: int, d_fiber: int = 16, d_control: int = 64): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.d_fiber = d_fiber |
|
|
|
|
|
|
|
|
self.fiber_projs = nn.ModuleList([ |
|
|
nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers) |
|
|
]) |
|
|
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) |
|
|
|
|
|
|
|
|
self.heads = nn.ModuleDict({ |
|
|
'repetition': self._make_head(d_fiber, d_control), |
|
|
'hedging': self._make_head(d_fiber, d_control), |
|
|
'verbosity': self._make_head(d_fiber, d_control), |
|
|
}) |
|
|
|
|
|
def _make_head(self, d_fiber: int, d_control: int) -> nn.Module: |
|
|
return nn.Sequential( |
|
|
nn.Linear(d_fiber, d_control), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_control, d_control), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_control, 1) |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: List[torch.Tensor], head_name: str) -> torch.Tensor: |
|
|
|
|
|
fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)] |
|
|
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) |
|
|
aggregated = sum(w * f for w, f in zip(weights, fibers)) |
|
|
|
|
|
|
|
|
logits = self.heads[head_name](aggregated).squeeze(-1) |
|
|
return torch.sigmoid(logits) |
|
|
|
|
|
def get_all_risks(self, hidden_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
fibers = [proj(h.float()) for proj, h in zip(self.fiber_projs, hidden_states)] |
|
|
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) |
|
|
aggregated = sum(w * f for w, f in zip(weights, fibers)) |
|
|
|
|
|
return { |
|
|
name: torch.sigmoid(head(aggregated).squeeze(-1)) |
|
|
for name, head in self.heads.items() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_data_for_behavior(behavior: str) -> Tuple[List[str], List[str]]: |
|
|
"""Get positive and negative examples for a behavior.""" |
|
|
if behavior == "repetition": |
|
|
return REPETITION_POSITIVE, REPETITION_NEGATIVE |
|
|
elif behavior == "hedging": |
|
|
return HEDGING_POSITIVE, HEDGING_NEGATIVE |
|
|
elif behavior == "verbosity": |
|
|
return VERBOSITY_POSITIVE, VERBOSITY_NEGATIVE |
|
|
else: |
|
|
raise ValueError(f"Unknown behavior: {behavior}") |
|
|
|
|
|
|
|
|
def collect_hidden_states(model, tokenizer, texts: List[str], device) -> List[torch.Tensor]: |
|
|
"""Collect hidden states from model for given texts.""" |
|
|
all_hidden_states = [] |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for text in texts: |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = model(**inputs, output_hidden_states=True, return_dict=True) |
|
|
|
|
|
|
|
|
hidden = outputs.hidden_states[1:] |
|
|
|
|
|
|
|
|
last_hidden = [h[:, -1, :] for h in hidden] |
|
|
all_hidden_states.append(last_hidden) |
|
|
|
|
|
return all_hidden_states |
|
|
|
|
|
|
|
|
def train_head( |
|
|
behavior: str, |
|
|
model_path: str, |
|
|
steps: int = 3000, |
|
|
lr: float = 1e-4, |
|
|
d_fiber: int = 16, |
|
|
d_control: int = 64, |
|
|
checkpoint_every: int = 500 |
|
|
): |
|
|
"""Train a single behavior head.""" |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"TRAINING {behavior.upper()} HEAD") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
|
|
|
|
|
|
print(f"[{behavior}] Loading model: {model_path}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
local_files_only=True |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
n_layers = model.config.num_hidden_layers |
|
|
d_model = model.config.hidden_size |
|
|
|
|
|
print(f"[{behavior}] Model loaded: {n_layers} layers, {d_model} dims") |
|
|
|
|
|
|
|
|
positive_texts, negative_texts = get_data_for_behavior(behavior) |
|
|
print(f"[{behavior}] Data: {len(positive_texts)} positive, {len(negative_texts)} negative") |
|
|
|
|
|
|
|
|
print(f"[{behavior}] Collecting hidden states...") |
|
|
positive_hidden = collect_hidden_states(model, tokenizer, positive_texts, device) |
|
|
negative_hidden = collect_hidden_states(model, tokenizer, negative_texts, device) |
|
|
|
|
|
|
|
|
predictor = RiskPredictor(d_model, n_layers, d_fiber, d_control).to(device).float() |
|
|
optimizer = torch.optim.AdamW(predictor.parameters(), lr=lr) |
|
|
criterion = nn.BCELoss() |
|
|
|
|
|
|
|
|
predictor.train() |
|
|
total_loss = 0 |
|
|
|
|
|
results_dir = os.path.join(RESULTS_DIR, f"{behavior}_head") |
|
|
os.makedirs(results_dir, exist_ok=True) |
|
|
|
|
|
for step in range(steps): |
|
|
|
|
|
if random.random() > 0.5: |
|
|
|
|
|
idx = random.randint(0, len(positive_hidden) - 1) |
|
|
hidden = positive_hidden[idx] |
|
|
target = torch.ones(1, device=device) |
|
|
else: |
|
|
|
|
|
idx = random.randint(0, len(negative_hidden) - 1) |
|
|
hidden = negative_hidden[idx] |
|
|
target = torch.zeros(1, device=device) |
|
|
|
|
|
|
|
|
pred = predictor(hidden) |
|
|
pred = pred.mean() |
|
|
|
|
|
loss = criterion(pred.unsqueeze(0), target) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
if (step + 1) % 100 == 0: |
|
|
avg_loss = total_loss / 100 |
|
|
print(f" Step {step+1}/{steps}: loss={avg_loss:.4f}") |
|
|
total_loss = 0 |
|
|
|
|
|
|
|
|
if (step + 1) % checkpoint_every == 0: |
|
|
ckpt_dir = os.path.join(results_dir, f"ckpt_{step+1}") |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
predictor.eval() |
|
|
with torch.no_grad(): |
|
|
pos_scores = [predictor(h).mean().item() for h in positive_hidden] |
|
|
neg_scores = [predictor(h).mean().item() for h in negative_hidden] |
|
|
predictor.train() |
|
|
|
|
|
avg_pos = sum(pos_scores) / len(pos_scores) |
|
|
avg_neg = sum(neg_scores) / len(neg_scores) |
|
|
separation = avg_pos / max(avg_neg, 1e-6) |
|
|
|
|
|
print(f"\n Checkpoint {step+1}:") |
|
|
print(f" Avg positive: {avg_pos:.4f}") |
|
|
print(f" Avg negative: {avg_neg:.4f}") |
|
|
print(f" Separation: {separation:.1f}x\n") |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'step': step + 1, |
|
|
'predictor_state': predictor.state_dict(), |
|
|
'risk_predictor': { |
|
|
**{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)}, |
|
|
'layer_weights': predictor.layer_weights, |
|
|
'predictor.0.weight': predictor.predictor[0].weight, |
|
|
'predictor.0.bias': predictor.predictor[0].bias, |
|
|
'predictor.2.weight': predictor.predictor[2].weight, |
|
|
'predictor.2.bias': predictor.predictor[2].bias, |
|
|
'predictor.4.weight': predictor.predictor[4].weight, |
|
|
'predictor.4.bias': predictor.predictor[4].bias, |
|
|
}, |
|
|
'result': { |
|
|
'avg_positive': avg_pos, |
|
|
'avg_negative': avg_neg, |
|
|
'separation': separation, |
|
|
} |
|
|
}, os.path.join(ckpt_dir, f"{behavior}_head.pt")) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'step': step + 1, |
|
|
'risk_predictor': { |
|
|
**{f'fiber_projs.{i}.weight': predictor.fiber_projs[i].weight for i in range(n_layers)}, |
|
|
'layer_weights': predictor.layer_weights, |
|
|
'predictor.0.weight': predictor.predictor[0].weight, |
|
|
'predictor.0.bias': predictor.predictor[0].bias, |
|
|
'predictor.2.weight': predictor.predictor[2].weight, |
|
|
'predictor.2.bias': predictor.predictor[2].bias, |
|
|
'predictor.4.weight': predictor.predictor[4].weight, |
|
|
'predictor.4.bias': predictor.predictor[4].bias, |
|
|
}, |
|
|
'result': { |
|
|
'avg_positive': avg_pos, |
|
|
'avg_negative': avg_neg, |
|
|
'separation': separation, |
|
|
} |
|
|
}, os.path.join(ckpt_dir, "risk_predictor.pt")) |
|
|
|
|
|
|
|
|
predictor.eval() |
|
|
with torch.no_grad(): |
|
|
pos_scores = [predictor(h).mean().item() for h in positive_hidden] |
|
|
neg_scores = [predictor(h).mean().item() for h in negative_hidden] |
|
|
|
|
|
avg_pos = sum(pos_scores) / len(pos_scores) |
|
|
avg_neg = sum(neg_scores) / len(neg_scores) |
|
|
separation = avg_pos / max(avg_neg, 1e-6) |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"FINAL RESULTS - {behavior.upper()} HEAD") |
|
|
print(f"{'='*50}") |
|
|
print(f" Avg positive score: {avg_pos:.4f}") |
|
|
print(f" Avg negative score: {avg_neg:.4f}") |
|
|
print(f" Separation: {separation:.1f}x") |
|
|
print(f"{'='*50}") |
|
|
|
|
|
return { |
|
|
'behavior': behavior, |
|
|
'separation': separation, |
|
|
'avg_positive': avg_pos, |
|
|
'avg_negative': avg_neg, |
|
|
'results_dir': results_dir, |
|
|
} |
|
|
|
|
|
|
|
|
def train_all_heads(model_path: str, steps: int = 3000): |
|
|
"""Train all behavior heads.""" |
|
|
results = {} |
|
|
|
|
|
for behavior in ["repetition", "hedging", "verbosity"]: |
|
|
result = train_head(behavior, model_path, steps) |
|
|
results[behavior] = result |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("ALL HEADS TRAINED") |
|
|
print("="*70) |
|
|
for behavior, result in results.items(): |
|
|
print(f" {behavior}: {result['separation']:.1f}x separation") |
|
|
print("="*70) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="CF-HoT Head Training") |
|
|
parser.add_argument("--behavior", type=str, default="repetition", |
|
|
help="Behavior to train: repetition, hedging, verbosity, all") |
|
|
parser.add_argument("--steps", type=int, default=3000, help="Training steps") |
|
|
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") |
|
|
parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Base model path") |
|
|
parser.add_argument("--d-fiber", type=int, default=16, help="Fiber dimension") |
|
|
parser.add_argument("--d-control", type=int, default=64, help="Control dimension") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("="*70) |
|
|
print("CF-HoT HEAD TRAINING") |
|
|
print("="*70) |
|
|
print(f" Behavior: {args.behavior}") |
|
|
print(f" Steps: {args.steps}") |
|
|
print(f" Learning rate: {args.lr}") |
|
|
print(f" Model: {args.model_path}") |
|
|
print("="*70) |
|
|
|
|
|
if args.behavior == "all": |
|
|
train_all_heads(args.model_path, args.steps) |
|
|
else: |
|
|
train_head( |
|
|
args.behavior, |
|
|
args.model_path, |
|
|
args.steps, |
|
|
args.lr, |
|
|
args.d_fiber, |
|
|
args.d_control |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|