|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
tags: |
|
|
- lora |
|
|
- peft |
|
|
- reinforcement-learning |
|
|
- domain-adaptation |
|
|
- sentence-embeddings |
|
|
- curriculum-learning |
|
|
- multi-task-learning |
|
|
- rag |
|
|
- information-retrieval |
|
|
- cross-domain |
|
|
- sentence-transformers |
|
|
base_model: |
|
|
- sentence-transformers/all-MiniLM-L6-v2 |
|
|
- EphAsad/FireDevourerEmbedder-RL-v3.6 |
|
|
pipeline_tag: sentence-similarity |
|
|
datasets: |
|
|
- sentence-transformers/stsb |
|
|
- nyu-mll/multi_nli |
|
|
- quora |
|
|
- google-research-datasets/paws |
|
|
- nyu-mll/glue |
|
|
- GBaker/MedQA-USMLE-4-options-hf |
|
|
- lex_glue |
|
|
- gbharti/finance-alpaca |
|
|
- scientific_papers |
|
|
model-index: |
|
|
- name: DomainEmbedder-v2.6 |
|
|
results: |
|
|
- task: |
|
|
type: domain-classification |
|
|
name: Domain Classification |
|
|
metrics: |
|
|
- type: accuracy |
|
|
value: 0.925 |
|
|
name: Training Accuracy |
|
|
- type: accuracy |
|
|
value: 0.56 |
|
|
name: Stress-Test Accuracy |
|
|
--- |
|
|
|
|
|
# DomainEmbedder-v2.6 |
|
|
|
|
|
> **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval** |
|
|
|
|
|
DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing. |
|
|
|
|
|
## What This Model Does |
|
|
|
|
|
| Component | Description | |
|
|
|-----------|-------------| |
|
|
| **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting | |
|
|
| **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) | |
|
|
| **RL Policy** | Automatically selects the optimal domain adapter for any input | |
|
|
|
|
|
**Why this matters for RAG/Retrieval:** |
|
|
- Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions) |
|
|
- Domain routing provides context-appropriate representations |
|
|
- Results in more precise retrieval across diverse content types |
|
|
|
|
|
## Key Innovation: Dual RL Architecture |
|
|
|
|
|
| Stage | RL Application | Purpose | |
|
|
|-------|---------------|---------| |
|
|
| Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training | |
|
|
| Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference | |
|
|
|
|
|
This dual RL approach is novel: **RL at training time AND inference time**. |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch transformers peft |
|
|
``` |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from peft import PeftModel |
|
|
|
|
|
# Device setup |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
# Define the base embedder architecture |
|
|
class FireDevourerEmbedder(nn.Module): |
|
|
def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'): |
|
|
super().__init__() |
|
|
self.encoder = AutoModel.from_pretrained(base_model_name) |
|
|
self.hidden_size = 384 |
|
|
|
|
|
# Task heads |
|
|
self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid()) |
|
|
self.nli_head = nn.Linear(384, 3) |
|
|
self.qqp_head = nn.Linear(384, 2) |
|
|
self.paws_head = nn.Linear(384, 2) |
|
|
self.domain_head = nn.Linear(384, 5) |
|
|
|
|
|
def mean_pool(self, token_embeddings, attention_mask): |
|
|
mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, task='encode'): |
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
embedding = self.mean_pool(outputs.last_hidden_state, attention_mask) |
|
|
|
|
|
if task == 'encode': |
|
|
return embedding |
|
|
elif task == 'domain': |
|
|
return self.domain_head(embedding) |
|
|
# Add other tasks as needed |
|
|
|
|
|
# Define RL Policy Network |
|
|
class RLPolicyNetwork(nn.Module): |
|
|
def __init__(self, input_dim=384, hidden_dim=128, num_actions=5): |
|
|
super().__init__() |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.ReLU() |
|
|
) |
|
|
self.policy_head = nn.Linear(hidden_dim, num_actions) |
|
|
self.value_head = nn.Linear(hidden_dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.network(x) |
|
|
policy = torch.softmax(self.policy_head(features), dim=-1) |
|
|
value = self.value_head(features) |
|
|
return policy, value |
|
|
|
|
|
# Load model |
|
|
model_dir = "path/to/DomainEmbedder-v2.6" |
|
|
|
|
|
# 1. Load base model with checkpoint |
|
|
base_model = FireDevourerEmbedder() |
|
|
checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device) |
|
|
base_model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
|
|
base_model.to(device) |
|
|
base_model.eval() |
|
|
|
|
|
# 2. Load RL policy |
|
|
rl_policy = RLPolicyNetwork() |
|
|
rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device) |
|
|
rl_policy.load_state_dict(rl_checkpoint['policy_state_dict']) |
|
|
rl_policy.to(device) |
|
|
rl_policy.eval() |
|
|
|
|
|
# 3. Load LoRA adapters (example: medical) |
|
|
from peft import PeftModel |
|
|
lora_model = PeftModel.from_pretrained( |
|
|
base_model.encoder, |
|
|
f"{model_dir}/medical_lora" |
|
|
) |
|
|
``` |
|
|
|
|
|
### Computing Embeddings with Domain Selection |
|
|
|
|
|
```python |
|
|
def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device): |
|
|
"""Get domain-aware embedding for input text.""" |
|
|
# Tokenize |
|
|
inputs = tokenizer(text, return_tensors='pt', padding=True, |
|
|
truncation=True, max_length=512).to(device) |
|
|
|
|
|
# Get base embedding |
|
|
with torch.no_grad(): |
|
|
base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode') |
|
|
|
|
|
# Get domain selection from RL policy |
|
|
policy_probs, _ = rl_policy(base_emb) |
|
|
domain_idx = torch.argmax(policy_probs, dim=-1).item() |
|
|
|
|
|
domains = ['medical', 'legal', 'code', 'finance', 'scientific'] |
|
|
selected_domain = domains[domain_idx] |
|
|
confidence = policy_probs[0, domain_idx].item() |
|
|
|
|
|
return { |
|
|
'embedding': base_emb, |
|
|
'domain': selected_domain, |
|
|
'confidence': confidence, |
|
|
'all_probs': policy_probs[0].cpu().numpy() |
|
|
} |
|
|
|
|
|
# Example usage |
|
|
result = get_domain_embedding( |
|
|
"What are the symptoms of diabetes?", |
|
|
base_model, rl_policy, None, tokenizer, device |
|
|
) |
|
|
print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})") |
|
|
``` |
|
|
|
|
|
## Architecture |
|
|
|
|
|
``` |
|
|
Input Text |
|
|
β |
|
|
βΌ |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β MiniLM-L6-v2 Encoder (FROZEN) β |
|
|
β + Optional LoRA Adapter (domain-specific) β |
|
|
β 384-dimensional output β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
ββββββββββββββββββββββββββββββββββββββββββββ |
|
|
β β |
|
|
βΌ βΌ |
|
|
βββββββββββββββββββ ββββββββββββββββββββ |
|
|
β Base Embedding β β RL Policy Net β |
|
|
β (384-dim) β β (66K params) β |
|
|
βββββββββββββββββββ ββββββββββββββββββββ |
|
|
β |
|
|
βΌ |
|
|
Domain Selection |
|
|
[Medical, Legal, Code, |
|
|
Finance, Scientific] |
|
|
β |
|
|
βΌ |
|
|
Load corresponding LoRA adapter |
|
|
β |
|
|
βΌ |
|
|
Domain-Adapted Embedding |
|
|
``` |
|
|
|
|
|
### Component Details |
|
|
|
|
|
| Component | Specification | |
|
|
|-----------|---------------| |
|
|
| Base Encoder | MiniLM-L6-v2 (22M params) | |
|
|
| Embedding Dim | 384 | |
|
|
| LoRA Rank | 16 | |
|
|
| LoRA Alpha | 32 | |
|
|
| LoRA Target | Query, Value projections | |
|
|
| LoRA Params | 147,456 per adapter (0.645%) | |
|
|
| RL Policy | 66,566 params | |
|
|
| Domains | Medical, Legal, Code, Finance, Scientific | |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Base Model: Multi-Task Embedding Quality |
|
|
|
|
|
The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks: |
|
|
|
|
|
| Task | Dataset | Score | What It Measures | |
|
|
|------|---------|-------|------------------| |
|
|
| Question Similarity | QQP | 0.8636 | Intent matching | |
|
|
| Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness | |
|
|
| Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase | |
|
|
| NLI | MultiNLI | 0.7465 | Logical relationships | |
|
|
| Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity | |
|
|
| **Average** | | **0.7134** | **Cross-task capability** | |
|
|
|
|
|
**Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content. |
|
|
|
|
|
### Domain Routing Accuracy |
|
|
|
|
|
**Training Results (In-Distribution)** |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| Domain Accuracy | 92.5% | |
|
|
| Average Reward | 1.527 | |
|
|
| Training Steps | 5,000 | |
|
|
|
|
|
**Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)** |
|
|
|
|
|
The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness: |
|
|
|
|
|
| Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement | |
|
|
|--------|--------------------------|------------|-------------| |
|
|
| Domain Accuracy | 56.0% | 20.4% | **+35.6%** | |
|
|
| Avg Confidence | 28.5% | 77.6% | More calibrated | |
|
|
|
|
|
### Per-Domain Breakdown |
|
|
|
|
|
| Domain | DomainEmbedder | Base Model | Note | |
|
|
|--------|----------------|------------|------| |
|
|
| Finance | 78.0% | 0.0% | +78.0% | |
|
|
| Medical | 73.0% | 0.0% | +73.0% | |
|
|
| Legal | 53.0% | 15.0% | +38.0% | |
|
|
| Scientific | 48.0% | 1.0% | +47.0% | |
|
|
| Code | 28.0% | 86.0% | Base over-predicted code | |
|
|
|
|
|
**Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution. |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Domain Training Data |
|
|
|
|
|
| Domain | Samples | Sources | |
|
|
|--------|---------|---------| |
|
|
| Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor | |
|
|
| Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B | |
|
|
| Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions | |
|
|
| Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA | |
|
|
| Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) | |
|
|
| **Total** | **200,000** | | |
|
|
|
|
|
### LoRA Training Configuration |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Epochs | 3 per domain | |
|
|
| Batch Size | 32 | |
|
|
| Learning Rate | 2e-4 | |
|
|
| Loss | Contrastive (InfoNCE-style) | |
|
|
| Trainable Params | 147,456 (0.645% of base) | |
|
|
| Warmup Steps | 500 | |
|
|
| Max Gradient Norm | 1.0 | |
|
|
|
|
|
### RL Training (Supervised A2C) |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Algorithm | Actor-Critic (A2C) | |
|
|
| Total Steps | 5,000 | |
|
|
| Episodes per Step | 5 | |
|
|
| Gamma (discount) | 0.99 | |
|
|
| Entropy Coef | 0.1 (high exploration) | |
|
|
| Value Coef | 0.5 | |
|
|
| Correctness Bonus | +1.0 | |
|
|
| Correctness Penalty | -0.5 | |
|
|
| Baseline Decay | 0.99 | |
|
|
|
|
|
### Curriculum Learning Phases |
|
|
|
|
|
| Phase | Steps | Data | Accuracy | |
|
|
|-------|-------|------|----------| |
|
|
| 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% β 87.5% | |
|
|
| 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% β 89.3% | |
|
|
| 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% β 92.5% | |
|
|
|
|
|
### Training Progress |
|
|
|
|
|
| Version | Step | Accuracy | Reward | |
|
|
|---------|------|----------|--------| |
|
|
| v2.1 | 500 | 68.8% | 1.100 | |
|
|
| v2.2 | 1,000 | 80.1% | 1.336 | |
|
|
| v2.3 | 1,500 | 87.5% | 1.454 | |
|
|
| v2.4 | 2,000 | 88.9% | 1.480 | |
|
|
| v2.5 | 3,000 | 89.3% | 1.507 | |
|
|
| **v2.6** | **4,000** | **92.5%** | **1.527** | |
|
|
|
|
|
## Package Contents |
|
|
|
|
|
``` |
|
|
DomainEmbedder-v2.6/ |
|
|
βββ FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB) |
|
|
βββ rl_policy.pt # Trained RL policy (0.27 MB) |
|
|
βββ metadata.json # Training metadata |
|
|
βββ README.md # This file |
|
|
βββ medical_lora/ # Medical domain adapter (0.6 MB) |
|
|
β βββ adapter_config.json |
|
|
β βββ adapter_model.safetensors |
|
|
βββ legal_lora/ # Legal domain adapter (0.6 MB) |
|
|
βββ code_lora/ # Code domain adapter (0.6 MB) |
|
|
βββ finance_lora/ # Finance domain adapter (0.6 MB) |
|
|
βββ scientific_lora/ # Scientific domain adapter (0.6 MB) |
|
|
``` |
|
|
|
|
|
**Total Size**: ~90 MB (self-contained) |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
### Best Use Cases |
|
|
|
|
|
- **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases |
|
|
- **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains |
|
|
- **Document Classification**: Automatic domain routing for document processing pipelines |
|
|
- **Semantic Similarity**: Information-dense embeddings for precise matching |
|
|
- **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain |
|
|
|
|
|
### Limitations |
|
|
|
|
|
- **English Only**: Trained exclusively on English data |
|
|
- **Max Length**: 512 tokens maximum input length |
|
|
- **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific) |
|
|
- **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries |
|
|
- **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{domainembedder2025, |
|
|
author = {Asad, Zain}, |
|
|
title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA}, |
|
|
year = {2025}, |
|
|
publisher = {Hugging Face}, |
|
|
note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Author |
|
|
|
|
|
**Zain Asad** |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |