DomainEmbedder / README.md
EphAsad's picture
Update README.md
5508295 verified
---
license: mit
language:
- en
library_name: transformers
tags:
- lora
- peft
- reinforcement-learning
- domain-adaptation
- sentence-embeddings
- curriculum-learning
- multi-task-learning
- rag
- information-retrieval
- cross-domain
- sentence-transformers
base_model:
- sentence-transformers/all-MiniLM-L6-v2
- EphAsad/FireDevourerEmbedder-RL-v3.6
pipeline_tag: sentence-similarity
datasets:
- sentence-transformers/stsb
- nyu-mll/multi_nli
- quora
- google-research-datasets/paws
- nyu-mll/glue
- GBaker/MedQA-USMLE-4-options-hf
- lex_glue
- gbharti/finance-alpaca
- scientific_papers
model-index:
- name: DomainEmbedder-v2.6
results:
- task:
type: domain-classification
name: Domain Classification
metrics:
- type: accuracy
value: 0.925
name: Training Accuracy
- type: accuracy
value: 0.56
name: Stress-Test Accuracy
---
# DomainEmbedder-v2.6
> **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval**
DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing.
## What This Model Does
| Component | Description |
|-----------|-------------|
| **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting |
| **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) |
| **RL Policy** | Automatically selects the optimal domain adapter for any input |
**Why this matters for RAG/Retrieval:**
- Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions)
- Domain routing provides context-appropriate representations
- Results in more precise retrieval across diverse content types
## Key Innovation: Dual RL Architecture
| Stage | RL Application | Purpose |
|-------|---------------|---------|
| Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training |
| Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference |
This dual RL approach is novel: **RL at training time AND inference time**.
## Quick Start
### Installation
```bash
pip install torch transformers peft
```
### Loading the Model
```python
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# Define the base embedder architecture
class FireDevourerEmbedder(nn.Module):
def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model_name)
self.hidden_size = 384
# Task heads
self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid())
self.nli_head = nn.Linear(384, 3)
self.qqp_head = nn.Linear(384, 2)
self.paws_head = nn.Linear(384, 2)
self.domain_head = nn.Linear(384, 5)
def mean_pool(self, token_embeddings, attention_mask):
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
def forward(self, input_ids, attention_mask, task='encode'):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
embedding = self.mean_pool(outputs.last_hidden_state, attention_mask)
if task == 'encode':
return embedding
elif task == 'domain':
return self.domain_head(embedding)
# Add other tasks as needed
# Define RL Policy Network
class RLPolicyNetwork(nn.Module):
def __init__(self, input_dim=384, hidden_dim=128, num_actions=5):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.policy_head = nn.Linear(hidden_dim, num_actions)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, x):
features = self.network(x)
policy = torch.softmax(self.policy_head(features), dim=-1)
value = self.value_head(features)
return policy, value
# Load model
model_dir = "path/to/DomainEmbedder-v2.6"
# 1. Load base model with checkpoint
base_model = FireDevourerEmbedder()
checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device)
base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
base_model.to(device)
base_model.eval()
# 2. Load RL policy
rl_policy = RLPolicyNetwork()
rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device)
rl_policy.load_state_dict(rl_checkpoint['policy_state_dict'])
rl_policy.to(device)
rl_policy.eval()
# 3. Load LoRA adapters (example: medical)
from peft import PeftModel
lora_model = PeftModel.from_pretrained(
base_model.encoder,
f"{model_dir}/medical_lora"
)
```
### Computing Embeddings with Domain Selection
```python
def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device):
"""Get domain-aware embedding for input text."""
# Tokenize
inputs = tokenizer(text, return_tensors='pt', padding=True,
truncation=True, max_length=512).to(device)
# Get base embedding
with torch.no_grad():
base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode')
# Get domain selection from RL policy
policy_probs, _ = rl_policy(base_emb)
domain_idx = torch.argmax(policy_probs, dim=-1).item()
domains = ['medical', 'legal', 'code', 'finance', 'scientific']
selected_domain = domains[domain_idx]
confidence = policy_probs[0, domain_idx].item()
return {
'embedding': base_emb,
'domain': selected_domain,
'confidence': confidence,
'all_probs': policy_probs[0].cpu().numpy()
}
# Example usage
result = get_domain_embedding(
"What are the symptoms of diabetes?",
base_model, rl_policy, None, tokenizer, device
)
print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})")
```
## Architecture
```
Input Text
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ MiniLM-L6-v2 Encoder (FROZEN) β”‚
β”‚ + Optional LoRA Adapter (domain-specific) β”‚
β”‚ 384-dimensional output β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ β”‚
β–Ό β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Base Embedding β”‚ β”‚ RL Policy Net β”‚
β”‚ (384-dim) β”‚ β”‚ (66K params) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
Domain Selection
[Medical, Legal, Code,
Finance, Scientific]
β”‚
β–Ό
Load corresponding LoRA adapter
β”‚
β–Ό
Domain-Adapted Embedding
```
### Component Details
| Component | Specification |
|-----------|---------------|
| Base Encoder | MiniLM-L6-v2 (22M params) |
| Embedding Dim | 384 |
| LoRA Rank | 16 |
| LoRA Alpha | 32 |
| LoRA Target | Query, Value projections |
| LoRA Params | 147,456 per adapter (0.645%) |
| RL Policy | 66,566 params |
| Domains | Medical, Legal, Code, Finance, Scientific |
## Performance
### Base Model: Multi-Task Embedding Quality
The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks:
| Task | Dataset | Score | What It Measures |
|------|---------|-------|------------------|
| Question Similarity | QQP | 0.8636 | Intent matching |
| Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness |
| Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase |
| NLI | MultiNLI | 0.7465 | Logical relationships |
| Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity |
| **Average** | | **0.7134** | **Cross-task capability** |
**Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content.
### Domain Routing Accuracy
**Training Results (In-Distribution)**
| Metric | Value |
|--------|-------|
| Domain Accuracy | 92.5% |
| Average Reward | 1.527 |
| Training Steps | 5,000 |
**Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)**
The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness:
| Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement |
|--------|--------------------------|------------|-------------|
| Domain Accuracy | 56.0% | 20.4% | **+35.6%** |
| Avg Confidence | 28.5% | 77.6% | More calibrated |
### Per-Domain Breakdown
| Domain | DomainEmbedder | Base Model | Note |
|--------|----------------|------------|------|
| Finance | 78.0% | 0.0% | +78.0% |
| Medical | 73.0% | 0.0% | +73.0% |
| Legal | 53.0% | 15.0% | +38.0% |
| Scientific | 48.0% | 1.0% | +47.0% |
| Code | 28.0% | 86.0% | Base over-predicted code |
**Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution.
## Training Details
### Domain Training Data
| Domain | Samples | Sources |
|--------|---------|---------|
| Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor |
| Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B |
| Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions |
| Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA |
| Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) |
| **Total** | **200,000** | |
### LoRA Training Configuration
| Parameter | Value |
|-----------|-------|
| Epochs | 3 per domain |
| Batch Size | 32 |
| Learning Rate | 2e-4 |
| Loss | Contrastive (InfoNCE-style) |
| Trainable Params | 147,456 (0.645% of base) |
| Warmup Steps | 500 |
| Max Gradient Norm | 1.0 |
### RL Training (Supervised A2C)
| Parameter | Value |
|-----------|-------|
| Algorithm | Actor-Critic (A2C) |
| Total Steps | 5,000 |
| Episodes per Step | 5 |
| Gamma (discount) | 0.99 |
| Entropy Coef | 0.1 (high exploration) |
| Value Coef | 0.5 |
| Correctness Bonus | +1.0 |
| Correctness Penalty | -0.5 |
| Baseline Decay | 0.99 |
### Curriculum Learning Phases
| Phase | Steps | Data | Accuracy |
|-------|-------|------|----------|
| 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% β†’ 87.5% |
| 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% β†’ 89.3% |
| 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% β†’ 92.5% |
### Training Progress
| Version | Step | Accuracy | Reward |
|---------|------|----------|--------|
| v2.1 | 500 | 68.8% | 1.100 |
| v2.2 | 1,000 | 80.1% | 1.336 |
| v2.3 | 1,500 | 87.5% | 1.454 |
| v2.4 | 2,000 | 88.9% | 1.480 |
| v2.5 | 3,000 | 89.3% | 1.507 |
| **v2.6** | **4,000** | **92.5%** | **1.527** |
## Package Contents
```
DomainEmbedder-v2.6/
β”œβ”€β”€ FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB)
β”œβ”€β”€ rl_policy.pt # Trained RL policy (0.27 MB)
β”œβ”€β”€ metadata.json # Training metadata
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ medical_lora/ # Medical domain adapter (0.6 MB)
β”‚ β”œβ”€β”€ adapter_config.json
β”‚ └── adapter_model.safetensors
β”œβ”€β”€ legal_lora/ # Legal domain adapter (0.6 MB)
β”œβ”€β”€ code_lora/ # Code domain adapter (0.6 MB)
β”œβ”€β”€ finance_lora/ # Finance domain adapter (0.6 MB)
└── scientific_lora/ # Scientific domain adapter (0.6 MB)
```
**Total Size**: ~90 MB (self-contained)
## Intended Use
### Best Use Cases
- **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases
- **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains
- **Document Classification**: Automatic domain routing for document processing pipelines
- **Semantic Similarity**: Information-dense embeddings for precise matching
- **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain
### Limitations
- **English Only**: Trained exclusively on English data
- **Max Length**: 512 tokens maximum input length
- **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific)
- **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries
- **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage
## Citation
```bibtex
@misc{domainembedder2025,
author = {Asad, Zain},
title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA},
year = {2025},
publisher = {Hugging Face},
note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning}
}
```
## Author
**Zain Asad**
## License
MIT License