--- license: mit language: - en library_name: transformers tags: - lora - peft - reinforcement-learning - domain-adaptation - sentence-embeddings - curriculum-learning - multi-task-learning - rag - information-retrieval - cross-domain - sentence-transformers base_model: - sentence-transformers/all-MiniLM-L6-v2 - EphAsad/FireDevourerEmbedder-RL-v3.6 pipeline_tag: sentence-similarity datasets: - sentence-transformers/stsb - nyu-mll/multi_nli - quora - google-research-datasets/paws - nyu-mll/glue - GBaker/MedQA-USMLE-4-options-hf - lex_glue - gbharti/finance-alpaca - scientific_papers model-index: - name: DomainEmbedder-v2.6 results: - task: type: domain-classification name: Domain Classification metrics: - type: accuracy value: 0.925 name: Training Accuracy - type: accuracy value: 0.56 name: Stress-Test Accuracy --- # DomainEmbedder-v2.6 > **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval** DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing. ## What This Model Does | Component | Description | |-----------|-------------| | **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting | | **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) | | **RL Policy** | Automatically selects the optimal domain adapter for any input | **Why this matters for RAG/Retrieval:** - Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions) - Domain routing provides context-appropriate representations - Results in more precise retrieval across diverse content types ## Key Innovation: Dual RL Architecture | Stage | RL Application | Purpose | |-------|---------------|---------| | Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training | | Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference | This dual RL approach is novel: **RL at training time AND inference time**. ## Quick Start ### Installation ```bash pip install torch transformers peft ``` ### Loading the Model ```python import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel from peft import PeftModel # Device setup device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load tokenizer tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') # Define the base embedder architecture class FireDevourerEmbedder(nn.Module): def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'): super().__init__() self.encoder = AutoModel.from_pretrained(base_model_name) self.hidden_size = 384 # Task heads self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid()) self.nli_head = nn.Linear(384, 3) self.qqp_head = nn.Linear(384, 2) self.paws_head = nn.Linear(384, 2) self.domain_head = nn.Linear(384, 5) def mean_pool(self, token_embeddings, attention_mask): mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) def forward(self, input_ids, attention_mask, task='encode'): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) embedding = self.mean_pool(outputs.last_hidden_state, attention_mask) if task == 'encode': return embedding elif task == 'domain': return self.domain_head(embedding) # Add other tasks as needed # Define RL Policy Network class RLPolicyNetwork(nn.Module): def __init__(self, input_dim=384, hidden_dim=128, num_actions=5): super().__init__() self.network = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU() ) self.policy_head = nn.Linear(hidden_dim, num_actions) self.value_head = nn.Linear(hidden_dim, 1) def forward(self, x): features = self.network(x) policy = torch.softmax(self.policy_head(features), dim=-1) value = self.value_head(features) return policy, value # Load model model_dir = "path/to/DomainEmbedder-v2.6" # 1. Load base model with checkpoint base_model = FireDevourerEmbedder() checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device) base_model.load_state_dict(checkpoint['model_state_dict'], strict=False) base_model.to(device) base_model.eval() # 2. Load RL policy rl_policy = RLPolicyNetwork() rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device) rl_policy.load_state_dict(rl_checkpoint['policy_state_dict']) rl_policy.to(device) rl_policy.eval() # 3. Load LoRA adapters (example: medical) from peft import PeftModel lora_model = PeftModel.from_pretrained( base_model.encoder, f"{model_dir}/medical_lora" ) ``` ### Computing Embeddings with Domain Selection ```python def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device): """Get domain-aware embedding for input text.""" # Tokenize inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) # Get base embedding with torch.no_grad(): base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode') # Get domain selection from RL policy policy_probs, _ = rl_policy(base_emb) domain_idx = torch.argmax(policy_probs, dim=-1).item() domains = ['medical', 'legal', 'code', 'finance', 'scientific'] selected_domain = domains[domain_idx] confidence = policy_probs[0, domain_idx].item() return { 'embedding': base_emb, 'domain': selected_domain, 'confidence': confidence, 'all_probs': policy_probs[0].cpu().numpy() } # Example usage result = get_domain_embedding( "What are the symptoms of diabetes?", base_model, rl_policy, None, tokenizer, device ) print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})") ``` ## Architecture ``` Input Text │ ▼ ┌────────────────────────────────────────────┐ │ MiniLM-L6-v2 Encoder (FROZEN) │ │ + Optional LoRA Adapter (domain-specific) │ │ 384-dimensional output │ └────────────────────────────────────────────┘ │ ├──────────────────────────────────────────┐ │ │ ▼ ▼ ┌─────────────────┐ ┌──────────────────┐ │ Base Embedding │ │ RL Policy Net │ │ (384-dim) │ │ (66K params) │ └─────────────────┘ └──────────────────┘ │ ▼ Domain Selection [Medical, Legal, Code, Finance, Scientific] │ ▼ Load corresponding LoRA adapter │ ▼ Domain-Adapted Embedding ``` ### Component Details | Component | Specification | |-----------|---------------| | Base Encoder | MiniLM-L6-v2 (22M params) | | Embedding Dim | 384 | | LoRA Rank | 16 | | LoRA Alpha | 32 | | LoRA Target | Query, Value projections | | LoRA Params | 147,456 per adapter (0.645%) | | RL Policy | 66,566 params | | Domains | Medical, Legal, Code, Finance, Scientific | ## Performance ### Base Model: Multi-Task Embedding Quality The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks: | Task | Dataset | Score | What It Measures | |------|---------|-------|------------------| | Question Similarity | QQP | 0.8636 | Intent matching | | Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness | | Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase | | NLI | MultiNLI | 0.7465 | Logical relationships | | Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity | | **Average** | | **0.7134** | **Cross-task capability** | **Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content. ### Domain Routing Accuracy **Training Results (In-Distribution)** | Metric | Value | |--------|-------| | Domain Accuracy | 92.5% | | Average Reward | 1.527 | | Training Steps | 5,000 | **Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)** The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness: | Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement | |--------|--------------------------|------------|-------------| | Domain Accuracy | 56.0% | 20.4% | **+35.6%** | | Avg Confidence | 28.5% | 77.6% | More calibrated | ### Per-Domain Breakdown | Domain | DomainEmbedder | Base Model | Note | |--------|----------------|------------|------| | Finance | 78.0% | 0.0% | +78.0% | | Medical | 73.0% | 0.0% | +73.0% | | Legal | 53.0% | 15.0% | +38.0% | | Scientific | 48.0% | 1.0% | +47.0% | | Code | 28.0% | 86.0% | Base over-predicted code | **Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution. ## Training Details ### Domain Training Data | Domain | Samples | Sources | |--------|---------|---------| | Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor | | Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B | | Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions | | Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA | | Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) | | **Total** | **200,000** | | ### LoRA Training Configuration | Parameter | Value | |-----------|-------| | Epochs | 3 per domain | | Batch Size | 32 | | Learning Rate | 2e-4 | | Loss | Contrastive (InfoNCE-style) | | Trainable Params | 147,456 (0.645% of base) | | Warmup Steps | 500 | | Max Gradient Norm | 1.0 | ### RL Training (Supervised A2C) | Parameter | Value | |-----------|-------| | Algorithm | Actor-Critic (A2C) | | Total Steps | 5,000 | | Episodes per Step | 5 | | Gamma (discount) | 0.99 | | Entropy Coef | 0.1 (high exploration) | | Value Coef | 0.5 | | Correctness Bonus | +1.0 | | Correctness Penalty | -0.5 | | Baseline Decay | 0.99 | ### Curriculum Learning Phases | Phase | Steps | Data | Accuracy | |-------|-------|------|----------| | 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% → 87.5% | | 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% → 89.3% | | 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% → 92.5% | ### Training Progress | Version | Step | Accuracy | Reward | |---------|------|----------|--------| | v2.1 | 500 | 68.8% | 1.100 | | v2.2 | 1,000 | 80.1% | 1.336 | | v2.3 | 1,500 | 87.5% | 1.454 | | v2.4 | 2,000 | 88.9% | 1.480 | | v2.5 | 3,000 | 89.3% | 1.507 | | **v2.6** | **4,000** | **92.5%** | **1.527** | ## Package Contents ``` DomainEmbedder-v2.6/ ├── FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB) ├── rl_policy.pt # Trained RL policy (0.27 MB) ├── metadata.json # Training metadata ├── README.md # This file ├── medical_lora/ # Medical domain adapter (0.6 MB) │ ├── adapter_config.json │ └── adapter_model.safetensors ├── legal_lora/ # Legal domain adapter (0.6 MB) ├── code_lora/ # Code domain adapter (0.6 MB) ├── finance_lora/ # Finance domain adapter (0.6 MB) └── scientific_lora/ # Scientific domain adapter (0.6 MB) ``` **Total Size**: ~90 MB (self-contained) ## Intended Use ### Best Use Cases - **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases - **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains - **Document Classification**: Automatic domain routing for document processing pipelines - **Semantic Similarity**: Information-dense embeddings for precise matching - **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain ### Limitations - **English Only**: Trained exclusively on English data - **Max Length**: 512 tokens maximum input length - **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific) - **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries - **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage ## Citation ```bibtex @misc{domainembedder2025, author = {Asad, Zain}, title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA}, year = {2025}, publisher = {Hugging Face}, note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning} } ``` ## Author **Zain Asad** ## License MIT License