SaiCharan7829 commited on
Commit
38b2bbe
·
verified ·
1 Parent(s): ad5c50b

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,98 +1,220 @@
1
  ---
2
  language: en
3
- license: mit
4
  tags:
5
  - text-classification
 
 
6
  - distilbert
7
- - query-classification
8
  - pytorch
9
  datasets:
10
- - synthetic
11
  metrics:
12
  - accuracy
13
  - f1
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ---
15
 
16
- # Query Classification Model
 
 
17
 
18
  ## Model Description
19
 
20
- This is a fine-tuned DistilBERT-base-uncased model for classifying user queries into 4 categories: basic_actions, script_writing, information, conversation.
 
 
 
 
 
21
 
22
- ## Intended Uses & Limitations
23
 
24
- ### Intended Uses
25
- - Classifying user queries for routing to appropriate handlers
26
- - Chatbot query categorization
27
- - Automated response systems
28
 
29
- ### Limitations
30
- - Trained on synthetic data, may not generalize to all real-world prompts
31
- - Limited to 4 predefined categories
32
- - English language only
33
 
34
- ## Training Details
 
 
 
 
35
 
36
- ### Training Data
37
- - 640 synthetic queries (scaled from 32,500 target)
38
- - Augmented with synonyms, paraphrasing, room variations
39
- - Deduplicated and filtered (3-50 words)
40
- - Format: JSONL with "context" and "output" fields
41
- - Split: 28 train, 6 validation, 40 test
42
-
43
- ### Training Procedure
44
- - Base model: distilbert-base-uncased (66M parameters)
45
- - Task: Sequence Classification (4 classes)
46
- - Fine-tuning: 3 epochs
47
- - Learning rate: 2e-5
48
- - Batch size: 1 (gradient accumulation 4)
49
- - Optimizer: AdamW
50
-
51
- ### Training Logs
52
- - Epoch 1: Eval Loss 1.37, Accuracy 0.17, F1 0.05
53
- - Epoch 2: Eval Loss 1.35, Accuracy 0.67, F1 0.54
54
- - Epoch 3: Eval Loss 1.32, Accuracy 0.50, F1 0.33
55
 
56
  ## Performance
57
 
58
- | Metric | Value |
59
- |--------|-------|
60
- | Accuracy | 67% |
61
- | F1 Score | 54% |
 
 
 
 
 
 
 
 
62
 
63
- ## How to Use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  ```python
66
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
67
  import torch
68
 
69
- tokenizer = AutoTokenizer.from_pretrained("SaiCharan7829/query_classification-distilBERT-66M")
70
- model = AutoModelForSequenceClassification.from_pretrained("SaiCharan7829/query_classification-distilBERT-66M")
71
-
72
- categories = ["basic_actions", "script_writing", "information", "conversation"]
73
 
74
- prompt = "Turn on the lights in the living room"
 
 
75
 
76
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512)
77
  with torch.no_grad():
78
  outputs = model(**inputs)
79
- predicted_class = torch.argmax(outputs.logits).item()
80
- print(f"Predicted Category: {categories[predicted_class]}")
 
 
 
 
 
81
  ```
82
 
83
- ## Model Files
 
 
 
 
 
 
 
 
 
 
84
 
85
- - `model.safetensors`: Model weights
86
- - `config.json`: Model configuration
87
- - `tokenizer.json`: Tokenizer files
88
- - `vocab.txt`: Vocabulary
89
- - `special_tokens_map.json`: Special tokens
90
- - `training_args.bin`: Training arguments
91
 
92
- ## Dataset
93
 
94
- The synthetic dataset is included as `train_data.jsonl`, `val_data.jsonl`, `test_data.jsonl`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  ## License
97
 
98
- MIT License
 
 
 
 
 
1
  ---
2
  language: en
3
+ license: apache-2.0
4
  tags:
5
  - text-classification
6
+ - intent-classification
7
+ - task-routing
8
  - distilbert
 
9
  - pytorch
10
  datasets:
11
+ - custom
12
  metrics:
13
  - accuracy
14
  - f1
15
+ model-index:
16
+ - name: query_classification-distilBERT-66M
17
+ results:
18
+ - task:
19
+ type: text-classification
20
+ name: Intent Classification
21
+ metrics:
22
+ - type: accuracy
23
+ value: 98.03
24
+ name: Test Accuracy
25
+ - type: f1
26
+ value: 98.03
27
+ name: F1 Score (Weighted)
28
  ---
29
 
30
+ # DistilBERT Task Router - Query Classification Model (V5)
31
+
32
+ A high-performance intent classification model based on DistilBERT, fine-tuned to classify user queries into 5 categories with **98.03% accuracy** on a challenging test set of 7,320 samples.
33
 
34
  ## Model Description
35
 
36
+ - **Base Model:** distilbert-base-uncased (66M parameters)
37
+ - **Task:** Multi-class text classification (5 categories)
38
+ - **Language:** English
39
+ - **Training Data:** 58,560 samples (custom generated)
40
+ - **Test Accuracy:** **98.03%** ✓
41
+ - **Inference Speed:** ~3ms average latency
42
 
43
+ ## Categories
44
 
45
+ This model classifies text into 5 intent categories:
 
 
 
46
 
47
+ 1. **basic_actions** - One-time, immediate commands
48
+ - Examples: "Turn on the lights", "Set temperature to 22 degrees", "Play music"
 
 
49
 
50
+ 2. **automator** - Recurring, scheduled, or conditional automations
51
+ - Examples: "Turn on lights every day at 6pm", "AC on if temperature > 28", "Every morning at 8am, start coffee"
52
+
53
+ 3. **information** - Educational, factual, or informational queries
54
+ - Examples: "What is quantum computing?", "How does photosynthesis work?", "What's the weather?"
55
 
56
+ 4. **conversation** - Social interactions and casual chat
57
+ - Examples: "Hello", "How are you?", "Good morning", "Nice to meet you"
58
+
59
+ 5. **irrelevant** - Abusive, meaningless, or off-topic content
60
+ - Examples: "asdfghjkl", "You're stupid", "Random gibberish"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  ## Performance
63
 
64
+ ### Test Set Results (7,320 samples)
65
+
66
+ | Category | Precision | Recall | F1-Score | Support |
67
+ |----------------|-----------|---------|----------|---------|
68
+ | basic_actions | 95.92% | 100.00% | 97.92% | 1,833 |
69
+ | automator | 100.00% | 94.50% | 97.17% | 1,418 |
70
+ | information | 100.00% | 95.39% | 97.64% | 1,432 |
71
+ | conversation | 100.00% | 100.00% | 100.00% | 1,456 |
72
+ | irrelevant | 94.71% | 100.00% | 97.28% | 1,181 |
73
+ | **Overall** | **98.12%**| **98.03%** | **98.03%** | **7,320** |
74
+
75
+ ### Key Metrics
76
 
77
+ - **Accuracy:** 98.03%
78
+ - **F1 Score (Weighted):** 98.03%
79
+ - **F1 Score (Macro):** 98.00%
80
+ - **Error Rate:** 1.97% (144 errors / 7,320 samples)
81
+
82
+ ### Latency
83
+
84
+ - **Average:** 2.91ms
85
+ - **Median:** 2.80ms
86
+ - **P95:** 3.36ms
87
+ - **P99:** 3.88ms
88
+
89
+ ## Usage
90
+
91
+ ### Quick Start
92
 
93
  ```python
94
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
95
  import torch
96
 
97
+ # Load model and tokenizer
98
+ model_name = "SaiCharan7829/query_classification-distilBERT-66M"
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
100
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
101
 
102
+ # Prepare input
103
+ text = "Turn on the lights every evening at 6pm"
104
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
105
 
106
+ # Get prediction
107
  with torch.no_grad():
108
  outputs = model(**inputs)
109
+ logits = outputs.logits
110
+ predicted_class = torch.argmax(logits, dim=1).item()
111
+
112
+ # Categories mapping
113
+ categories = ["basic_actions", "automator", "information", "conversation", "irrelevant"]
114
+ print(f"Predicted category: {categories[predicted_class]}")
115
+ # Output: Predicted category: automator
116
  ```
117
 
118
+ ### With Confidence Scores
119
+
120
+ ```python
121
+ import torch.nn.functional as F
122
+
123
+ # Get probabilities
124
+ probs = F.softmax(logits, dim=1)[0]
125
+ confidence = probs[predicted_class].item()
126
+
127
+ print(f"Category: {categories[predicted_class]}")
128
+ print(f"Confidence: {confidence:.2%}")
129
 
130
+ # Show all probabilities
131
+ for i, category in enumerate(categories):
132
+ print(f"{category}: {probs[i].item():.2%}")
133
+ ```
 
 
134
 
135
+ ## Training Details
136
 
137
+ ### Training Hyperparameters
138
+
139
+ - **Epochs:** 30
140
+ - **Batch Size:** 64 (effective, with gradient accumulation)
141
+ - **Learning Rate:** 2e-5
142
+ - **Warmup Steps:** 500
143
+ - **Weight Decay:** 0.01
144
+ - **Label Smoothing:** 0.1
145
+ - **Learning Rate Schedule:** Cosine with warmup
146
+ - **Optimizer:** AdamW
147
+ - **Class Weights:** Applied (automator: 1.31x, basic_actions: 1.48x, irrelevant: 0.98x)
148
+
149
+ ### Dataset
150
+
151
+ - **Training Samples:** 58,560
152
+ - **Validation Samples:** 7,320
153
+ - **Test Samples:** 7,320
154
+ - **Data Split:** 80% / 10% / 10%
155
+
156
+ **Distribution:**
157
+ - basic_actions: 24.4% (15,000 samples with 40% short commands)
158
+ - automator: 19.8%
159
+ - information: 19.7%
160
+ - conversation: 19.8%
161
+ - irrelevant: 16.4%
162
+
163
+ ### Training Infrastructure
164
+
165
+ - **Framework:** Transformers 4.x, PyTorch 2.x
166
+ - **Device:** Apple Silicon (MPS)
167
+ - **Precision:** FP32
168
+
169
+ ## Limitations & Biases
170
+
171
+ - The model is trained on English text only
172
+ - Performance may degrade on domain-specific jargon not seen during training
173
+ - Short ambiguous commands (1-2 words) may have lower confidence
174
+ - The "irrelevant" category includes abusive content, which may reflect biases in training data
175
+
176
+ ## Intended Use
177
+
178
+ This model is designed for:
179
+ - Smart home assistants and IoT platforms
180
+ - Chatbot intent classification
181
+ - Task routing and workflow automation
182
+ - Virtual assistant command parsing
183
+
184
+ **Not recommended for:**
185
+ - Sensitive content moderation (use dedicated safety models)
186
+ - Medical or legal decision-making
187
+ - Financial advice classification
188
+
189
+ ## Version History
190
+
191
+ ### v5 (Current) - November 2024
192
+ - **Accuracy:** 98.03% (test set)
193
+ - Major improvements to basic_actions recall (100%)
194
+ - Optimized class weights based on error analysis
195
+ - Enhanced dataset with better short command coverage
196
+
197
+ ### v4
198
+ - **Accuracy:** 94.86% (test set)
199
+ - Initial release with 72k training samples
200
+ - Identified issues with short command classification
201
+
202
+ ## Citation
203
+
204
+ ```bibtex
205
+ @misc{query_classification_distilbert_2024,
206
+ author = {SaiCharan7829},
207
+ title = {DistilBERT Task Router - Query Classification Model},
208
+ year = {2024},
209
+ publisher = {HuggingFace},
210
+ howpublished = {\url{https://huggingface.co/SaiCharan7829/query_classification-distilBERT-66M}}
211
+ }
212
+ ```
213
 
214
  ## License
215
 
216
+ Apache 2.0
217
+
218
+ ## Model Card Authors
219
+
220
+ SaiCharan7829
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.3,
7
+ "dim": 768,
8
+ "dropout": 0.3,
9
+ "dtype": "float32",
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3",
16
+ "4": "LABEL_4"
17
+ },
18
+ "initializer_range": 0.02,
19
+ "label2id": {
20
+ "LABEL_0": 0,
21
+ "LABEL_1": 1,
22
+ "LABEL_2": 2,
23
+ "LABEL_3": 3,
24
+ "LABEL_4": 4
25
+ },
26
+ "max_position_embeddings": 512,
27
+ "model_type": "distilbert",
28
+ "n_heads": 12,
29
+ "n_layers": 6,
30
+ "pad_token_id": 0,
31
+ "qa_dropout": 0.1,
32
+ "seq_classif_dropout": 0.3,
33
+ "sinusoidal_pos_embds": false,
34
+ "tie_weights_": true,
35
+ "transformers_version": "4.57.1",
36
+ "vocab_size": 30522
37
+ }
label_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "basic_actions": 0,
3
+ "automator": 1,
4
+ "information": 2,
5
+ "conversation": 3,
6
+ "irrelevant": 4
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e646c4db75cca299970ef495042c3ddfa322737f8fa3466b240384bedaabd3f
3
+ size 267841796
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:346d95449450462a9af6a37abd6886f382d3c25debf1065d0ae773f601d4a53c
3
+ size 5841
vocab.txt ADDED
The diff for this file is too large to render. See raw diff