Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,142 +1,110 @@
|
|
| 1 |
-
---
|
| 2 |
-
tags:
|
| 3 |
-
- ml-intern
|
| 4 |
-
---
|
| 5 |
# HR Conversations Multi-Label Classifier
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
## Model Details
|
| 10 |
|
| 11 |
| Attribute | Value |
|
| 12 |
|-----------|-------|
|
| 13 |
-
|
|
| 14 |
-
|
|
| 15 |
-
|
|
| 16 |
-
|
|
| 17 |
-
|
|
|
|
|
| 18 |
|
| 19 |
## 20 HR Topic Labels
|
| 20 |
|
| 21 |
-
1. Benefits
|
| 22 |
-
2. Career Development
|
| 23 |
-
3. Compliance & Legal
|
| 24 |
-
4. Contracts
|
| 25 |
-
5. Diversity, Equity & Inclusion
|
| 26 |
-
6. Expense Management
|
| 27 |
-
7. Harassment
|
| 28 |
-
8. Health
|
| 29 |
-
9. IT & Equipment
|
| 30 |
-
10. Leave & Absence
|
| 31 |
-
11. Mobility
|
| 32 |
-
12. Offboarding
|
| 33 |
-
13. Onboarding
|
| 34 |
-
14. Payroll
|
| 35 |
-
15. Performance Management
|
| 36 |
-
16. Recruitment
|
| 37 |
-
17. Safety
|
| 38 |
-
18. Timetracking
|
| 39 |
-
19. Training
|
| 40 |
20. Work Arrangements
|
| 41 |
|
| 42 |
-
##
|
| 43 |
-
|
| 44 |
-
```python
|
| 45 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 46 |
-
import torch
|
| 47 |
-
|
| 48 |
-
model_id = "AurelPx/hr-conversations-classifier"
|
| 49 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 50 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
| 51 |
-
|
| 52 |
-
LABELS = [
|
| 53 |
-
"Benefits", "Career Development", "Compliance & Legal", "Contracts",
|
| 54 |
-
"Diversity, Equity & Inclusion", "Expense Management", "Harassment", "Health",
|
| 55 |
-
"IT & Equipment", "Leave & Absence", "Mobility", "Offboarding",
|
| 56 |
-
"Onboarding", "Payroll", "Performance Management", "Recruitment",
|
| 57 |
-
"Safety", "Timetracking", "Training", "Work Arrangements",
|
| 58 |
-
]
|
| 59 |
-
|
| 60 |
-
def classify(text: str, threshold: float = 0.5):
|
| 61 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 62 |
-
with torch.no_grad():
|
| 63 |
-
logits = model(**inputs).logits
|
| 64 |
-
probs = torch.sigmoid(logits).numpy()[0]
|
| 65 |
-
return [LABELS[i] for i, p in enumerate(probs) if p >= threshold]
|
| 66 |
-
|
| 67 |
-
# Example
|
| 68 |
-
conversation = "USER: I haven't received my payslip for March yet. Could you please check what's going on?"
|
| 69 |
-
print(classify(conversation)) # ['Payroll']
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
## Improve the Model (Colab / Local)
|
| 73 |
-
|
| 74 |
-
The current model was trained on only **100 real conversations** — this is too small for reliable classification. We provide a complete training script that:
|
| 75 |
-
|
| 76 |
-
1. **Generates 5,000 synthetic HR conversations** from templates (no LLM needed)
|
| 77 |
-
2. **Runs 5-fold stratified cross-validation** (no data leakage)
|
| 78 |
-
3. **Fine-tunes DistilBERT** and pushes the best model to Hub
|
| 79 |
-
|
| 80 |
-
### Run in Google Colab (Free T4 GPU)
|
| 81 |
|
| 82 |
```python
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
```
|
| 94 |
|
| 95 |
-
|
| 96 |
-
- Load your dataset from `AurelPx/ml-intern-a2d69eee-datasets`
|
| 97 |
-
- Generate 5,000 synthetic conversations preserving the real distribution
|
| 98 |
-
- Run 5-fold cross-validation, reporting mean ± std F1-micro
|
| 99 |
-
- Push the best model to your Hub
|
| 100 |
|
| 101 |
-
|
| 102 |
-
-
|
| 103 |
-
-
|
| 104 |
|
| 105 |
## Files in this Repo
|
| 106 |
|
| 107 |
| File | Description |
|
| 108 |
|------|-------------|
|
| 109 |
-
| `
|
| 110 |
-
| `
|
| 111 |
-
| `
|
| 112 |
-
| `
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
## Dataset
|
| 115 |
|
| 116 |
- [AurelPx/ml-intern-a2d69eee-datasets](https://huggingface.co/datasets/AurelPx/ml-intern-a2d69eee-datasets)
|
| 117 |
- 100 English HR conversations with multi-label annotations
|
| 118 |
-
- Conversations include Payroll, Benefits, Leave, Contracts, Training, etc.
|
| 119 |
|
| 120 |
## License
|
| 121 |
|
| 122 |
-
Apache 2.0
|
| 123 |
-
|
| 124 |
-
<!-- ml-intern-provenance -->
|
| 125 |
-
## Generated by ML Intern
|
| 126 |
-
|
| 127 |
-
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
|
| 128 |
-
|
| 129 |
-
- Try ML Intern: https://smolagents-ml-intern.hf.space
|
| 130 |
-
- Source code: https://github.com/huggingface/ml-intern
|
| 131 |
-
|
| 132 |
-
## Usage
|
| 133 |
-
|
| 134 |
-
```python
|
| 135 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 136 |
-
|
| 137 |
-
model_id = 'AurelPx/hr-conversations-classifier'
|
| 138 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 139 |
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 140 |
-
```
|
| 141 |
-
|
| 142 |
-
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# HR Conversations Multi-Label Classifier
|
| 2 |
|
| 3 |
+
SETFit-style classifier for **20 HR topic labels** on employee–agent conversations, trained with **5,000 synthetic + 100 real samples** and evaluated via **5-fold stratified cross-validation** (no data leakage).
|
| 4 |
+
|
| 5 |
+
## Results
|
| 6 |
+
|
| 7 |
+
| Metric | Score |
|
| 8 |
+
|--------|-------|
|
| 9 |
+
| **F1-micro (5-fold CV)** | **0.7962 ± 0.0098** |
|
| 10 |
+
| **F1-macro (5-fold CV)** | **0.7721** |
|
| 11 |
+
| Fold 1 | 0.7851 |
|
| 12 |
+
| Fold 2 | 0.7989 |
|
| 13 |
+
| Fold 3 | 0.8031 |
|
| 14 |
+
| Fold 4 | 0.7846 |
|
| 15 |
+
| Fold 5 | **0.8091** |
|
| 16 |
|
| 17 |
## Model Details
|
| 18 |
|
| 19 |
| Attribute | Value |
|
| 20 |
|-----------|-------|
|
| 21 |
+
| Encoder | `sentence-transformers/all-MiniLM-L6-v2` (384-dim) |
|
| 22 |
+
| Classifier | Multi-output Logistic Regression (scikit-learn) |
|
| 23 |
+
| Training samples | 5,100 (5,000 synthetic + 100 real) |
|
| 24 |
+
| Labels | 20 HR topics |
|
| 25 |
+
| Validation | 5-fold stratified cross-validation |
|
| 26 |
+
| Framework | Sentence-Transformers + scikit-learn |
|
| 27 |
|
| 28 |
## 20 HR Topic Labels
|
| 29 |
|
| 30 |
+
1. Benefits
|
| 31 |
+
2. Career Development
|
| 32 |
+
3. Compliance & Legal
|
| 33 |
+
4. Contracts
|
| 34 |
+
5. Diversity, Equity & Inclusion
|
| 35 |
+
6. Expense Management
|
| 36 |
+
7. Harassment
|
| 37 |
+
8. Health
|
| 38 |
+
9. IT & Equipment
|
| 39 |
+
10. Leave & Absence
|
| 40 |
+
11. Mobility
|
| 41 |
+
12. Offboarding
|
| 42 |
+
13. Onboarding
|
| 43 |
+
14. Payroll
|
| 44 |
+
15. Performance Management
|
| 45 |
+
16. Recruitment
|
| 46 |
+
17. Safety
|
| 47 |
+
18. Timetracking
|
| 48 |
+
19. Training
|
| 49 |
20. Work Arrangements
|
| 50 |
|
| 51 |
+
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
```python
|
| 54 |
+
from sentence_transformers import SentenceTransformer
|
| 55 |
+
import pickle, json
|
| 56 |
+
from huggingface_hub import hf_hub_download
|
| 57 |
+
|
| 58 |
+
# Download artifacts
|
| 59 |
+
classifier_path = hf_hub_download("AurelPx/hr-conversations-classifier", "setfit_classifier.pkl")
|
| 60 |
+
label_path = hf_hub_download("AurelPx/hr-conversations-classifier", "setfit_label_config.json")
|
| 61 |
+
|
| 62 |
+
# Load
|
| 63 |
+
encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 64 |
+
with open(classifier_path, 'rb') as f:
|
| 65 |
+
classifier = pickle.load(f)
|
| 66 |
+
with open(label_path) as f:
|
| 67 |
+
config = json.load(f)
|
| 68 |
+
|
| 69 |
+
LABELS = config['label_names']
|
| 70 |
+
|
| 71 |
+
# Classify
|
| 72 |
+
sample = (
|
| 73 |
+
"USER: I haven't received my payslip for March yet. Could you please check what's going on?\n"
|
| 74 |
+
"AGENT: Good morning. I've checked the payroll system and it appears your March payslip "
|
| 75 |
+
"was generated on the 28th but there was a distribution delay. I've resent it to your "
|
| 76 |
+
"registered email. You should receive it within the next hour."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
emb = encoder.encode([sample])
|
| 80 |
+
proba = classifier.predict_proba(emb)
|
| 81 |
+
preds = [LABELS[i] for i, p in enumerate(proba) if p[0][1] >= 0.5]
|
| 82 |
+
print(preds) # ['Payroll']
|
| 83 |
```
|
| 84 |
|
| 85 |
+
## Training Approach
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
1. **Data augmentation** — 5,000 synthetic HR conversations generated from real conversation templates (no LLM, no external API).
|
| 88 |
+
2. **Stratified 5-fold CV** — splits by primary label, ensuring each fold preserves label distribution.
|
| 89 |
+
3. **SETFit-style pipeline** — MiniLM embeddings + Logistic Regression, fast and accurate on small data.
|
| 90 |
|
| 91 |
## Files in this Repo
|
| 92 |
|
| 93 |
| File | Description |
|
| 94 |
|------|-------------|
|
| 95 |
+
| `setfit_classifier.pkl` | Trained Logistic Regression classifier |
|
| 96 |
+
| `setfit_encoder.pkl` | SentenceTransformer MiniLM encoder (optional, for offline use) |
|
| 97 |
+
| `setfit_cv_results.json` | Cross-validation scores per fold |
|
| 98 |
+
| `setfit_label_config.json` | Label names and threshold |
|
| 99 |
+
| `training_script.py` | Full training pipeline (augmentation + CV + inference) |
|
| 100 |
+
| `inference.py` | Standalone inference script (DistilBERT legacy — not recommended) |
|
| 101 |
+
| `model.safetensors` | Legacy DistilBERT checkpoint (kept for compatibility) |
|
| 102 |
|
| 103 |
## Dataset
|
| 104 |
|
| 105 |
- [AurelPx/ml-intern-a2d69eee-datasets](https://huggingface.co/datasets/AurelPx/ml-intern-a2d69eee-datasets)
|
| 106 |
- 100 English HR conversations with multi-label annotations
|
|
|
|
| 107 |
|
| 108 |
## License
|
| 109 |
|
| 110 |
+
Apache 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|