|
|
--- |
|
|
license: other |
|
|
license_name: link-attribution |
|
|
license_link: https://dejanmarketing.com/link-attribution/ |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy |
|
|
- f1 |
|
|
- precision |
|
|
- recall |
|
|
base_model: microsoft/deberta-v3-large |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- grounding |
|
|
- retrieval |
|
|
- LLM-enhancement |
|
|
- DejanAI |
|
|
--- |
|
|
|
|
|
[](https://dejan.ai/blog/grounding-classifier/) |
|
|
|
|
|
# Prompt Grounding Classifier |
|
|
|
|
|
This model predicts whether a prompt **requires grounding** in external sources like web search, databases, or RAG pipelines. |
|
|
|
|
|
It was fine-tuned from [microsoft/deberta-v3-large](https://huggingface.co/microsoft/deberta-v3-large) using binary labels: |
|
|
|
|
|
- `1` = grounding required |
|
|
- `0` = grounding not required |
|
|
|
|
|
--- |
|
|
|
|
|
## 🚀 Use Case |
|
|
|
|
|
This classifier acts as a **routing layer** in an LLM pipeline, helping decide: |
|
|
|
|
|
- When to trigger retrieval |
|
|
- When to let the model respond from internal knowledge |
|
|
- How to optimize for latency and cost |
|
|
|
|
|
--- |
|
|
|
|
|
## 📦 Training Details |
|
|
|
|
|
- Model: DeBERTa v3 Large |
|
|
- Fine-tuning: Full (no adapters) |
|
|
- Dropout: 0.1 |
|
|
- Scheduler: Cosine with warmup |
|
|
- Batch size: 24 (accumulated) |
|
|
- Evaluation: every 500 steps |
|
|
- Metric used for best checkpoint: F1 |
|
|
|
|
|
--- |
|
|
|
|
|
## 🧪 Example Predictions |
|
|
|
|
|
| Prompt | Grounding | Confidence | |
|
|
|---------------------------------------------------------|-----------|------------| |
|
|
| What’s the exchange rate for USD to Yen right now? | 1 | 0.999 | |
|
|
| Tell me a bedtime story about a robot and a dragon. | 0 | 0.9961 | |
|
|
| Who is the current CEO of Microsoft? | 1 | 0.9986 | |
|
|
|
|
|
--- |
|
|
|
|
|
## 🧠 How to Use |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch.nn.functional as F |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("dejanseo/query-grounding") |
|
|
tokenizer = AutoTokenizer.from_pretrained("dejanseo/query-grounding") |
|
|
|
|
|
prompt = "What time is the next train from Tokyo to Osaka?" |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
outputs = model(**inputs).logits |
|
|
probs = F.softmax(outputs, dim=-1) |
|
|
label = probs.argmax().item() |
|
|
confidence = probs[0][label].item() |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 🧾 Dataset Origin |
|
|
|
|
|
Prompts were collected using a Gemini 2.5 Pro + Google Search toolchain with grounding enabled. Each prompt's response was parsed to extract Gemini's grounding confidence, used as soft supervision for binary labeling: |
|
|
|
|
|
- Label 1 if grounded confidence present |
|
|
- Label 0 if response required no external evidence |
|
|
|