--- language: en license: apache-2.0 base_model: Falconsai/intent_classification tags: - text-classification - multi-label-classification - intent-classification - distilbert pipeline_tag: text-classification --- # Intent Classification — Multi-Label (`web_search` / `diagram_enabled`) Fine-tuned from [`Falconsai/intent_classification`](https://huggingface.co/Falconsai/intent_classification) (DistilBERT-base-uncased, apache-2.0) for **multi-label binary intent classification**. The original 15-class head was replaced with a 2-label sigmoid head trained with `BCEWithLogitsLoss`. ## Labels | Index | Label | Meaning | |-------|-------|---------| | 0 | `web_search` | Query requires a live web search | | 1 | `diagram_enabled` | Query benefits from a diagram / visualisation | ## Training details | Setting | Value | |---------|-------| | Base model | `Falconsai/intent_classification` (DistilBERT) | | Problem type | `multi_label_classification` | | Frozen layers | embeddings + transformer.layer[0-3] | | Trainable params | ~7M / 67M total (~10%) | | Classifier dropout | 0.3 | | Learning rate | 5e-6 | | Early stopping | patience=3 on eval_loss | | Threshold floor | 0.30 | | Max sequence length | 128 | | Split | 80 / 10 / 10 (train / val / test) | | Seed | 42 | ## Decision thresholds | Label | Threshold | |-------|-----------| | `web_search` | 0.35 | | `diagram_enabled` | 0.6 | > Thresholds are stored in `thresholds.json` **and** embedded in `config.json` > under `config.thresholds` — no separate download needed. ## Usage ```python import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification REPO = "aitraineracc/intent-classification-multilabel" tokenizer = AutoTokenizer.from_pretrained(REPO) model = AutoModelForSequenceClassification.from_pretrained(REPO) model.eval() thresholds = model.config.thresholds # {'web_search': 0.35, 'diagram_enabled': 0.6} def predict(text: str) -> dict: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) with torch.no_grad(): logits = model(**inputs).logits probs = torch.sigmoid(logits).squeeze().tolist() return { "web_search" : int(probs[0] >= thresholds["web_search"]), "diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]), "probs" : {"web_search": round(probs[0], 4), "diagram_enabled": round(probs[1], 4)}, } print(predict("What is the weather today in Singapore?")) # {'web_search': 1, 'diagram_enabled': 0, 'probs': ...} print(predict("Draw me a diagram of how TCP/IP works")) # {'web_search': 0, 'diagram_enabled': 1, 'probs': ...} ```