|
|
--- |
|
|
language: |
|
|
- en |
|
|
license: cc-by-nc-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- citation-verification |
|
|
- retrieval-augmented-generation |
|
|
- rag |
|
|
- cross-lingual |
|
|
- deberta |
|
|
- cross-encoder |
|
|
- nli |
|
|
- attribution |
|
|
pipeline_tag: text-classification |
|
|
datasets: |
|
|
- fever |
|
|
- din0s/asqa |
|
|
- miracl/hagrid |
|
|
metrics: |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
- accuracy |
|
|
- roc_auc |
|
|
base_model: microsoft/deberta-v3-base |
|
|
model-index: |
|
|
- name: dualtrack-alignment-module |
|
|
results: |
|
|
- task: |
|
|
type: text-classification |
|
|
name: Citation Verification |
|
|
metrics: |
|
|
- type: f1 |
|
|
value: 0.89 |
|
|
name: F1 Score |
|
|
- type: accuracy |
|
|
value: 0.87 |
|
|
name: Accuracy |
|
|
- type: roc_auc |
|
|
value: 0.94 |
|
|
name: ROC-AUC |
|
|
--- |
|
|
|
|
|
# DualTrack Alignment Module |
|
|
|
|
|
> **Anonymous submission to ACL 2026** |
|
|
|
|
|
A cross-encoder model for detecting **citation drift** in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim). |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model addresses a critical reliability problem in RAG systems: **citation drift**, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language. |
|
|
|
|
|
### Architecture |
|
|
|
|
|
``` |
|
|
Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]" |
|
|
↓ |
|
|
DeBERTa-v3-base (184M parameters) |
|
|
↓ |
|
|
[CLS] embedding (768-dim) |
|
|
↓ |
|
|
Linear(768, 2) → Softmax |
|
|
↓ |
|
|
Output: P(valid citation) |
|
|
``` |
|
|
|
|
|
### Why Cross-Encoder? |
|
|
|
|
|
Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components **together**, enabling: |
|
|
- Cross-attention between claim and source |
|
|
- Detection of subtle semantic mismatches |
|
|
- Better handling of paraphrases vs. factual errors |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
### Primary Use Cases |
|
|
|
|
|
1. **Post-hoc citation verification**: Validate citations in RAG outputs before serving to users |
|
|
2. **Citation drift detection**: Identify claims that have semantically drifted from their sources |
|
|
3. **Training signal**: Provide rewards for citation-aware generation |
|
|
|
|
|
### Out of Scope |
|
|
|
|
|
- General NLI/entailment (model is specialized for RAG citation patterns) |
|
|
- Fact-checking against world knowledge (requires source passage) |
|
|
- Non-English source documents (trained on English sources only) |
|
|
|
|
|
## How to Use |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch |
|
|
``` |
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
# Load model |
|
|
model_name = "anonymous-acl2026/dualtrack-alignment" # Replace with actual path |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
model.eval() |
|
|
|
|
|
def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]: |
|
|
""" |
|
|
Check if a citation is valid. |
|
|
|
|
|
Args: |
|
|
user_claim: The claim shown to the user |
|
|
evidence: Evidence track representation (can be same as user_claim) |
|
|
source: The source passage being cited |
|
|
threshold: Classification threshold (default from training) |
|
|
|
|
|
Returns: |
|
|
(is_valid, probability) |
|
|
""" |
|
|
# Format input |
|
|
text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}" |
|
|
|
|
|
# Tokenize |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item() |
|
|
|
|
|
return prob >= threshold, prob |
|
|
|
|
|
# Example: Valid citation |
|
|
is_valid, prob = check_citation( |
|
|
user_claim="Python was created by Guido van Rossum.", |
|
|
evidence="Python was created by Guido van Rossum.", |
|
|
source="Python is a programming language created by Guido van Rossum in 1991." |
|
|
) |
|
|
print(f"Valid: {is_valid}, Probability: {prob:.3f}") |
|
|
# Output: Valid: True, Probability: 0.95 |
|
|
|
|
|
# Example: Invalid citation (wrong date) |
|
|
is_valid, prob = check_citation( |
|
|
user_claim="Python was created in 1989.", |
|
|
evidence="Python was created in 1989.", |
|
|
source="Python is a programming language created by Guido van Rossum in 1991." |
|
|
) |
|
|
print(f"Valid: {is_valid}, Probability: {prob:.3f}") |
|
|
# Output: Valid: False, Probability: 0.12 |
|
|
``` |
|
|
|
|
|
### Batch Processing |
|
|
|
|
|
```python |
|
|
def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]: |
|
|
""" |
|
|
Check multiple citations efficiently. |
|
|
|
|
|
Args: |
|
|
examples: List of dicts with keys 'user', 'evidence', 'source' |
|
|
batch_size: Batch size for inference |
|
|
|
|
|
Returns: |
|
|
List of probabilities |
|
|
""" |
|
|
all_probs = [] |
|
|
|
|
|
for i in range(0, len(examples), batch_size): |
|
|
batch = examples[i:i + batch_size] |
|
|
|
|
|
texts = [ |
|
|
f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}" |
|
|
for ex in batch |
|
|
] |
|
|
|
|
|
inputs = tokenizer( |
|
|
texts, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist() |
|
|
|
|
|
all_probs.extend(probs) |
|
|
|
|
|
return all_probs |
|
|
``` |
|
|
|
|
|
### Integration with DualTrack |
|
|
|
|
|
```python |
|
|
class DualTrackAlignmentModule: |
|
|
""" |
|
|
Alignment module for the DualTrack RAG system. |
|
|
|
|
|
Detects citation drift between user track and source documents. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, threshold: float = None, device: str = None): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
# Load optimal threshold from metadata |
|
|
import json |
|
|
import os |
|
|
metadata_path = os.path.join(model_path, "metadata.json") |
|
|
if os.path.exists(metadata_path): |
|
|
with open(metadata_path) as f: |
|
|
metadata = json.load(f) |
|
|
self.threshold = threshold or metadata.get("optimal_threshold", 0.5) |
|
|
else: |
|
|
self.threshold = threshold or 0.5 |
|
|
|
|
|
def detect_drift( |
|
|
self, |
|
|
user_claims: list[str], |
|
|
evidence_claims: list[str], |
|
|
sources: list[str] |
|
|
) -> list[dict]: |
|
|
""" |
|
|
Detect citation drift for multiple claim-source pairs. |
|
|
|
|
|
Returns list of {is_valid, probability, drift_detected}. |
|
|
""" |
|
|
results = [] |
|
|
|
|
|
for user, evidence, source in zip(user_claims, evidence_claims, sources): |
|
|
text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}" |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
text, return_tensors="pt", truncation=True, max_length=512 |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item() |
|
|
|
|
|
results.append({ |
|
|
"is_valid": prob >= self.threshold, |
|
|
"probability": prob, |
|
|
"drift_detected": prob < self.threshold |
|
|
}) |
|
|
|
|
|
return results |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Training Data |
|
|
|
|
|
The model was trained on a curated dataset combining multiple sources: |
|
|
|
|
|
| Source | Examples | Description | |
|
|
|--------|----------|-------------| |
|
|
| FEVER | ~8,000 | Fact verification with SUPPORTS/REFUTES labels | |
|
|
| HAGRID | ~2,000 | Attributed QA with quote-based evidence | |
|
|
| ASQA | ~3,000 | Ambiguous questions with long-form answers | |
|
|
|
|
|
**Label Generation (V3 - LLM-Supervised)**: |
|
|
- Training labels verified by GPT-4o-mini ("Does context support claim?") |
|
|
- Evaluation uses independent NLI model (DeBERTa-MNLI) |
|
|
- This breaks circularity: model learns LLM judgment, evaluated by NLI |
|
|
|
|
|
**Data Augmentation**: |
|
|
- **Negative perturbations**: date_change, number_change, entity_swap, false_detail, negation, topic_drift |
|
|
- **Positive perturbations**: paraphrase, synonym_swap, formal_informal register changes |
|
|
|
|
|
### Training Procedure |
|
|
|
|
|
| Hyperparameter | Value | |
|
|
|----------------|-------| |
|
|
| Base model | `microsoft/deberta-v3-base` | |
|
|
| Max sequence length | 512 | |
|
|
| Batch size | 8 | |
|
|
| Gradient accumulation | 2 | |
|
|
| Effective batch size | 16 | |
|
|
| Learning rate | 2e-5 | |
|
|
| Warmup ratio | 0.1 | |
|
|
| Weight decay | 0.01 | |
|
|
| Epochs | 5 | |
|
|
| Early stopping patience | 3 | |
|
|
| FP16 training | Yes | |
|
|
| Optimizer | AdamW | |
|
|
|
|
|
**Training Infrastructure**: |
|
|
- Single GPU (NVIDIA T4/V100) |
|
|
- Training time: ~2-3 hours |
|
|
- Framework: HuggingFace Transformers + PyTorch |
|
|
|
|
|
### Evaluation |
|
|
|
|
|
**Validation Set Performance** (15% held-out, stratified): |
|
|
|
|
|
| Metric | Score | |
|
|
|--------|-------| |
|
|
| Accuracy | 0.87 | |
|
|
| Precision | 0.88 | |
|
|
| Recall | 0.90 | |
|
|
| F1 | 0.89 | |
|
|
| ROC-AUC | 0.94 | |
|
|
|
|
|
**Optimal Threshold**: 0.50 (determined via F1 maximization on validation set) |
|
|
|
|
|
**Performance by Perturbation Type**: |
|
|
|
|
|
| Type | Accuracy | Notes | |
|
|
|------|----------|-------| |
|
|
| original | 0.91 | Clean examples | |
|
|
| paraphrase | 0.88 | Meaning-preserving rewrites | |
|
|
| entity_swap | 0.94 | Wrong person/place/org | |
|
|
| date_change | 0.92 | Incorrect dates | |
|
|
| negation | 0.89 | Reversed claims | |
|
|
| topic_drift | 0.85 | Subtle semantic shifts | |
|
|
|
|
|
## Limitations |
|
|
|
|
|
1. **English only**: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder. |
|
|
|
|
|
2. **RAG-specific**: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks. |
|
|
|
|
|
3. **Passage length**: Max 512 tokens. Long documents require chunking or summarization. |
|
|
|
|
|
4. **Threshold sensitivity**: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds. |
|
|
|
|
|
5. **Training data bias**: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code). |
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
### Intended Benefits |
|
|
- Improved reliability of AI-generated citations |
|
|
- Reduced misinformation from RAG hallucinations |
|
|
- Better transparency in AI-assisted research |
|
|
|
|
|
### Potential Risks |
|
|
- Over-reliance on automated verification (human review still recommended for high-stakes applications) |
|
|
- False negatives may incorrectly flag valid citations |
|
|
- False positives may miss genuine attribution errors |
|
|
|
|
|
### Recommendations |
|
|
- Use as one signal among many, not sole arbiter |
|
|
- Monitor performance on domain-specific data |
|
|
- Combine with human review for critical applications |
|
|
|
|
|
|
|
|
*This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.* |