File size: 2,695 Bytes
b705ef7
68a2e22
 
 
 
 
 
 
 
 
b705ef7
 
68a2e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
language: en
license: apache-2.0
base_model: Falconsai/intent_classification
tags:
  - text-classification
  - multi-label-classification
  - intent-classification
  - distilbert
pipeline_tag: text-classification
---

# Intent Classification — Multi-Label (`web_search` / `diagram_enabled`)

Fine-tuned from [`Falconsai/intent_classification`](https://huggingface.co/Falconsai/intent_classification)
(DistilBERT-base-uncased, apache-2.0) for **multi-label binary intent classification**.
The original 15-class head was replaced with a 2-label sigmoid head trained with `BCEWithLogitsLoss`.

## Labels

| Index | Label | Meaning |
|-------|-------|---------|
| 0 | `web_search` | Query requires a live web search |
| 1 | `diagram_enabled` | Query benefits from a diagram / visualisation |

## Training details

| Setting | Value |
|---------|-------|
| Base model | `Falconsai/intent_classification` (DistilBERT) |
| Problem type | `multi_label_classification` |
| Frozen layers | embeddings + transformer.layer[0-3] |
| Trainable params | ~7M / 67M total (~10%) |
| Classifier dropout | 0.3 |
| Learning rate | 5e-6 |
| Early stopping | patience=3 on eval_loss |
| Threshold floor | 0.30 |
| Max sequence length | 128 |
| Split | 80 / 10 / 10 (train / val / test) |
| Seed | 42 |

## Decision thresholds

| Label | Threshold |
|-------|-----------|
| `web_search` | 0.35 |
| `diagram_enabled` | 0.6 |

> Thresholds are stored in `thresholds.json` **and** embedded in `config.json`
> under `config.thresholds` — no separate download needed.

## Usage

```python
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

REPO = "aitraineracc/intent-classification-multilabel"
tokenizer = AutoTokenizer.from_pretrained(REPO)
model     = AutoModelForSequenceClassification.from_pretrained(REPO)
model.eval()

thresholds = model.config.thresholds  # {'web_search': 0.35, 'diagram_enabled': 0.6}

def predict(text: str) -> dict:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.sigmoid(logits).squeeze().tolist()
    return {
        "web_search"      : int(probs[0] >= thresholds["web_search"]),
        "diagram_enabled" : int(probs[1] >= thresholds["diagram_enabled"]),
        "probs"           : {"web_search": round(probs[0], 4),
                              "diagram_enabled": round(probs[1], 4)},
    }

print(predict("What is the weather today in Singapore?"))
# {'web_search': 1, 'diagram_enabled': 0, 'probs': ...}

print(predict("Draw me a diagram of how TCP/IP works"))
# {'web_search': 0, 'diagram_enabled': 1, 'probs': ...}
```