| --- |
| language: en |
| license: apache-2.0 |
| base_model: Falconsai/intent_classification |
| tags: |
| - text-classification |
| - multi-label-classification |
| - intent-classification |
| - distilbert |
| pipeline_tag: text-classification |
| --- |
| |
| # Intent Classification — Multi-Label (`web_search` / `diagram_enabled`) |
|
|
| Fine-tuned from [`Falconsai/intent_classification`](https://huggingface.co/Falconsai/intent_classification) |
| (DistilBERT-base-uncased, apache-2.0) for **multi-label binary intent classification**. |
| The original 15-class head was replaced with a 2-label sigmoid head trained with `BCEWithLogitsLoss`. |
|
|
| ## Labels |
|
|
| | Index | Label | Meaning | |
| |-------|-------|---------| |
| | 0 | `web_search` | Query requires a live web search | |
| | 1 | `diagram_enabled` | Query benefits from a diagram / visualisation | |
|
|
| ## Training details |
|
|
| | Setting | Value | |
| |---------|-------| |
| | Base model | `Falconsai/intent_classification` (DistilBERT) | |
| | Problem type | `multi_label_classification` | |
| | Frozen layers | embeddings + transformer.layer[0-3] | |
| | Trainable params | ~7M / 67M total (~10%) | |
| | Classifier dropout | 0.3 | |
| | Learning rate | 5e-6 | |
| | Early stopping | patience=3 on eval_loss | |
| | Threshold floor | 0.30 | |
| | Max sequence length | 128 | |
| | Split | 80 / 10 / 10 (train / val / test) | |
| | Seed | 42 | |
| |
| ## Decision thresholds |
| |
| | Label | Threshold | |
| |-------|-----------| |
| | `web_search` | 0.35 | |
| | `diagram_enabled` | 0.6 | |
|
|
| > Thresholds are stored in `thresholds.json` **and** embedded in `config.json` |
| > under `config.thresholds` — no separate download needed. |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| |
| REPO = "aitraineracc/intent-classification-multilabel" |
| tokenizer = AutoTokenizer.from_pretrained(REPO) |
| model = AutoModelForSequenceClassification.from_pretrained(REPO) |
| model.eval() |
| |
| thresholds = model.config.thresholds # {'web_search': 0.35, 'diagram_enabled': 0.6} |
| |
| def predict(text: str) -> dict: |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.sigmoid(logits).squeeze().tolist() |
| return { |
| "web_search" : int(probs[0] >= thresholds["web_search"]), |
| "diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]), |
| "probs" : {"web_search": round(probs[0], 4), |
| "diagram_enabled": round(probs[1], 4)}, |
| } |
| |
| print(predict("What is the weather today in Singapore?")) |
| # {'web_search': 1, 'diagram_enabled': 0, 'probs': ...} |
| |
| print(predict("Draw me a diagram of how TCP/IP works")) |
| # {'web_search': 0, 'diagram_enabled': 1, 'probs': ...} |
| ``` |
|
|