EphAsad's picture
Upload README.md
c3015be verified
metadata
license: mit
language:
  - en
tags:
  - sentence-transformers
  - sentence-embeddings
  - multi-task-learning
  - reinforcement-learning
  - semantic-similarity
  - nli
  - paraphrase-detection
datasets:
  - sentence-transformers/stsb
  - nyu-mll/multi_nli
  - quora
  - google-research-datasets/paws
  - nyu-mll/glue
pipeline_tag: sentence-similarity
base_model: sentence-transformers/all-MiniLM-L6-v2
model-index:
  - name: FireDevourerEmbedder-RL-v3.6
    results:
      - task:
          type: semantic-similarity
          name: Semantic Textual Similarity
        dataset:
          type: sentence-transformers/stsb
          name: STS-B
        metrics:
          - type: pearson_spearman_avg
            value: 0.3366
      - task:
          type: natural-language-inference
          name: Natural Language Inference
        dataset:
          type: nyu-mll/multi_nli
          name: MultiNLI
        metrics:
          - type: accuracy
            value: 0.7465
      - task:
          type: text-classification
          name: Question Duplicate Detection
        dataset:
          type: quora
          name: QQP
        metrics:
          - type: accuracy
            value: 0.8636
      - task:
          type: text-classification
          name: Paraphrase Detection
        dataset:
          type: google-research-datasets/paws
          name: PAWS
        metrics:
          - type: accuracy
            value: 0.8459
      - task:
          type: text-classification
          name: Paraphrase Detection
        dataset:
          type: nyu-mll/glue
          name: MRPC
        metrics:
          - type: accuracy
            value: 0.7744

FireDevourerEmbedder-RL-v3.6

A multi-task sentence embedding model that uses Reinforcement Learning to dynamically optimize task weights during training. The model learns to balance multiple NLU tasks simultaneously, producing robust sentence embeddings suitable for semantic similarity, natural language inference, and paraphrase detection.

Key Innovation

FireDevourerEmbedder introduces an RL-based adaptive task weighting system that automatically adjusts the importance of each training task based on validation performance. Instead of using fixed task weights, a policy network learns optimal weight distributions during training, leading to better overall performance across diverse NLU benchmarks.

Why Multi-Task? Information-Dense Embeddings

The core philosophy behind FireDevourerEmbedder is that multi-task learning creates richer, more information-dense embeddings than single-task approaches.

By training with multiple task heads simultaneously, the shared encoder is forced to learn representations that capture:

Dimension Learned From What It Captures
Semantic Similarity STS-B Fine-grained meaning overlap
Logical Relationships MultiNLI Entailment, contradiction, neutrality
Question Semantics QQP Intent and duplicate detection
Adversarial Patterns PAWS Word-order sensitivity, paraphrase robustness
Domain Awareness All datasets Context-appropriate representations

This results in embeddings that are:

  • More robust - trained to handle diverse linguistic phenomena
  • More transferable - generalize better to unseen tasks
  • More informative - each dimension of the embedding vector carries meaningful semantic signal

Unlike single-task embedders that optimize for one objective, FireDevourerEmbedder's embeddings simultaneously encode multiple facets of meaning, making them suitable for a wide range of downstream applications without fine-tuning.

Model Details

Property Value
Base Model sentence-transformers/all-MiniLM-L6-v2
Hidden Size 384
Version v3.6
Training Steps 80,000
Total Parameters ~22M

Architecture

The model consists of a shared BERT encoder with task-specific output heads:

Input Sentence(s)
       β”‚
       β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   MiniLM-L6-v2 Encoder  β”‚
β”‚     (384-dim output)    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό
   Mean Pooling
       β”‚
       β”œβ”€β”€β–Ί STS Head (384β†’1) ──► Similarity Score [0,1]
       β”œβ”€β”€β–Ί NLI Head (384β†’3) ──► [Contradiction, Neutral, Entailment]
       β”œβ”€β”€β–Ί QQP Head (384β†’2) ──► [Not Duplicate, Duplicate]
       β”œβ”€β”€β–Ί PAWS Head (384β†’2) ──► [Not Paraphrase, Paraphrase]
       └──► Domain Head (384β†’5) ──► [General, Entailment, Questions, Adversarial, News]

Performance

Task Dataset Metric Score
Question Duplicate Detection QQP Accuracy + F1 0.8636
Paraphrase Detection PAWS Accuracy + F1 0.8459
Paraphrase Detection MRPC Accuracy + F1 0.7744
Natural Language Inference MultiNLI Accuracy + F1 0.7465
Semantic Textual Similarity STS-B Pearson/Spearman 0.3366
Average 0.7134

Training Details

Datasets

The model was trained on 5 balanced datasets with 100,000 samples each (500,000 total):

Dataset Task Type Domain Samples
STS-B Semantic Similarity General 100,000
MultiNLI Natural Language Inference Entailment 100,000
QQP Duplicate Question Detection Questions 100,000
PAWS Paraphrase Detection Adversarial 100,000
MRPC Paraphrase Detection News 100,000

Data Augmentation Strategy

To prevent training bias, all datasets were balanced to exactly 100,000 samples each:

Dataset Original Size Augmentation Method
STS-B ~8,600 Repetition (~12x) + pair swapping
MultiNLI ~433,000 Subsampling
QQP ~400,000 Subsampling
PAWS ~49,000 Repetition (~2x) + pair swapping
MRPC ~3,600 Repetition (~10x, capped) + pair swapping

Why this matters:

  • Without balancing, larger datasets (QQP, MultiNLI) would dominate training
  • Smaller but valuable datasets (MRPC, STS-B) would be underrepresented
  • Equal representation ensures the model learns equally from all task types

Augmentation techniques:

  • Repetition: Smaller datasets repeated up to 10x maximum to prevent memorization
  • Sentence pair swapping: For symmetric tasks, (A, B) pairs also trained as (B, A)

Training Configuration

Parameter Value
Epochs 3
Batch Size 16
Learning Rate 2e-5
Total Steps 93,750
Warmup Steps 9,375 (10%)
Evaluation Frequency Every 10,000 steps
Early Stopping 3 consecutive decreases
Training Time 3.29 hours

RL Weight Adaptation System

The model uses a policy network to dynamically adjust task weights during training:

Parameter Value
RL Learning Rate 0.001
State Dimension 6 (5 task scores + average)
Action Dimension 5 (weight deltas)
Hidden Dimension 32
Delta Scale Β±5% per update
Update Frequency Every 10,000 steps

Weight Evolution During Training:

Task Initial Weight Final Weight Change
STS 0.250 0.282 +0.032
NLI 0.300 0.337 +0.037
QQP 0.200 0.063 -0.137
PAWS 0.150 0.173 +0.023
MRPC 0.100 0.145 +0.045

The RL system learned to reduce QQP weight (already high-performing) while increasing weights for harder tasks.

Training Progress

Version Step Average Score Best Task Improvement
v3.1 10,000 0.6133 QQP (0.8093) +0.6133
v3.2 20,000 0.6430 QQP (0.8351) +0.0297
v3.3 30,000 0.6813 QQP (0.8391) +0.0383
v3.4 40,000 0.6925 QQP (0.8527) +0.0111
v3.5 50,000 0.7099 QQP (0.8579) +0.0175
v3.6 80,000 0.7134 QQP (0.8636) +0.0035

Usage

Installation

pip install torch transformers

Loading the Model

import torch
from transformers import AutoTokenizer, AutoModel

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained("path/to/FireDevourerEmbedder-RL-v3.6")
base_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Load checkpoint
checkpoint = torch.load("path/to/FireDevourerEmbedder-RL-v3.6/full_checkpoint.pt")

# Load model weights (you'll need to reconstruct the full model class)
# See the training script for the complete FireDevourerEmbedder class definition

Computing Embeddings

def mean_pooling(model_output, attention_mask):
    """Apply mean pooling to get sentence embeddings."""
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embedding(text, model, tokenizer):
    """Get sentence embedding for a single text."""
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    return mean_pooling(outputs, inputs["attention_mask"])

# Example
embedding = get_embedding("This is a sample sentence.", model, tokenizer)
print(f"Embedding shape: {embedding.shape}")  # [1, 384]

Computing Similarity

from torch.nn.functional import cosine_similarity

sentence1 = "A man is eating pizza"
sentence2 = "A person is eating food"

emb1 = get_embedding(sentence1, model, tokenizer)
emb2 = get_embedding(sentence2, model, tokenizer)

similarity = cosine_similarity(emb1, emb2)
print(f"Similarity: {similarity.item():.4f}")  # ~0.9448

Task-Specific Predictions

# After loading the full model with task heads:

def predict_nli(sentence1, sentence2, model, tokenizer):
    """Predict entailment relationship."""
    # Get embeddings for both sentences
    emb1 = get_embedding(sentence1, model, tokenizer)
    emb2 = get_embedding(sentence2, model, tokenizer)

    # Combine embeddings (concatenate with element-wise difference and product)
    combined = torch.cat([emb1, emb2, torch.abs(emb1 - emb2), emb1 * emb2], dim=-1)

    # Pass through NLI head
    logits = model.nli_head(combined)
    prediction = torch.argmax(logits, dim=-1)

    labels = ["Contradiction", "Neutral", "Entailment"]
    return labels[prediction.item()]

# Example
result = predict_nli("It's raining outside", "The weather is sunny", model, tokenizer)
print(f"NLI Prediction: {result}")  # Contradiction

Evaluation Results

Test Suite Statistics (20 diverse test cases)

Cosine Similarity:

Statistic Value
Mean 0.8001
Std 0.1562
Min 0.3139
Max 0.9831
Median 0.8149

STS Score:

Statistic Value
Mean 0.5672
Std 0.2270
Min 0.0182
Max 0.9468
Median 0.5788

Example Predictions

Sentence 1 Sentence 2 Cosine Sim NLI Domain
"A man is eating pizza" "A person is eating food" 0.9448 Entailment General
"It's raining outside" "The weather is sunny" 0.7124 Contradiction Entailment
"How do I learn Python?" "What's the best way to learn Python?" 0.8915 Entailment Questions
"The quick brown fox jumps..." "A fast brown fox leaps..." 0.7837 Entailment General

Intended Use

Best Use Cases

  • Semantic Search: Finding similar documents or passages
  • Duplicate Detection: Identifying duplicate questions or content
  • Paraphrase Mining: Finding paraphrased text pairs
  • Clustering: Grouping similar sentences or documents
  • Natural Language Inference: Determining textual entailment

Limitations

  • STS-B Performance: The model shows lower performance on fine-grained semantic similarity regression (0.3366). For tasks requiring precise similarity scores, consider using dedicated STS models.
  • English Only: Trained exclusively on English data.
  • Max Length: 512 tokens maximum input length.
  • Adversarial Robustness: While trained on PAWS adversarial data, performance on novel adversarial examples may vary.

Training Loss Progression

Epoch STS Loss NLI Loss QQP Loss PAWS Loss MRPC Loss Domain Loss Total Loss
1 0.0073 0.2508 0.0742 0.0966 0.0287 0.0529 0.4977
2 0.0038 0.1970 0.0430 0.0638 0.0025 0.0196 0.3211
3 0.0031 0.1822 0.0221 0.0479 0.0009 0.0141 0.2631

Citation

If you use this model in your research, please cite:

@misc{firedevourerembedder2025,
  author = {Asad, Zain},
  title = {FireDevourerEmbedder: Multi-Task Sentence Embeddings with RL-Adaptive Task Weighting},
  year = {2025},
  publisher = {Hugging Face},
  url = {https://huggingface.co/zainasad/FireDevourerEmbedder-RL-v3.6}
}

Author

Zain Asad

License

MIT License