| | |
| | """ |
| | Production Encoder LoRA Training for Stablebridge |
| | |
| | Trains LoRA adapters on BAAI/bge-m3 for US regulatory domain. |
| | Implements tech spec requirements: |
| | - LoRA rank 16, alpha 32 |
| | - 8192 token context window |
| | - MultipleNegativesRankingLoss (in-batch negatives) |
| | - WandB logging, checkpointing, evaluation |
| | - Model Hub push |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import torch |
| | import wandb |
| | from pathlib import Path |
| | from datetime import datetime |
| | from typing import Dict, List, Optional |
| | from dataclasses import dataclass, field |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModel, |
| | get_cosine_schedule_with_warmup, |
| | TrainingArguments, |
| | ) |
| | from peft import LoraConfig, get_peft_model, TaskType |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.nn import functional as F |
| | from torch.cuda.amp import autocast, GradScaler |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| |
|
| | @dataclass |
| | class EncoderTrainingConfig: |
| | """Complete training configuration matching tech spec.""" |
| | |
| | |
| | base_model: str = "BAAI/bge-m3" |
| | max_length: int = 8192 |
| | |
| | |
| | lora_rank: int = 16 |
| | lora_alpha: int = 32 |
| | lora_dropout: float = 0.1 |
| | target_modules: List[str] = field(default_factory=lambda: ["query", "key", "value"]) |
| | |
| | |
| | epochs: int = 3 |
| | per_device_batch_size: int = 4 |
| | gradient_accumulation_steps: int = 16 |
| | learning_rate: float = 5e-5 |
| | weight_decay: float = 0.01 |
| | warmup_ratio: float = 0.1 |
| | max_grad_norm: float = 1.0 |
| | |
| | |
| | mixed_precision: str = "bf16" |
| | |
| | |
| | save_steps: int = 500 |
| | eval_steps: int = 500 |
| | logging_steps: int = 50 |
| | |
| | |
| | data_path: str = "/workspace/data/labels/encoder_triplets.jsonl" |
| | corpus_dir: str = "/workspace/data/raw" |
| | output_dir: str = "/workspace/checkpoints/bge-m3-us-regulatory-lora" |
| | |
| | |
| | wandb_project: str = "stablebridge-encoder" |
| | wandb_run_name: Optional[str] = None |
| | |
| | |
| | push_to_hub: bool = True |
| | hub_model_id: str = "cognilogue/bge-m3-us-regulatory-lora" |
| | hub_token: Optional[str] = None |
| | |
| | |
| | eval_split: float = 0.1 |
| | eval_metrics: List[str] = field(default_factory=lambda: ["ndcg@10", "mrr@10", "recall@100"]) |
| |
|
| |
|
| | class TripletDataset(Dataset): |
| | """Dataset for encoder triplet training with in-batch negatives.""" |
| | |
| | def __init__( |
| | self, |
| | triplets: List[Dict], |
| | corpus: Dict[str, str], |
| | tokenizer, |
| | max_length: int = 8192 |
| | ): |
| | self.triplets = triplets |
| | self.corpus = corpus |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | |
| | def __len__(self): |
| | return len(self.triplets) |
| | |
| | def __getitem__(self, idx): |
| | triplet = self.triplets[idx] |
| | query = triplet["query"] |
| | pos_id = triplet["positive"] |
| | |
| | |
| | positive_text = self.corpus.get(pos_id, "") |
| | |
| | |
| | query_enc = self.tokenizer( |
| | query, |
| | max_length=self.max_length, |
| | truncation=True, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ) |
| | |
| | pos_enc = self.tokenizer( |
| | positive_text, |
| | max_length=self.max_length, |
| | truncation=True, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ) |
| | |
| | return { |
| | "query_input_ids": query_enc["input_ids"].squeeze(0), |
| | "query_attention_mask": query_enc["attention_mask"].squeeze(0), |
| | "pos_input_ids": pos_enc["input_ids"].squeeze(0), |
| | "pos_attention_mask": pos_enc["attention_mask"].squeeze(0), |
| | } |
| |
|
| |
|
| | def mean_pooling(model_output, attention_mask): |
| | """Mean pooling over token embeddings (ignore padding).""" |
| | token_embeddings = model_output[0] |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| | input_mask_expanded.sum(1), min=1e-9 |
| | ) |
| |
|
| |
|
| | def compute_loss(query_emb, pos_emb, temperature=0.05): |
| | """ |
| | Multiple Negatives Ranking Loss (InfoNCE). |
| | |
| | Uses in-batch negatives: all other positives in the batch serve as negatives. |
| | Standard approach in sentence-transformers contrastive learning. |
| | |
| | Args: |
| | query_emb: (batch_size, hidden_dim) - normalized query embeddings |
| | pos_emb: (batch_size, hidden_dim) - normalized positive embeddings |
| | temperature: Temperature for softmax (default 0.05) |
| | |
| | Returns: |
| | loss: Scalar loss value |
| | """ |
| | |
| | query_emb = F.normalize(query_emb, p=2, dim=1) |
| | pos_emb = F.normalize(pos_emb, p=2, dim=1) |
| | |
| | |
| | |
| | sim_matrix = torch.matmul(query_emb, pos_emb.T) / temperature |
| | |
| | |
| | labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device) |
| | |
| | |
| | loss = F.cross_entropy(sim_matrix, labels) |
| | |
| | return loss |
| |
|
| |
|
| | def evaluate_retrieval(model, tokenizer, eval_data, corpus, device, config): |
| | """ |
| | Evaluate retrieval quality on validation set. |
| | |
| | Metrics: |
| | - NDCG@10: Ranking quality |
| | - MRR@10: Mean Reciprocal Rank |
| | - Recall@100: Coverage |
| | """ |
| | model.eval() |
| | |
| | |
| | print("\nEncoding corpus for evaluation...") |
| | doc_ids = list(corpus.keys()) |
| | doc_embeddings = [] |
| | |
| | with torch.no_grad(): |
| | for doc_id in tqdm(doc_ids, desc="Encoding docs"): |
| | doc_text = corpus[doc_id] |
| | doc_enc = tokenizer( |
| | doc_text, |
| | max_length=config.max_length, |
| | truncation=True, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ).to(device) |
| | |
| | doc_output = model(**doc_enc) |
| | doc_emb = mean_pooling(doc_output, doc_enc["attention_mask"]) |
| | doc_emb = F.normalize(doc_emb, p=2, dim=1) |
| | doc_embeddings.append(doc_emb.cpu()) |
| | |
| | doc_embeddings = torch.cat(doc_embeddings, dim=0) |
| | |
| | |
| | ndcg_scores = [] |
| | mrr_scores = [] |
| | recall_scores = [] |
| | |
| | with torch.no_grad(): |
| | for triplet in tqdm(eval_data, desc="Evaluating"): |
| | query = triplet["query"] |
| | pos_id = triplet["positive"] |
| | |
| | |
| | query_enc = tokenizer( |
| | query, |
| | max_length=config.max_length, |
| | truncation=True, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ).to(device) |
| | |
| | query_output = model(**query_enc) |
| | query_emb = mean_pooling(query_output, query_enc["attention_mask"]) |
| | query_emb = F.normalize(query_emb, p=2, dim=1) |
| | |
| | |
| | similarities = torch.matmul(query_emb.cpu(), doc_embeddings.T).squeeze(0) |
| | |
| | |
| | ranks = torch.argsort(similarities, descending=True) |
| | |
| | |
| | try: |
| | pos_idx = doc_ids.index(pos_id) |
| | pos_rank = (ranks == pos_idx).nonzero(as_tuple=True)[0].item() + 1 |
| | except (ValueError, IndexError): |
| | pos_rank = len(doc_ids) + 1 |
| | |
| | |
| | if pos_rank <= 10: |
| | ndcg = 1.0 / np.log2(pos_rank + 1) |
| | else: |
| | ndcg = 0.0 |
| | ndcg_scores.append(ndcg) |
| | |
| | |
| | if pos_rank <= 10: |
| | mrr = 1.0 / pos_rank |
| | else: |
| | mrr = 0.0 |
| | mrr_scores.append(mrr) |
| | |
| | |
| | recall = 1.0 if pos_rank <= 100 else 0.0 |
| | recall_scores.append(recall) |
| | |
| | metrics = { |
| | "eval/ndcg@10": np.mean(ndcg_scores), |
| | "eval/mrr@10": np.mean(mrr_scores), |
| | "eval/recall@100": np.mean(recall_scores), |
| | } |
| | |
| | return metrics |
| |
|
| |
|
| | def load_data(config: EncoderTrainingConfig): |
| | """Load triplets and corpus, split train/eval.""" |
| | |
| | |
| | print(f"Loading triplets from {config.data_path}...") |
| | triplets = [] |
| | with open(config.data_path) as f: |
| | for line in f: |
| | if line.strip(): |
| | triplets.append(json.loads(line)) |
| | print(f"✅ {len(triplets)} triplets") |
| | |
| | |
| | print(f"Loading corpus from {config.corpus_dir}...") |
| | corpus = {} |
| | corpus_dir = Path(config.corpus_dir) |
| | for json_file in corpus_dir.glob("*.json"): |
| | with open(json_file) as f: |
| | doc = json.load(f) |
| | doc_id = doc.get("doc_id") |
| | content = doc.get("content", "") |
| | if doc_id and content: |
| | corpus[doc_id] = content |
| | print(f"✅ {len(corpus)} documents") |
| | |
| | |
| | num_eval = int(len(triplets) * config.eval_split) |
| | eval_triplets = triplets[:num_eval] |
| | train_triplets = triplets[num_eval:] |
| | |
| | print(f"\nSplit: {len(train_triplets)} train, {len(eval_triplets)} eval") |
| | |
| | return train_triplets, eval_triplets, corpus |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, help="Path to YAML config file (optional)") |
| | parser.add_argument("--data-path", type=str, help="Override triplets path") |
| | parser.add_argument("--output-dir", type=str, help="Override output directory") |
| | parser.add_argument("--batch-size", type=int, help="Override batch size") |
| | parser.add_argument("--epochs", type=int, help="Override number of epochs") |
| | parser.add_argument("--no-wandb", action="store_true", help="Disable WandB logging") |
| | parser.add_argument("--no-push", action="store_true", help="Disable Hub push") |
| | args = parser.parse_args() |
| | |
| | |
| | config = EncoderTrainingConfig() |
| | |
| | |
| | if args.data_path: |
| | config.data_path = args.data_path |
| | if args.output_dir: |
| | config.output_dir = args.output_dir |
| | if args.batch_size: |
| | config.per_device_batch_size = args.batch_size |
| | if args.epochs: |
| | config.epochs = args.epochs |
| | if args.no_push: |
| | config.push_to_hub = False |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("=" * 80) |
| | print("STABLEBRIDGE ENCODER LORA TRAINING") |
| | print("=" * 80) |
| | print(f"Device: {device}") |
| | print(f"Base model: {config.base_model}") |
| | print(f"LoRA rank: {config.lora_rank}, alpha: {config.lora_alpha}") |
| | print(f"Max length: {config.max_length}") |
| | print(f"Batch size: {config.per_device_batch_size} × {config.gradient_accumulation_steps} = {config.per_device_batch_size * config.gradient_accumulation_steps}") |
| | print(f"Epochs: {config.epochs}") |
| | print(f"Output: {config.output_dir}") |
| | |
| | |
| | use_wandb = not args.no_wandb and os.getenv("WANDB_API_KEY") |
| | if use_wandb: |
| | wandb.init( |
| | project=config.wandb_project, |
| | name=config.wandb_run_name or f"encoder-lora-{datetime.now().strftime('%Y%m%d-%H%M%S')}", |
| | config=vars(config) |
| | ) |
| | |
| | |
| | train_triplets, eval_triplets, corpus = load_data(config) |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("MODEL SETUP") |
| | print("=" * 80) |
| | |
| | print("\nLoading tokenizer and model...") |
| | tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True, local_files_only=True) |
| | |
| | |
| | if config.mixed_precision == "bf16" and torch.cuda.is_bf16_supported(): |
| | dtype = torch.bfloat16 |
| | print("Using bfloat16 precision") |
| | else: |
| | dtype = torch.float16 |
| | print("Using float16 precision") |
| | |
| | model = AutoModel.from_pretrained( |
| | config.base_model, |
| | torch_dtype=dtype, |
| | trust_remote_code=True, |
| | local_files_only=True |
| | ).to(device) |
| | |
| | |
| | print(f"\nApplying LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...") |
| | lora_config = LoraConfig( |
| | r=config.lora_rank, |
| | lora_alpha=config.lora_alpha, |
| | target_modules=config.target_modules, |
| | lora_dropout=config.lora_dropout, |
| | bias="none", |
| | task_type=TaskType.FEATURE_EXTRACTION, |
| | ) |
| | model = get_peft_model(model, lora_config) |
| | model.print_trainable_parameters() |
| | |
| | |
| | train_dataset = TripletDataset(train_triplets, corpus, tokenizer, config.max_length) |
| | eval_dataset = eval_triplets |
| | |
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=config.per_device_batch_size, |
| | shuffle=True, |
| | num_workers=4, |
| | pin_memory=True |
| | ) |
| | |
| | |
| | optimizer = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=config.learning_rate, |
| | weight_decay=config.weight_decay |
| | ) |
| | |
| | |
| | num_training_steps = len(train_loader) * config.epochs // config.gradient_accumulation_steps |
| | num_warmup_steps = int(num_training_steps * config.warmup_ratio) |
| | |
| | scheduler = get_cosine_schedule_with_warmup( |
| | optimizer, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps |
| | ) |
| | |
| | |
| | scaler = GradScaler() if dtype == torch.float16 else None |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("TRAINING") |
| | print("=" * 80) |
| | print(f"Total steps: {num_training_steps}") |
| | print(f"Warmup steps: {num_warmup_steps}") |
| | |
| | global_step = 0 |
| | best_ndcg = 0.0 |
| | |
| | for epoch in range(config.epochs): |
| | print(f"\n{'='*80}") |
| | print(f"EPOCH {epoch + 1}/{config.epochs}") |
| | print(f"{'='*80}") |
| | |
| | model.train() |
| | epoch_loss = 0.0 |
| | optimizer.zero_grad() |
| | |
| | pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") |
| | for step, batch in enumerate(pbar): |
| | |
| | query_ids = batch["query_input_ids"].to(device) |
| | query_mask = batch["query_attention_mask"].to(device) |
| | pos_ids = batch["pos_input_ids"].to(device) |
| | pos_mask = batch["pos_attention_mask"].to(device) |
| | |
| | |
| | with autocast(dtype=dtype): |
| | query_output = model(input_ids=query_ids, attention_mask=query_mask) |
| | query_emb = mean_pooling(query_output, query_mask) |
| | |
| | pos_output = model(input_ids=pos_ids, attention_mask=pos_mask) |
| | pos_emb = mean_pooling(pos_output, pos_mask) |
| | |
| | |
| | loss = compute_loss(query_emb, pos_emb) |
| | loss = loss / config.gradient_accumulation_steps |
| | |
| | |
| | if scaler: |
| | scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| | |
| | epoch_loss += loss.item() * config.gradient_accumulation_steps |
| | |
| | |
| | if (step + 1) % config.gradient_accumulation_steps == 0: |
| | |
| | if scaler: |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| | |
| | |
| | if scaler: |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | optimizer.step() |
| | |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | global_step += 1 |
| | |
| | |
| | if global_step % config.logging_steps == 0: |
| | lr = scheduler.get_last_lr()[0] |
| | pbar.set_postfix({ |
| | "loss": f"{loss.item() * config.gradient_accumulation_steps:.4f}", |
| | "lr": f"{lr:.2e}" |
| | }) |
| | |
| | if use_wandb: |
| | wandb.log({ |
| | "train/loss": loss.item() * config.gradient_accumulation_steps, |
| | "train/learning_rate": lr, |
| | "train/epoch": epoch, |
| | "train/step": global_step, |
| | }) |
| | |
| | |
| | if global_step % config.eval_steps == 0: |
| | print("\n" + "-" * 80) |
| | print(f"EVALUATION at step {global_step}") |
| | print("-" * 80) |
| | |
| | eval_metrics = evaluate_retrieval( |
| | model, tokenizer, eval_dataset, corpus, device, config |
| | ) |
| | |
| | print("\nEvaluation Results:") |
| | for metric, value in eval_metrics.items(): |
| | print(f" {metric}: {value:.4f}") |
| | |
| | if use_wandb: |
| | wandb.log(eval_metrics) |
| | |
| | |
| | if eval_metrics["eval/ndcg@10"] > best_ndcg: |
| | best_ndcg = eval_metrics["eval/ndcg@10"] |
| | print(f"\n✅ New best NDCG@10: {best_ndcg:.4f}") |
| | best_model_dir = Path(config.output_dir) / "best" |
| | best_model_dir.mkdir(parents=True, exist_ok=True) |
| | model.save_pretrained(best_model_dir) |
| | tokenizer.save_pretrained(best_model_dir) |
| | |
| | model.train() |
| | print("-" * 80) |
| | |
| | |
| | if global_step % config.save_steps == 0: |
| | checkpoint_dir = Path(config.output_dir) / f"checkpoint-{global_step}" |
| | checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| | model.save_pretrained(checkpoint_dir) |
| | tokenizer.save_pretrained(checkpoint_dir) |
| | print(f"\n💾 Checkpoint saved: {checkpoint_dir}") |
| | |
| | avg_loss = epoch_loss / len(train_loader) |
| | print(f"\nEpoch {epoch + 1} - Average Loss: {avg_loss:.4f}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("SAVING FINAL MODEL") |
| | print("=" * 80) |
| | |
| | output_dir = Path(config.output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | model.save_pretrained(output_dir) |
| | tokenizer.save_pretrained(output_dir) |
| | print(f"✅ Model saved to: {output_dir}") |
| | |
| | |
| | if config.push_to_hub: |
| | print("\n" + "=" * 80) |
| | print("PUSHING TO HUGGING FACE HUB") |
| | print("=" * 80) |
| | |
| | try: |
| | model.push_to_hub( |
| | config.hub_model_id, |
| | token=config.hub_token or os.getenv("HF_TOKEN") |
| | ) |
| | tokenizer.push_to_hub( |
| | config.hub_model_id, |
| | token=config.hub_token or os.getenv("HF_TOKEN") |
| | ) |
| | print(f"✅ Model pushed to: {config.hub_model_id}") |
| | except Exception as e: |
| | print(f"❌ Failed to push to Hub: {e}") |
| | |
| | |
| | print("\n" + "=" * 80) |
| | print("FINAL EVALUATION") |
| | print("=" * 80) |
| | |
| | final_metrics = evaluate_retrieval( |
| | model, tokenizer, eval_dataset, corpus, device, config |
| | ) |
| | |
| | print("\nFinal Results:") |
| | for metric, value in final_metrics.items(): |
| | print(f" {metric}: {value:.4f}") |
| | |
| | if use_wandb: |
| | wandb.log({"final/" + k.split("/")[1]: v for k, v in final_metrics.items()}) |
| | wandb.finish() |
| | |
| | print("\n" + "=" * 80) |
| | print("TRAINING COMPLETE!") |
| | print("=" * 80) |
| | print(f"Best NDCG@10: {best_ndcg:.4f}") |
| | print(f"Model saved to: {output_dir}") |
| | if config.push_to_hub: |
| | print(f"Hub: https://huggingface.co/{config.hub_model_id}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|