aitraineracc's picture
Upload README.md with huggingface_hub
68a2e22 verified
---
language: en
license: apache-2.0
base_model: Falconsai/intent_classification
tags:
- text-classification
- multi-label-classification
- intent-classification
- distilbert
pipeline_tag: text-classification
---
# Intent Classification — Multi-Label (`web_search` / `diagram_enabled`)
Fine-tuned from [`Falconsai/intent_classification`](https://huggingface.co/Falconsai/intent_classification)
(DistilBERT-base-uncased, apache-2.0) for **multi-label binary intent classification**.
The original 15-class head was replaced with a 2-label sigmoid head trained with `BCEWithLogitsLoss`.
## Labels
| Index | Label | Meaning |
|-------|-------|---------|
| 0 | `web_search` | Query requires a live web search |
| 1 | `diagram_enabled` | Query benefits from a diagram / visualisation |
## Training details
| Setting | Value |
|---------|-------|
| Base model | `Falconsai/intent_classification` (DistilBERT) |
| Problem type | `multi_label_classification` |
| Frozen layers | embeddings + transformer.layer[0-3] |
| Trainable params | ~7M / 67M total (~10%) |
| Classifier dropout | 0.3 |
| Learning rate | 5e-6 |
| Early stopping | patience=3 on eval_loss |
| Threshold floor | 0.30 |
| Max sequence length | 128 |
| Split | 80 / 10 / 10 (train / val / test) |
| Seed | 42 |
## Decision thresholds
| Label | Threshold |
|-------|-----------|
| `web_search` | 0.35 |
| `diagram_enabled` | 0.6 |
> Thresholds are stored in `thresholds.json` **and** embedded in `config.json`
> under `config.thresholds` — no separate download needed.
## Usage
```python
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
REPO = "aitraineracc/intent-classification-multilabel"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model = AutoModelForSequenceClassification.from_pretrained(REPO)
model.eval()
thresholds = model.config.thresholds # {'web_search': 0.35, 'diagram_enabled': 0.6}
def predict(text: str) -> dict:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().tolist()
return {
"web_search" : int(probs[0] >= thresholds["web_search"]),
"diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]),
"probs" : {"web_search": round(probs[0], 4),
"diagram_enabled": round(probs[1], 4)},
}
print(predict("What is the weather today in Singapore?"))
# {'web_search': 1, 'diagram_enabled': 0, 'probs': ...}
print(predict("Draw me a diagram of how TCP/IP works"))
# {'web_search': 0, 'diagram_enabled': 1, 'probs': ...}
```