EphAsad commited on
Commit
0c58c88
Β·
verified Β·
1 Parent(s): 485908e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +388 -41
README.md CHANGED
@@ -1,62 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # DomainEmbedder-v2.6
2
 
3
- πŸ”₯ **Production-Ready Domain-Adaptive Embedding System (SUPERVISED RL + CURRICULUM)**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- ## πŸ“Š Model Information
 
 
 
 
 
6
 
7
- - **Version**: DomainEmbedder-v2.6
8
- - **Method**: TRUE LoRA + SUPERVISED RL + CURRICULUM LEARNING
9
- - **Training Step**: 4,000 / 5,000
10
- - **Average Reward**: 1.5270
11
- - **🎯 Domain Accuracy**: 92.5%
12
- - **Baseline Reward**: 0.2991
13
- - **Date**: 2026-02-10 02:01:18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- ## 🎯 What This Does
16
 
17
- This model uses **supervised reinforcement learning with curriculum learning** to automatically select the best domain-specific adapter (LoRA) for any input text.
 
 
 
 
 
 
18
 
19
- **Key Improvements over v1.0:**
20
- - βœ… **Supervised RL**: Learns from labeled domain data (85-90% accuracy vs 20% random)
21
- - βœ… **Curriculum Learning**: Progressive training (easy β†’ moderate β†’ hard)
22
- - βœ… **Domain Correctness Rewards**: +1.0 bonus for correct domain, -0.5 penalty for wrong
23
- - βœ… **Higher Entropy**: 0.1 (vs 0.01) for better exploration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- ## πŸ“¦ Package Contents
26
  ```
27
  DomainEmbedder-v2.6/
28
- β”œβ”€ FireDevourerEmbedder-RL-v3.6.pt # Base model (86 MB)
29
- β”œβ”€ rl_policy.pt # Supervised RL policy (0.2 MB)
30
- β”œβ”€ medical_lora/ # Medical adapter (0.6 MB)
31
- β”œβ”€ legal_lora/ # Legal adapter (0.6 MB)
32
- β”œβ”€ code_lora/ # Code adapter (0.6 MB)
33
- β”œβ”€ finance_lora/ # Finance adapter (0.6 MB)
34
- β”œβ”€ scientific_lora/ # Scientific adapter (0.6 MB)
35
- β”œβ”€ metadata.json # Training metadata
36
- └─ README.md # This file
 
 
37
  ```
38
 
39
  **Total Size**: ~90 MB (self-contained)
40
 
41
- ## πŸ“ˆ Performance
42
 
43
- - **Domain Accuracy**: 92.5% (vs 20% random baseline)
44
- - **Average Reward**: 1.5270
45
- - **Baseline Reward**: 0.2991
46
- - **Improvement**: 410.5%
47
 
48
- ### Training Method
49
- - **Supervised RL**: Policy learns from labeled domain data
50
- - **Curriculum Learning**: 3 phases (easy β†’ moderate β†’ hard)
51
- - **Correctness Bonus**: +1.0 for correct domain selection
52
- - **Correctness Penalty**: -0.5 for wrong domain selection
53
 
54
- ## πŸš€ Usage
55
 
56
- (Same loading code as before - see previous README)
 
 
 
 
57
 
58
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- **Built with πŸ”₯ by the FireDevourer team**
61
 
62
- *Trained with SUPERVISED RL + CURRICULUM LEARNING for 85-90% domain accuracy!*
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - lora
8
+ - peft
9
+ - reinforcement-learning
10
+ - domain-adaptation
11
+ - sentence-embeddings
12
+ - curriculum-learning
13
+ - multi-task-learning
14
+ - rag
15
+ - information-retrieval
16
+ - cross-domain
17
+ - sentence-transformers
18
+ base_model: sentence-transformers/all-MiniLM-L6-v2
19
+ pipeline_tag: sentence-similarity
20
+ datasets:
21
+ - sentence-transformers/stsb
22
+ - nyu-mll/multi_nli
23
+ - quora
24
+ - google-research-datasets/paws
25
+ - nyu-mll/glue
26
+ - GBaker/MedQA-USMLE-4-options-hf
27
+ - lex_glue
28
+ - gbharti/finance-alpaca
29
+ - scientific_papers
30
+ model-index:
31
+ - name: DomainEmbedder-v2.6
32
+ results:
33
+ - task:
34
+ type: domain-classification
35
+ name: Domain Classification
36
+ metrics:
37
+ - type: accuracy
38
+ value: 0.925
39
+ name: Training Accuracy
40
+ - type: accuracy
41
+ value: 0.560
42
+ name: Stress-Test Accuracy
43
+ ---
44
+
45
  # DomainEmbedder-v2.6
46
 
47
+ > **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval**
48
+
49
+ DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing.
50
+
51
+ ## What This Model Does
52
+
53
+ | Component | Description |
54
+ |-----------|-------------|
55
+ | **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting |
56
+ | **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) |
57
+ | **RL Policy** | Automatically selects the optimal domain adapter for any input |
58
+
59
+ **Why this matters for RAG/Retrieval:**
60
+ - Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions)
61
+ - Domain routing provides context-appropriate representations
62
+ - Results in more precise retrieval across diverse content types
63
+
64
+ ## Key Innovation: Dual RL Architecture
65
+
66
+ | Stage | RL Application | Purpose |
67
+ |-------|---------------|---------|
68
+ | Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training |
69
+ | Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference |
70
+
71
+ This dual RL approach is novel: **RL at training time AND inference time**.
72
+
73
+ ## Quick Start
74
+
75
+ ### Installation
76
+
77
+ ```bash
78
+ pip install torch transformers peft
79
+ ```
80
+
81
+ ### Loading the Model
82
+
83
+ ```python
84
+ import torch
85
+ import torch.nn as nn
86
+ from transformers import AutoTokenizer, AutoModel
87
+ from peft import PeftModel
88
+
89
+ # Device setup
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+
92
+ # Load tokenizer
93
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
94
+
95
+ # Define the base embedder architecture
96
+ class FireDevourerEmbedder(nn.Module):
97
+ def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'):
98
+ super().__init__()
99
+ self.encoder = AutoModel.from_pretrained(base_model_name)
100
+ self.hidden_size = 384
101
+
102
+ # Task heads
103
+ self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid())
104
+ self.nli_head = nn.Linear(384, 3)
105
+ self.qqp_head = nn.Linear(384, 2)
106
+ self.paws_head = nn.Linear(384, 2)
107
+ self.domain_head = nn.Linear(384, 5)
108
+
109
+ def mean_pool(self, token_embeddings, attention_mask):
110
+ mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
111
+ return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
112
+
113
+ def forward(self, input_ids, attention_mask, task='encode'):
114
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
115
+ embedding = self.mean_pool(outputs.last_hidden_state, attention_mask)
116
+
117
+ if task == 'encode':
118
+ return embedding
119
+ elif task == 'domain':
120
+ return self.domain_head(embedding)
121
+ # Add other tasks as needed
122
+
123
+ # Define RL Policy Network
124
+ class RLPolicyNetwork(nn.Module):
125
+ def __init__(self, input_dim=384, hidden_dim=128, num_actions=5):
126
+ super().__init__()
127
+ self.network = nn.Sequential(
128
+ nn.Linear(input_dim, hidden_dim),
129
+ nn.ReLU(),
130
+ nn.Linear(hidden_dim, hidden_dim),
131
+ nn.ReLU()
132
+ )
133
+ self.policy_head = nn.Linear(hidden_dim, num_actions)
134
+ self.value_head = nn.Linear(hidden_dim, 1)
135
+
136
+ def forward(self, x):
137
+ features = self.network(x)
138
+ policy = torch.softmax(self.policy_head(features), dim=-1)
139
+ value = self.value_head(features)
140
+ return policy, value
141
+
142
+ # Load model
143
+ model_dir = "path/to/DomainEmbedder-v2.6"
144
+
145
+ # 1. Load base model with checkpoint
146
+ base_model = FireDevourerEmbedder()
147
+ checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device)
148
+ base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
149
+ base_model.to(device)
150
+ base_model.eval()
151
+
152
+ # 2. Load RL policy
153
+ rl_policy = RLPolicyNetwork()
154
+ rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device)
155
+ rl_policy.load_state_dict(rl_checkpoint['policy_state_dict'])
156
+ rl_policy.to(device)
157
+ rl_policy.eval()
158
+
159
+ # 3. Load LoRA adapters (example: medical)
160
+ from peft import PeftModel
161
+ lora_model = PeftModel.from_pretrained(
162
+ base_model.encoder,
163
+ f"{model_dir}/medical_lora"
164
+ )
165
+ ```
166
+
167
+ ### Computing Embeddings with Domain Selection
168
+
169
+ ```python
170
+ def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device):
171
+ """Get domain-aware embedding for input text."""
172
+ # Tokenize
173
+ inputs = tokenizer(text, return_tensors='pt', padding=True,
174
+ truncation=True, max_length=512).to(device)
175
+
176
+ # Get base embedding
177
+ with torch.no_grad():
178
+ base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode')
179
+
180
+ # Get domain selection from RL policy
181
+ policy_probs, _ = rl_policy(base_emb)
182
+ domain_idx = torch.argmax(policy_probs, dim=-1).item()
183
+
184
+ domains = ['medical', 'legal', 'code', 'finance', 'scientific']
185
+ selected_domain = domains[domain_idx]
186
+ confidence = policy_probs[0, domain_idx].item()
187
 
188
+ return {
189
+ 'embedding': base_emb,
190
+ 'domain': selected_domain,
191
+ 'confidence': confidence,
192
+ 'all_probs': policy_probs[0].cpu().numpy()
193
+ }
194
 
195
+ # Example usage
196
+ result = get_domain_embedding(
197
+ "What are the symptoms of diabetes?",
198
+ base_model, rl_policy, None, tokenizer, device
199
+ )
200
+ print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})")
201
+ ```
202
+
203
+ ## Architecture
204
+
205
+ ```
206
+ Input Text
207
+ β”‚
208
+ β–Ό
209
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
210
+ β”‚ MiniLM-L6-v2 Encoder (FROZEN) β”‚
211
+ β”‚ + Optional LoRA Adapter (domain-specific) β”‚
212
+ β”‚ 384-dimensional output β”‚
213
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
214
+ β”‚
215
+ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
216
+ β”‚ β”‚
217
+ β–Ό β–Ό
218
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
219
+ β”‚ Base Embedding β”‚ β”‚ RL Policy Net β”‚
220
+ β”‚ (384-dim) β”‚ β”‚ (66K params) β”‚
221
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
222
+ β”‚
223
+ β–Ό
224
+ Domain Selection
225
+ [Medical, Legal, Code,
226
+ Finance, Scientific]
227
+ β”‚
228
+ β–Ό
229
+ Load corresponding LoRA adapter
230
+ β”‚
231
+ β–Ό
232
+ Domain-Adapted Embedding
233
+ ```
234
+
235
+ ### Component Details
236
+
237
+ | Component | Specification |
238
+ |-----------|---------------|
239
+ | Base Encoder | MiniLM-L6-v2 (22M params) |
240
+ | Embedding Dim | 384 |
241
+ | LoRA Rank | 16 |
242
+ | LoRA Alpha | 32 |
243
+ | LoRA Target | Query, Value projections |
244
+ | LoRA Params | 147,456 per adapter (0.645%) |
245
+ | RL Policy | 66,566 params |
246
+ | Domains | Medical, Legal, Code, Finance, Scientific |
247
+
248
+ ## Performance
249
+
250
+ ### Base Model: Multi-Task Embedding Quality
251
+
252
+ The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks:
253
+
254
+ | Task | Dataset | Score | What It Measures |
255
+ |------|---------|-------|------------------|
256
+ | Question Similarity | QQP | 0.8636 | Intent matching |
257
+ | Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness |
258
+ | Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase |
259
+ | NLI | MultiNLI | 0.7465 | Logical relationships |
260
+ | Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity |
261
+ | **Average** | | **0.7134** | **Cross-task capability** |
262
+
263
+ **Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content.
264
+
265
+ ### Domain Routing Accuracy
266
+
267
+ **Training Results (In-Distribution)**
268
+
269
+ | Metric | Value |
270
+ |--------|-------|
271
+ | Domain Accuracy | 92.5% |
272
+ | Average Reward | 1.527 |
273
+ | Training Steps | 5,000 |
274
+
275
+ **Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)**
276
+
277
+ The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness:
278
+
279
+ | Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement |
280
+ |--------|--------------------------|------------|-------------|
281
+ | Domain Accuracy | 56.0% | 20.4% | **+35.6%** |
282
+ | Avg Confidence | 28.5% | 77.6% | More calibrated |
283
 
284
+ ### Per-Domain Breakdown
285
 
286
+ | Domain | DomainEmbedder | Base Model | Note |
287
+ |--------|----------------|------------|------|
288
+ | Finance | 78.0% | 0.0% | +78.0% |
289
+ | Medical | 73.0% | 0.0% | +73.0% |
290
+ | Legal | 53.0% | 15.0% | +38.0% |
291
+ | Scientific | 48.0% | 1.0% | +47.0% |
292
+ | Code | 28.0% | 86.0% | Base over-predicted code |
293
 
294
+ **Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution.
295
+
296
+ ## Training Details
297
+
298
+ ### Domain Training Data
299
+
300
+ | Domain | Samples | Sources |
301
+ |--------|---------|---------|
302
+ | Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor |
303
+ | Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B |
304
+ | Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions |
305
+ | Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA |
306
+ | Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) |
307
+ | **Total** | **200,000** | |
308
+
309
+ ### LoRA Training Configuration
310
+
311
+ | Parameter | Value |
312
+ |-----------|-------|
313
+ | Epochs | 3 per domain |
314
+ | Batch Size | 32 |
315
+ | Learning Rate | 2e-4 |
316
+ | Loss | Contrastive (InfoNCE-style) |
317
+ | Trainable Params | 147,456 (0.645% of base) |
318
+ | Warmup Steps | 500 |
319
+ | Max Gradient Norm | 1.0 |
320
+
321
+ ### RL Training (Supervised A2C)
322
+
323
+ | Parameter | Value |
324
+ |-----------|-------|
325
+ | Algorithm | Actor-Critic (A2C) |
326
+ | Total Steps | 5,000 |
327
+ | Episodes per Step | 5 |
328
+ | Gamma (discount) | 0.99 |
329
+ | Entropy Coef | 0.1 (high exploration) |
330
+ | Value Coef | 0.5 |
331
+ | Correctness Bonus | +1.0 |
332
+ | Correctness Penalty | -0.5 |
333
+ | Baseline Decay | 0.99 |
334
+
335
+ ### Curriculum Learning Phases
336
+
337
+ | Phase | Steps | Data | Accuracy |
338
+ |-------|-------|------|----------|
339
+ | 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% β†’ 87.5% |
340
+ | 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% β†’ 89.3% |
341
+ | 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% β†’ 92.5% |
342
+
343
+ ### Training Progress
344
+
345
+ | Version | Step | Accuracy | Reward |
346
+ |---------|------|----------|--------|
347
+ | v2.1 | 500 | 68.8% | 1.100 |
348
+ | v2.2 | 1,000 | 80.1% | 1.336 |
349
+ | v2.3 | 1,500 | 87.5% | 1.454 |
350
+ | v2.4 | 2,000 | 88.9% | 1.480 |
351
+ | v2.5 | 3,000 | 89.3% | 1.507 |
352
+ | **v2.6** | **4,000** | **92.5%** | **1.527** |
353
+
354
+ ## Package Contents
355
 
 
356
  ```
357
  DomainEmbedder-v2.6/
358
+ β”œβ”€β”€ FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB)
359
+ β”œβ”€β”€ rl_policy.pt # Trained RL policy (0.27 MB)
360
+ β”œβ”€β”€ metadata.json # Training metadata
361
+ β”œβ”€β”€ README.md # This file
362
+ β”œβ”€β”€ medical_lora/ # Medical domain adapter (0.6 MB)
363
+ β”‚ β”œβ”€β”€ adapter_config.json
364
+ β”‚ └── adapter_model.safetensors
365
+ β”œβ”€β”€ legal_lora/ # Legal domain adapter (0.6 MB)
366
+ β”œβ”€β”€ code_lora/ # Code domain adapter (0.6 MB)
367
+ β”œβ”€β”€ finance_lora/ # Finance domain adapter (0.6 MB)
368
+ └── scientific_lora/ # Scientific domain adapter (0.6 MB)
369
  ```
370
 
371
  **Total Size**: ~90 MB (self-contained)
372
 
373
+ ## Intended Use
374
 
375
+ ### Best Use Cases
 
 
 
376
 
377
+ - **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases
378
+ - **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains
379
+ - **Document Classification**: Automatic domain routing for document processing pipelines
380
+ - **Semantic Similarity**: Information-dense embeddings for precise matching
381
+ - **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain
382
 
383
+ ### Limitations
384
 
385
+ - **English Only**: Trained exclusively on English data
386
+ - **Max Length**: 512 tokens maximum input length
387
+ - **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific)
388
+ - **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries
389
+ - **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage
390
 
391
+ ## Citation
392
+
393
+ ```bibtex
394
+ @misc{domainembedder2025,
395
+ author = {Asad, Zain},
396
+ title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA},
397
+ year = {2025},
398
+ publisher = {Hugging Face},
399
+ note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning}
400
+ }
401
+ ```
402
+
403
+ ## Author
404
+
405
+ **Zain Asad**
406
 
407
+ ## License
408
 
409
+ MIT License