AurelPx commited on
Commit
fd4c2cf
·
verified ·
1 Parent(s): f903d42

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -112
README.md CHANGED
@@ -1,142 +1,110 @@
1
- ---
2
- tags:
3
- - ml-intern
4
- ---
5
  # HR Conversations Multi-Label Classifier
6
 
7
- Fine-tuned **DistilBERT-base-uncased** (66M parameters) for multi-label classification of HR support conversations.
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  ## Model Details
10
 
11
  | Attribute | Value |
12
  |-----------|-------|
13
- | Base Model | `distilbert/distilbert-base-uncased` |
14
- | Task | Multi-label text classification |
15
- | Labels | 20 HR topics (see below) |
16
- | Training Data | 100 synthetic HR conversations (first version) |
17
- | Framework | Hugging Face Transformers |
 
18
 
19
  ## 20 HR Topic Labels
20
 
21
- 1. Benefits
22
- 2. Career Development
23
- 3. Compliance & Legal
24
- 4. Contracts
25
- 5. Diversity, Equity & Inclusion
26
- 6. Expense Management
27
- 7. Harassment
28
- 8. Health
29
- 9. IT & Equipment
30
- 10. Leave & Absence
31
- 11. Mobility
32
- 12. Offboarding
33
- 13. Onboarding
34
- 14. Payroll
35
- 15. Performance Management
36
- 16. Recruitment
37
- 17. Safety
38
- 18. Timetracking
39
- 19. Training
40
  20. Work Arrangements
41
 
42
- ## Quick Usage
43
-
44
- ```python
45
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
46
- import torch
47
-
48
- model_id = "AurelPx/hr-conversations-classifier"
49
- tokenizer = AutoTokenizer.from_pretrained(model_id)
50
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
51
-
52
- LABELS = [
53
- "Benefits", "Career Development", "Compliance & Legal", "Contracts",
54
- "Diversity, Equity & Inclusion", "Expense Management", "Harassment", "Health",
55
- "IT & Equipment", "Leave & Absence", "Mobility", "Offboarding",
56
- "Onboarding", "Payroll", "Performance Management", "Recruitment",
57
- "Safety", "Timetracking", "Training", "Work Arrangements",
58
- ]
59
-
60
- def classify(text: str, threshold: float = 0.5):
61
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
62
- with torch.no_grad():
63
- logits = model(**inputs).logits
64
- probs = torch.sigmoid(logits).numpy()[0]
65
- return [LABELS[i] for i, p in enumerate(probs) if p >= threshold]
66
-
67
- # Example
68
- conversation = "USER: I haven't received my payslip for March yet. Could you please check what's going on?"
69
- print(classify(conversation)) # ['Payroll']
70
- ```
71
-
72
- ## Improve the Model (Colab / Local)
73
-
74
- The current model was trained on only **100 real conversations** — this is too small for reliable classification. We provide a complete training script that:
75
-
76
- 1. **Generates 5,000 synthetic HR conversations** from templates (no LLM needed)
77
- 2. **Runs 5-fold stratified cross-validation** (no data leakage)
78
- 3. **Fine-tunes DistilBERT** and pushes the best model to Hub
79
-
80
- ### Run in Google Colab (Free T4 GPU)
81
 
82
  ```python
83
- # Step 1: Install dependencies
84
- !pip install -q transformers datasets accelerate pandas scikit-learn huggingface_hub
85
-
86
- # Step 2: Login to Hugging Face
87
- from huggingface_hub import notebook_login
88
- notebook_login()
89
-
90
- # Step 3: Download and run the training script
91
- !curl -L -o train.py https://huggingface.co/AurelPx/hr-conversations-classifier/raw/main/training_script.py
92
- !python train.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ```
94
 
95
- The script will:
96
- - Load your dataset from `AurelPx/ml-intern-a2d69eee-datasets`
97
- - Generate 5,000 synthetic conversations preserving the real distribution
98
- - Run 5-fold cross-validation, reporting mean ± std F1-micro
99
- - Push the best model to your Hub
100
 
101
- ### Expected Runtime
102
- - CPU (8 cores): ~6-8 hours
103
- - Google Colab T4 GPU: ~30-45 minutes
104
 
105
  ## Files in this Repo
106
 
107
  | File | Description |
108
  |------|-------------|
109
- | `model.safetensors` | Fine-tuned model weights |
110
- | `training_script.py` | Full training pipeline (augmentation + 5-fold CV + fine-tuning) |
111
- | `inference.py` | Standalone inference script |
112
- | `label_config.json` | Label names and classification threshold |
 
 
 
113
 
114
  ## Dataset
115
 
116
  - [AurelPx/ml-intern-a2d69eee-datasets](https://huggingface.co/datasets/AurelPx/ml-intern-a2d69eee-datasets)
117
  - 100 English HR conversations with multi-label annotations
118
- - Conversations include Payroll, Benefits, Leave, Contracts, Training, etc.
119
 
120
  ## License
121
 
122
- Apache 2.0 (same as DistilBERT)
123
-
124
- <!-- ml-intern-provenance -->
125
- ## Generated by ML Intern
126
-
127
- This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
128
-
129
- - Try ML Intern: https://smolagents-ml-intern.hf.space
130
- - Source code: https://github.com/huggingface/ml-intern
131
-
132
- ## Usage
133
-
134
- ```python
135
- from transformers import AutoModelForCausalLM, AutoTokenizer
136
-
137
- model_id = 'AurelPx/hr-conversations-classifier'
138
- tokenizer = AutoTokenizer.from_pretrained(model_id)
139
- model = AutoModelForCausalLM.from_pretrained(model_id)
140
- ```
141
-
142
- For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
 
 
 
 
 
1
  # HR Conversations Multi-Label Classifier
2
 
3
+ SETFit-style classifier for **20 HR topic labels** on employee–agent conversations, trained with **5,000 synthetic + 100 real samples** and evaluated via **5-fold stratified cross-validation** (no data leakage).
4
+
5
+ ## Results
6
+
7
+ | Metric | Score |
8
+ |--------|-------|
9
+ | **F1-micro (5-fold CV)** | **0.7962 ± 0.0098** |
10
+ | **F1-macro (5-fold CV)** | **0.7721** |
11
+ | Fold 1 | 0.7851 |
12
+ | Fold 2 | 0.7989 |
13
+ | Fold 3 | 0.8031 |
14
+ | Fold 4 | 0.7846 |
15
+ | Fold 5 | **0.8091** |
16
 
17
  ## Model Details
18
 
19
  | Attribute | Value |
20
  |-----------|-------|
21
+ | Encoder | `sentence-transformers/all-MiniLM-L6-v2` (384-dim) |
22
+ | Classifier | Multi-output Logistic Regression (scikit-learn) |
23
+ | Training samples | 5,100 (5,000 synthetic + 100 real) |
24
+ | Labels | 20 HR topics |
25
+ | Validation | 5-fold stratified cross-validation |
26
+ | Framework | Sentence-Transformers + scikit-learn |
27
 
28
  ## 20 HR Topic Labels
29
 
30
+ 1. Benefits
31
+ 2. Career Development
32
+ 3. Compliance & Legal
33
+ 4. Contracts
34
+ 5. Diversity, Equity & Inclusion
35
+ 6. Expense Management
36
+ 7. Harassment
37
+ 8. Health
38
+ 9. IT & Equipment
39
+ 10. Leave & Absence
40
+ 11. Mobility
41
+ 12. Offboarding
42
+ 13. Onboarding
43
+ 14. Payroll
44
+ 15. Performance Management
45
+ 16. Recruitment
46
+ 17. Safety
47
+ 18. Timetracking
48
+ 19. Training
49
  20. Work Arrangements
50
 
51
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  ```python
54
+ from sentence_transformers import SentenceTransformer
55
+ import pickle, json
56
+ from huggingface_hub import hf_hub_download
57
+
58
+ # Download artifacts
59
+ classifier_path = hf_hub_download("AurelPx/hr-conversations-classifier", "setfit_classifier.pkl")
60
+ label_path = hf_hub_download("AurelPx/hr-conversations-classifier", "setfit_label_config.json")
61
+
62
+ # Load
63
+ encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
64
+ with open(classifier_path, 'rb') as f:
65
+ classifier = pickle.load(f)
66
+ with open(label_path) as f:
67
+ config = json.load(f)
68
+
69
+ LABELS = config['label_names']
70
+
71
+ # Classify
72
+ sample = (
73
+ "USER: I haven't received my payslip for March yet. Could you please check what's going on?\n"
74
+ "AGENT: Good morning. I've checked the payroll system and it appears your March payslip "
75
+ "was generated on the 28th but there was a distribution delay. I've resent it to your "
76
+ "registered email. You should receive it within the next hour."
77
+ )
78
+
79
+ emb = encoder.encode([sample])
80
+ proba = classifier.predict_proba(emb)
81
+ preds = [LABELS[i] for i, p in enumerate(proba) if p[0][1] >= 0.5]
82
+ print(preds) # ['Payroll']
83
  ```
84
 
85
+ ## Training Approach
 
 
 
 
86
 
87
+ 1. **Data augmentation** — 5,000 synthetic HR conversations generated from real conversation templates (no LLM, no external API).
88
+ 2. **Stratified 5-fold CV** splits by primary label, ensuring each fold preserves label distribution.
89
+ 3. **SETFit-style pipeline** MiniLM embeddings + Logistic Regression, fast and accurate on small data.
90
 
91
  ## Files in this Repo
92
 
93
  | File | Description |
94
  |------|-------------|
95
+ | `setfit_classifier.pkl` | Trained Logistic Regression classifier |
96
+ | `setfit_encoder.pkl` | SentenceTransformer MiniLM encoder (optional, for offline use) |
97
+ | `setfit_cv_results.json` | Cross-validation scores per fold |
98
+ | `setfit_label_config.json` | Label names and threshold |
99
+ | `training_script.py` | Full training pipeline (augmentation + CV + inference) |
100
+ | `inference.py` | Standalone inference script (DistilBERT legacy — not recommended) |
101
+ | `model.safetensors` | Legacy DistilBERT checkpoint (kept for compatibility) |
102
 
103
  ## Dataset
104
 
105
  - [AurelPx/ml-intern-a2d69eee-datasets](https://huggingface.co/datasets/AurelPx/ml-intern-a2d69eee-datasets)
106
  - 100 English HR conversations with multi-label annotations
 
107
 
108
  ## License
109
 
110
+ Apache 2.0