#!/usr/bin/env python3 """ Production Encoder LoRA Training for Stablebridge Trains LoRA adapters on BAAI/bge-m3 for US regulatory domain. Implements tech spec requirements: - LoRA rank 16, alpha 32 - 8192 token context window - MultipleNegativesRankingLoss (in-batch negatives) - WandB logging, checkpointing, evaluation - Model Hub push """ import argparse import json import os import torch import wandb from pathlib import Path from datetime import datetime from typing import Dict, List, Optional from dataclasses import dataclass, field from transformers import ( AutoTokenizer, AutoModel, get_cosine_schedule_with_warmup, TrainingArguments, ) from peft import LoraConfig, get_peft_model, TaskType from torch.utils.data import Dataset, DataLoader from torch.nn import functional as F from torch.cuda.amp import autocast, GradScaler import numpy as np from tqdm import tqdm @dataclass class EncoderTrainingConfig: """Complete training configuration matching tech spec.""" # Model base_model: str = "BAAI/bge-m3" max_length: int = 8192 # Full context per tech spec # LoRA lora_rank: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.1 target_modules: List[str] = field(default_factory=lambda: ["query", "key", "value"]) # Training epochs: int = 3 per_device_batch_size: int = 4 # RTX 6000 Ada - will adjust based on memory gradient_accumulation_steps: int = 16 # Effective batch size = 64 learning_rate: float = 5e-5 weight_decay: float = 0.01 warmup_ratio: float = 0.1 max_grad_norm: float = 1.0 # Precision mixed_precision: str = "bf16" # Will fall back to fp16 if needed # Checkpointing save_steps: int = 500 eval_steps: int = 500 logging_steps: int = 50 # Paths data_path: str = "/workspace/data/labels/encoder_triplets.jsonl" corpus_dir: str = "/workspace/data/raw" output_dir: str = "/workspace/checkpoints/bge-m3-us-regulatory-lora" # Monitoring wandb_project: str = "stablebridge-encoder" wandb_run_name: Optional[str] = None # Hub push_to_hub: bool = True hub_model_id: str = "cognilogue/bge-m3-us-regulatory-lora" hub_token: Optional[str] = None # Evaluation eval_split: float = 0.1 # Hold out 10% for validation eval_metrics: List[str] = field(default_factory=lambda: ["ndcg@10", "mrr@10", "recall@100"]) class TripletDataset(Dataset): """Dataset for encoder triplet training with in-batch negatives.""" def __init__( self, triplets: List[Dict], corpus: Dict[str, str], tokenizer, max_length: int = 8192 ): self.triplets = triplets self.corpus = corpus self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.triplets) def __getitem__(self, idx): triplet = self.triplets[idx] query = triplet["query"] pos_id = triplet["positive"] # Get positive document positive_text = self.corpus.get(pos_id, "") # Tokenize query_enc = self.tokenizer( query, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt" ) pos_enc = self.tokenizer( positive_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt" ) return { "query_input_ids": query_enc["input_ids"].squeeze(0), "query_attention_mask": query_enc["attention_mask"].squeeze(0), "pos_input_ids": pos_enc["input_ids"].squeeze(0), "pos_attention_mask": pos_enc["attention_mask"].squeeze(0), } def mean_pooling(model_output, attention_mask): """Mean pooling over token embeddings (ignore padding).""" token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def compute_loss(query_emb, pos_emb, temperature=0.05): """ Multiple Negatives Ranking Loss (InfoNCE). Uses in-batch negatives: all other positives in the batch serve as negatives. Standard approach in sentence-transformers contrastive learning. Args: query_emb: (batch_size, hidden_dim) - normalized query embeddings pos_emb: (batch_size, hidden_dim) - normalized positive embeddings temperature: Temperature for softmax (default 0.05) Returns: loss: Scalar loss value """ # Normalize embeddings query_emb = F.normalize(query_emb, p=2, dim=1) pos_emb = F.normalize(pos_emb, p=2, dim=1) # Compute similarity matrix: (batch_size, batch_size) # query_emb[i] @ pos_emb[j].T gives similarity between query i and doc j sim_matrix = torch.matmul(query_emb, pos_emb.T) / temperature # Labels: diagonal elements are positive pairs labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device) # Cross-entropy: pulls positives closer, pushes negatives away loss = F.cross_entropy(sim_matrix, labels) return loss def evaluate_retrieval(model, tokenizer, eval_data, corpus, device, config): """ Evaluate retrieval quality on validation set. Metrics: - NDCG@10: Ranking quality - MRR@10: Mean Reciprocal Rank - Recall@100: Coverage """ model.eval() # Encode all documents print("\nEncoding corpus for evaluation...") doc_ids = list(corpus.keys()) doc_embeddings = [] with torch.no_grad(): for doc_id in tqdm(doc_ids, desc="Encoding docs"): doc_text = corpus[doc_id] doc_enc = tokenizer( doc_text, max_length=config.max_length, truncation=True, padding="max_length", return_tensors="pt" ).to(device) doc_output = model(**doc_enc) doc_emb = mean_pooling(doc_output, doc_enc["attention_mask"]) doc_emb = F.normalize(doc_emb, p=2, dim=1) doc_embeddings.append(doc_emb.cpu()) doc_embeddings = torch.cat(doc_embeddings, dim=0) # (num_docs, hidden_dim) # Evaluate queries ndcg_scores = [] mrr_scores = [] recall_scores = [] with torch.no_grad(): for triplet in tqdm(eval_data, desc="Evaluating"): query = triplet["query"] pos_id = triplet["positive"] # Encode query query_enc = tokenizer( query, max_length=config.max_length, truncation=True, padding="max_length", return_tensors="pt" ).to(device) query_output = model(**query_enc) query_emb = mean_pooling(query_output, query_enc["attention_mask"]) query_emb = F.normalize(query_emb, p=2, dim=1) # Compute similarities similarities = torch.matmul(query_emb.cpu(), doc_embeddings.T).squeeze(0) # Rank documents ranks = torch.argsort(similarities, descending=True) # Find position of positive document try: pos_idx = doc_ids.index(pos_id) pos_rank = (ranks == pos_idx).nonzero(as_tuple=True)[0].item() + 1 except (ValueError, IndexError): pos_rank = len(doc_ids) + 1 # Not found # NDCG@10 if pos_rank <= 10: ndcg = 1.0 / np.log2(pos_rank + 1) else: ndcg = 0.0 ndcg_scores.append(ndcg) # MRR@10 if pos_rank <= 10: mrr = 1.0 / pos_rank else: mrr = 0.0 mrr_scores.append(mrr) # Recall@100 recall = 1.0 if pos_rank <= 100 else 0.0 recall_scores.append(recall) metrics = { "eval/ndcg@10": np.mean(ndcg_scores), "eval/mrr@10": np.mean(mrr_scores), "eval/recall@100": np.mean(recall_scores), } return metrics def load_data(config: EncoderTrainingConfig): """Load triplets and corpus, split train/eval.""" # Load triplets print(f"Loading triplets from {config.data_path}...") triplets = [] with open(config.data_path) as f: for line in f: if line.strip(): triplets.append(json.loads(line)) print(f"āœ… {len(triplets)} triplets") # Load corpus print(f"Loading corpus from {config.corpus_dir}...") corpus = {} corpus_dir = Path(config.corpus_dir) for json_file in corpus_dir.glob("*.json"): with open(json_file) as f: doc = json.load(f) doc_id = doc.get("doc_id") content = doc.get("content", "") if doc_id and content: corpus[doc_id] = content print(f"āœ… {len(corpus)} documents") # Train/eval split num_eval = int(len(triplets) * config.eval_split) eval_triplets = triplets[:num_eval] train_triplets = triplets[num_eval:] print(f"\nSplit: {len(train_triplets)} train, {len(eval_triplets)} eval") return train_triplets, eval_triplets, corpus def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="Path to YAML config file (optional)") parser.add_argument("--data-path", type=str, help="Override triplets path") parser.add_argument("--output-dir", type=str, help="Override output directory") parser.add_argument("--batch-size", type=int, help="Override batch size") parser.add_argument("--epochs", type=int, help="Override number of epochs") parser.add_argument("--no-wandb", action="store_true", help="Disable WandB logging") parser.add_argument("--no-push", action="store_true", help="Disable Hub push") args = parser.parse_args() # Load config config = EncoderTrainingConfig() # Override from args if args.data_path: config.data_path = args.data_path if args.output_dir: config.output_dir = args.output_dir if args.batch_size: config.per_device_batch_size = args.batch_size if args.epochs: config.epochs = args.epochs if args.no_push: config.push_to_hub = False # Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("=" * 80) print("STABLEBRIDGE ENCODER LORA TRAINING") print("=" * 80) print(f"Device: {device}") print(f"Base model: {config.base_model}") print(f"LoRA rank: {config.lora_rank}, alpha: {config.lora_alpha}") print(f"Max length: {config.max_length}") print(f"Batch size: {config.per_device_batch_size} Ɨ {config.gradient_accumulation_steps} = {config.per_device_batch_size * config.gradient_accumulation_steps}") print(f"Epochs: {config.epochs}") print(f"Output: {config.output_dir}") # Initialize WandB use_wandb = not args.no_wandb and os.getenv("WANDB_API_KEY") if use_wandb: wandb.init( project=config.wandb_project, name=config.wandb_run_name or f"encoder-lora-{datetime.now().strftime('%Y%m%d-%H%M%S')}", config=vars(config) ) # Load data train_triplets, eval_triplets, corpus = load_data(config) # Load model print("\n" + "=" * 80) print("MODEL SETUP") print("=" * 80) print("\nLoading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True, local_files_only=True) # Determine dtype if config.mixed_precision == "bf16" and torch.cuda.is_bf16_supported(): dtype = torch.bfloat16 print("Using bfloat16 precision") else: dtype = torch.float16 print("Using float16 precision") model = AutoModel.from_pretrained( config.base_model, torch_dtype=dtype, trust_remote_code=True, local_files_only=True ).to(device) # Apply LoRA print(f"\nApplying LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...") lora_config = LoraConfig( r=config.lora_rank, lora_alpha=config.lora_alpha, target_modules=config.target_modules, lora_dropout=config.lora_dropout, bias="none", task_type=TaskType.FEATURE_EXTRACTION, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Create datasets train_dataset = TripletDataset(train_triplets, corpus, tokenizer, config.max_length) eval_dataset = eval_triplets # Will process differently in evaluation # Create dataloader train_loader = DataLoader( train_dataset, batch_size=config.per_device_batch_size, shuffle=True, num_workers=4, pin_memory=True ) # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay ) # Learning rate scheduler num_training_steps = len(train_loader) * config.epochs // config.gradient_accumulation_steps num_warmup_steps = int(num_training_steps * config.warmup_ratio) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) # Gradient scaler for mixed precision scaler = GradScaler() if dtype == torch.float16 else None # Training print("\n" + "=" * 80) print("TRAINING") print("=" * 80) print(f"Total steps: {num_training_steps}") print(f"Warmup steps: {num_warmup_steps}") global_step = 0 best_ndcg = 0.0 for epoch in range(config.epochs): print(f"\n{'='*80}") print(f"EPOCH {epoch + 1}/{config.epochs}") print(f"{'='*80}") model.train() epoch_loss = 0.0 optimizer.zero_grad() pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") for step, batch in enumerate(pbar): # Move to device query_ids = batch["query_input_ids"].to(device) query_mask = batch["query_attention_mask"].to(device) pos_ids = batch["pos_input_ids"].to(device) pos_mask = batch["pos_attention_mask"].to(device) # Forward pass with mixed precision with autocast(dtype=dtype): query_output = model(input_ids=query_ids, attention_mask=query_mask) query_emb = mean_pooling(query_output, query_mask) pos_output = model(input_ids=pos_ids, attention_mask=pos_mask) pos_emb = mean_pooling(pos_output, pos_mask) # Compute loss loss = compute_loss(query_emb, pos_emb) loss = loss / config.gradient_accumulation_steps # Backward pass if scaler: scaler.scale(loss).backward() else: loss.backward() epoch_loss += loss.item() * config.gradient_accumulation_steps # Update weights if (step + 1) % config.gradient_accumulation_steps == 0: # Gradient clipping if scaler: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # Optimizer step if scaler: scaler.step(optimizer) scaler.update() else: optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # Logging if global_step % config.logging_steps == 0: lr = scheduler.get_last_lr()[0] pbar.set_postfix({ "loss": f"{loss.item() * config.gradient_accumulation_steps:.4f}", "lr": f"{lr:.2e}" }) if use_wandb: wandb.log({ "train/loss": loss.item() * config.gradient_accumulation_steps, "train/learning_rate": lr, "train/epoch": epoch, "train/step": global_step, }) # Evaluation if global_step % config.eval_steps == 0: print("\n" + "-" * 80) print(f"EVALUATION at step {global_step}") print("-" * 80) eval_metrics = evaluate_retrieval( model, tokenizer, eval_dataset, corpus, device, config ) print("\nEvaluation Results:") for metric, value in eval_metrics.items(): print(f" {metric}: {value:.4f}") if use_wandb: wandb.log(eval_metrics) # Save best model if eval_metrics["eval/ndcg@10"] > best_ndcg: best_ndcg = eval_metrics["eval/ndcg@10"] print(f"\nāœ… New best NDCG@10: {best_ndcg:.4f}") best_model_dir = Path(config.output_dir) / "best" best_model_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(best_model_dir) tokenizer.save_pretrained(best_model_dir) model.train() print("-" * 80) # Checkpointing if global_step % config.save_steps == 0: checkpoint_dir = Path(config.output_dir) / f"checkpoint-{global_step}" checkpoint_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(checkpoint_dir) tokenizer.save_pretrained(checkpoint_dir) print(f"\nšŸ’¾ Checkpoint saved: {checkpoint_dir}") avg_loss = epoch_loss / len(train_loader) print(f"\nEpoch {epoch + 1} - Average Loss: {avg_loss:.4f}") # Final save print("\n" + "=" * 80) print("SAVING FINAL MODEL") print("=" * 80) output_dir = Path(config.output_dir) output_dir.mkdir(parents=True, exist_ok=True) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print(f"āœ… Model saved to: {output_dir}") # Push to Hub if config.push_to_hub: print("\n" + "=" * 80) print("PUSHING TO HUGGING FACE HUB") print("=" * 80) try: model.push_to_hub( config.hub_model_id, token=config.hub_token or os.getenv("HF_TOKEN") ) tokenizer.push_to_hub( config.hub_model_id, token=config.hub_token or os.getenv("HF_TOKEN") ) print(f"āœ… Model pushed to: {config.hub_model_id}") except Exception as e: print(f"āŒ Failed to push to Hub: {e}") # Final evaluation print("\n" + "=" * 80) print("FINAL EVALUATION") print("=" * 80) final_metrics = evaluate_retrieval( model, tokenizer, eval_dataset, corpus, device, config ) print("\nFinal Results:") for metric, value in final_metrics.items(): print(f" {metric}: {value:.4f}") if use_wandb: wandb.log({"final/" + k.split("/")[1]: v for k, v in final_metrics.items()}) wandb.finish() print("\n" + "=" * 80) print("TRAINING COMPLETE!") print("=" * 80) print(f"Best NDCG@10: {best_ndcg:.4f}") print(f"Model saved to: {output_dir}") if config.push_to_hub: print(f"Hub: https://huggingface.co/{config.hub_model_id}") if __name__ == "__main__": main()