modernbert-us-stablecoin-encoder / train_encoder_lora.py
sugiv's picture
Upload train_encoder_lora.py with huggingface_hub
c24b776 verified
#!/usr/bin/env python3
"""
Production Encoder LoRA Training for Stablebridge
Trains LoRA adapters on BAAI/bge-m3 for US regulatory domain.
Implements tech spec requirements:
- LoRA rank 16, alpha 32
- 8192 token context window
- MultipleNegativesRankingLoss (in-batch negatives)
- WandB logging, checkpointing, evaluation
- Model Hub push
"""
import argparse
import json
import os
import torch
import wandb
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from transformers import (
AutoTokenizer,
AutoModel,
get_cosine_schedule_with_warmup,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from tqdm import tqdm
@dataclass
class EncoderTrainingConfig:
"""Complete training configuration matching tech spec."""
# Model
base_model: str = "BAAI/bge-m3"
max_length: int = 8192 # Full context per tech spec
# LoRA
lora_rank: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.1
target_modules: List[str] = field(default_factory=lambda: ["query", "key", "value"])
# Training
epochs: int = 3
per_device_batch_size: int = 4 # RTX 6000 Ada - will adjust based on memory
gradient_accumulation_steps: int = 16 # Effective batch size = 64
learning_rate: float = 5e-5
weight_decay: float = 0.01
warmup_ratio: float = 0.1
max_grad_norm: float = 1.0
# Precision
mixed_precision: str = "bf16" # Will fall back to fp16 if needed
# Checkpointing
save_steps: int = 500
eval_steps: int = 500
logging_steps: int = 50
# Paths
data_path: str = "/workspace/data/labels/encoder_triplets.jsonl"
corpus_dir: str = "/workspace/data/raw"
output_dir: str = "/workspace/checkpoints/bge-m3-us-regulatory-lora"
# Monitoring
wandb_project: str = "stablebridge-encoder"
wandb_run_name: Optional[str] = None
# Hub
push_to_hub: bool = True
hub_model_id: str = "cognilogue/bge-m3-us-regulatory-lora"
hub_token: Optional[str] = None
# Evaluation
eval_split: float = 0.1 # Hold out 10% for validation
eval_metrics: List[str] = field(default_factory=lambda: ["ndcg@10", "mrr@10", "recall@100"])
class TripletDataset(Dataset):
"""Dataset for encoder triplet training with in-batch negatives."""
def __init__(
self,
triplets: List[Dict],
corpus: Dict[str, str],
tokenizer,
max_length: int = 8192
):
self.triplets = triplets
self.corpus = corpus
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.triplets)
def __getitem__(self, idx):
triplet = self.triplets[idx]
query = triplet["query"]
pos_id = triplet["positive"]
# Get positive document
positive_text = self.corpus.get(pos_id, "")
# Tokenize
query_enc = self.tokenizer(
query,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)
pos_enc = self.tokenizer(
positive_text,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
)
return {
"query_input_ids": query_enc["input_ids"].squeeze(0),
"query_attention_mask": query_enc["attention_mask"].squeeze(0),
"pos_input_ids": pos_enc["input_ids"].squeeze(0),
"pos_attention_mask": pos_enc["attention_mask"].squeeze(0),
}
def mean_pooling(model_output, attention_mask):
"""Mean pooling over token embeddings (ignore padding)."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def compute_loss(query_emb, pos_emb, temperature=0.05):
"""
Multiple Negatives Ranking Loss (InfoNCE).
Uses in-batch negatives: all other positives in the batch serve as negatives.
Standard approach in sentence-transformers contrastive learning.
Args:
query_emb: (batch_size, hidden_dim) - normalized query embeddings
pos_emb: (batch_size, hidden_dim) - normalized positive embeddings
temperature: Temperature for softmax (default 0.05)
Returns:
loss: Scalar loss value
"""
# Normalize embeddings
query_emb = F.normalize(query_emb, p=2, dim=1)
pos_emb = F.normalize(pos_emb, p=2, dim=1)
# Compute similarity matrix: (batch_size, batch_size)
# query_emb[i] @ pos_emb[j].T gives similarity between query i and doc j
sim_matrix = torch.matmul(query_emb, pos_emb.T) / temperature
# Labels: diagonal elements are positive pairs
labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device)
# Cross-entropy: pulls positives closer, pushes negatives away
loss = F.cross_entropy(sim_matrix, labels)
return loss
def evaluate_retrieval(model, tokenizer, eval_data, corpus, device, config):
"""
Evaluate retrieval quality on validation set.
Metrics:
- NDCG@10: Ranking quality
- MRR@10: Mean Reciprocal Rank
- Recall@100: Coverage
"""
model.eval()
# Encode all documents
print("\nEncoding corpus for evaluation...")
doc_ids = list(corpus.keys())
doc_embeddings = []
with torch.no_grad():
for doc_id in tqdm(doc_ids, desc="Encoding docs"):
doc_text = corpus[doc_id]
doc_enc = tokenizer(
doc_text,
max_length=config.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(device)
doc_output = model(**doc_enc)
doc_emb = mean_pooling(doc_output, doc_enc["attention_mask"])
doc_emb = F.normalize(doc_emb, p=2, dim=1)
doc_embeddings.append(doc_emb.cpu())
doc_embeddings = torch.cat(doc_embeddings, dim=0) # (num_docs, hidden_dim)
# Evaluate queries
ndcg_scores = []
mrr_scores = []
recall_scores = []
with torch.no_grad():
for triplet in tqdm(eval_data, desc="Evaluating"):
query = triplet["query"]
pos_id = triplet["positive"]
# Encode query
query_enc = tokenizer(
query,
max_length=config.max_length,
truncation=True,
padding="max_length",
return_tensors="pt"
).to(device)
query_output = model(**query_enc)
query_emb = mean_pooling(query_output, query_enc["attention_mask"])
query_emb = F.normalize(query_emb, p=2, dim=1)
# Compute similarities
similarities = torch.matmul(query_emb.cpu(), doc_embeddings.T).squeeze(0)
# Rank documents
ranks = torch.argsort(similarities, descending=True)
# Find position of positive document
try:
pos_idx = doc_ids.index(pos_id)
pos_rank = (ranks == pos_idx).nonzero(as_tuple=True)[0].item() + 1
except (ValueError, IndexError):
pos_rank = len(doc_ids) + 1 # Not found
# NDCG@10
if pos_rank <= 10:
ndcg = 1.0 / np.log2(pos_rank + 1)
else:
ndcg = 0.0
ndcg_scores.append(ndcg)
# MRR@10
if pos_rank <= 10:
mrr = 1.0 / pos_rank
else:
mrr = 0.0
mrr_scores.append(mrr)
# Recall@100
recall = 1.0 if pos_rank <= 100 else 0.0
recall_scores.append(recall)
metrics = {
"eval/ndcg@10": np.mean(ndcg_scores),
"eval/mrr@10": np.mean(mrr_scores),
"eval/recall@100": np.mean(recall_scores),
}
return metrics
def load_data(config: EncoderTrainingConfig):
"""Load triplets and corpus, split train/eval."""
# Load triplets
print(f"Loading triplets from {config.data_path}...")
triplets = []
with open(config.data_path) as f:
for line in f:
if line.strip():
triplets.append(json.loads(line))
print(f"✅ {len(triplets)} triplets")
# Load corpus
print(f"Loading corpus from {config.corpus_dir}...")
corpus = {}
corpus_dir = Path(config.corpus_dir)
for json_file in corpus_dir.glob("*.json"):
with open(json_file) as f:
doc = json.load(f)
doc_id = doc.get("doc_id")
content = doc.get("content", "")
if doc_id and content:
corpus[doc_id] = content
print(f"✅ {len(corpus)} documents")
# Train/eval split
num_eval = int(len(triplets) * config.eval_split)
eval_triplets = triplets[:num_eval]
train_triplets = triplets[num_eval:]
print(f"\nSplit: {len(train_triplets)} train, {len(eval_triplets)} eval")
return train_triplets, eval_triplets, corpus
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to YAML config file (optional)")
parser.add_argument("--data-path", type=str, help="Override triplets path")
parser.add_argument("--output-dir", type=str, help="Override output directory")
parser.add_argument("--batch-size", type=int, help="Override batch size")
parser.add_argument("--epochs", type=int, help="Override number of epochs")
parser.add_argument("--no-wandb", action="store_true", help="Disable WandB logging")
parser.add_argument("--no-push", action="store_true", help="Disable Hub push")
args = parser.parse_args()
# Load config
config = EncoderTrainingConfig()
# Override from args
if args.data_path:
config.data_path = args.data_path
if args.output_dir:
config.output_dir = args.output_dir
if args.batch_size:
config.per_device_batch_size = args.batch_size
if args.epochs:
config.epochs = args.epochs
if args.no_push:
config.push_to_hub = False
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("=" * 80)
print("STABLEBRIDGE ENCODER LORA TRAINING")
print("=" * 80)
print(f"Device: {device}")
print(f"Base model: {config.base_model}")
print(f"LoRA rank: {config.lora_rank}, alpha: {config.lora_alpha}")
print(f"Max length: {config.max_length}")
print(f"Batch size: {config.per_device_batch_size} × {config.gradient_accumulation_steps} = {config.per_device_batch_size * config.gradient_accumulation_steps}")
print(f"Epochs: {config.epochs}")
print(f"Output: {config.output_dir}")
# Initialize WandB
use_wandb = not args.no_wandb and os.getenv("WANDB_API_KEY")
if use_wandb:
wandb.init(
project=config.wandb_project,
name=config.wandb_run_name or f"encoder-lora-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
config=vars(config)
)
# Load data
train_triplets, eval_triplets, corpus = load_data(config)
# Load model
print("\n" + "=" * 80)
print("MODEL SETUP")
print("=" * 80)
print("\nLoading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True, local_files_only=True)
# Determine dtype
if config.mixed_precision == "bf16" and torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
print("Using bfloat16 precision")
else:
dtype = torch.float16
print("Using float16 precision")
model = AutoModel.from_pretrained(
config.base_model,
torch_dtype=dtype,
trust_remote_code=True,
local_files_only=True
).to(device)
# Apply LoRA
print(f"\nApplying LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...")
lora_config = LoraConfig(
r=config.lora_rank,
lora_alpha=config.lora_alpha,
target_modules=config.target_modules,
lora_dropout=config.lora_dropout,
bias="none",
task_type=TaskType.FEATURE_EXTRACTION,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Create datasets
train_dataset = TripletDataset(train_triplets, corpus, tokenizer, config.max_length)
eval_dataset = eval_triplets # Will process differently in evaluation
# Create dataloader
train_loader = DataLoader(
train_dataset,
batch_size=config.per_device_batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Learning rate scheduler
num_training_steps = len(train_loader) * config.epochs // config.gradient_accumulation_steps
num_warmup_steps = int(num_training_steps * config.warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
# Gradient scaler for mixed precision
scaler = GradScaler() if dtype == torch.float16 else None
# Training
print("\n" + "=" * 80)
print("TRAINING")
print("=" * 80)
print(f"Total steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")
global_step = 0
best_ndcg = 0.0
for epoch in range(config.epochs):
print(f"\n{'='*80}")
print(f"EPOCH {epoch + 1}/{config.epochs}")
print(f"{'='*80}")
model.train()
epoch_loss = 0.0
optimizer.zero_grad()
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
for step, batch in enumerate(pbar):
# Move to device
query_ids = batch["query_input_ids"].to(device)
query_mask = batch["query_attention_mask"].to(device)
pos_ids = batch["pos_input_ids"].to(device)
pos_mask = batch["pos_attention_mask"].to(device)
# Forward pass with mixed precision
with autocast(dtype=dtype):
query_output = model(input_ids=query_ids, attention_mask=query_mask)
query_emb = mean_pooling(query_output, query_mask)
pos_output = model(input_ids=pos_ids, attention_mask=pos_mask)
pos_emb = mean_pooling(pos_output, pos_mask)
# Compute loss
loss = compute_loss(query_emb, pos_emb)
loss = loss / config.gradient_accumulation_steps
# Backward pass
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
epoch_loss += loss.item() * config.gradient_accumulation_steps
# Update weights
if (step + 1) % config.gradient_accumulation_steps == 0:
# Gradient clipping
if scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
# Optimizer step
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# Logging
if global_step % config.logging_steps == 0:
lr = scheduler.get_last_lr()[0]
pbar.set_postfix({
"loss": f"{loss.item() * config.gradient_accumulation_steps:.4f}",
"lr": f"{lr:.2e}"
})
if use_wandb:
wandb.log({
"train/loss": loss.item() * config.gradient_accumulation_steps,
"train/learning_rate": lr,
"train/epoch": epoch,
"train/step": global_step,
})
# Evaluation
if global_step % config.eval_steps == 0:
print("\n" + "-" * 80)
print(f"EVALUATION at step {global_step}")
print("-" * 80)
eval_metrics = evaluate_retrieval(
model, tokenizer, eval_dataset, corpus, device, config
)
print("\nEvaluation Results:")
for metric, value in eval_metrics.items():
print(f" {metric}: {value:.4f}")
if use_wandb:
wandb.log(eval_metrics)
# Save best model
if eval_metrics["eval/ndcg@10"] > best_ndcg:
best_ndcg = eval_metrics["eval/ndcg@10"]
print(f"\n✅ New best NDCG@10: {best_ndcg:.4f}")
best_model_dir = Path(config.output_dir) / "best"
best_model_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(best_model_dir)
tokenizer.save_pretrained(best_model_dir)
model.train()
print("-" * 80)
# Checkpointing
if global_step % config.save_steps == 0:
checkpoint_dir = Path(config.output_dir) / f"checkpoint-{global_step}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)
print(f"\n💾 Checkpoint saved: {checkpoint_dir}")
avg_loss = epoch_loss / len(train_loader)
print(f"\nEpoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
# Final save
print("\n" + "=" * 80)
print("SAVING FINAL MODEL")
print("=" * 80)
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"✅ Model saved to: {output_dir}")
# Push to Hub
if config.push_to_hub:
print("\n" + "=" * 80)
print("PUSHING TO HUGGING FACE HUB")
print("=" * 80)
try:
model.push_to_hub(
config.hub_model_id,
token=config.hub_token or os.getenv("HF_TOKEN")
)
tokenizer.push_to_hub(
config.hub_model_id,
token=config.hub_token or os.getenv("HF_TOKEN")
)
print(f"✅ Model pushed to: {config.hub_model_id}")
except Exception as e:
print(f"❌ Failed to push to Hub: {e}")
# Final evaluation
print("\n" + "=" * 80)
print("FINAL EVALUATION")
print("=" * 80)
final_metrics = evaluate_retrieval(
model, tokenizer, eval_dataset, corpus, device, config
)
print("\nFinal Results:")
for metric, value in final_metrics.items():
print(f" {metric}: {value:.4f}")
if use_wandb:
wandb.log({"final/" + k.split("/")[1]: v for k, v in final_metrics.items()})
wandb.finish()
print("\n" + "=" * 80)
print("TRAINING COMPLETE!")
print("=" * 80)
print(f"Best NDCG@10: {best_ndcg:.4f}")
print(f"Model saved to: {output_dir}")
if config.push_to_hub:
print(f"Hub: https://huggingface.co/{config.hub_model_id}")
if __name__ == "__main__":
main()