File size: 26,376 Bytes
95d3457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
#!/usr/bin/env python3 -u
"""
HR Conversations: data augmentation + 5-fold stratified cross-validation + fine-tuning.
Copy this script and run it in Google Colab (free T4) or any Python environment.
"""
import os, json, re, random, numpy as np, pandas as pd, torch
from collections import defaultdict, Counter
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, hamming_loss
from datasets import Dataset, DatasetDict, Sequence, Value
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorWithPadding
)
from huggingface_hub import hf_hub_download

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

MODEL_ID = "distilbert/distilbert-base-uncased"
OUTPUT_DIR = "./hr-distilbert-cv"
HUB_MODEL_ID = "AurelPx/hr-conversations-classifier"
NUM_SYNTHETIC = 5000
MAX_LENGTH = 512
BATCH_SIZE = 16
LR = 3e-5
WEIGHT_DECAY = 0.1
EPOCHS = 4
N_FOLDS = 5

csv_path = hf_hub_download(
    repo_id="AurelPx/ml-intern-a2d69eee-datasets",
    filename="uploads/76ee47c7699c/HRDatasetConv-English.csv",
    repo_type="dataset",
)
df_real = pd.read_csv(csv_path, sep=";", encoding="utf-8-sig", quoting=1)
df_real["label_list"] = df_real["labels"].apply(lambda x: [l.strip() for l in str(x).split("|")])
print(f"Real samples: {len(df_real)}")

# Extract reusable parts
user_messages_by_topic = defaultdict(list)
agent_messages_by_topic = defaultdict(list)
for _, row in df_real.iterrows():
    labels = row["label_list"]
    conv = str(row["conversation"])
    turns = re.split(r'(USER:|AGENT:)', conv)
    for i in range(1, len(turns), 2):
        speaker = turns[i].strip(":")
        msg = turns[i+1].strip() if i+1 < len(turns) else ""
        primary = labels[0]
        if speaker == "USER":
            msg_clean = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "{NAME}", msg)
            msg_clean = re.sub(r"\b[A-Z][a-z]+\b", "{FIRSTNAME}", msg_clean, count=1)
            msg_clean = re.sub(r"\b\d{4,}\b", "{NUM}", msg_clean)
            msg_clean = re.sub(r"\b\d{1,2} [A-Za-z]+ \d{4}\b", "{DATE}", msg_clean)
            user_messages_by_topic[primary].append(msg_clean)
        elif speaker == "AGENT":
            agent_messages_by_topic[primary].append(msg)

ALL_LABELS = [
    "Benefits", "Career Development", "Compliance & Legal", "Contracts",
    "Diversity, Equity & Inclusion", "Expense Management", "Harassment", "Health",
    "IT & Equipment", "Leave & Absence", "Mobility", "Offboarding",
    "Onboarding", "Payroll", "Performance Management", "Recruitment",
    "Safety", "Timetracking", "Training", "Work Arrangements",
]

# Template pools
FIRSTNAMES = ["Emma","Lucas","Sophie","Thomas","Léa","Alexandre","Julia","Maxime","Camille","Nathan","Chloé","Antoine","Manon","Louis","Sarah","Hugo","Zoé","Gabriel","Inès","Raphaël"]
LASTNAMES = ["Martin","Bernard","Dubois","Thomas","Robert","Petit","Durand","Leroy","Moreau","Simon","Laurent","Lefebvre","Michel","Garcia","Roux","Bonnet","AndrΓ©","FranΓ§ois","Mercier","Dupont"]

TOPIC_POOLS = {
    "Payroll": {
        "user": ["My payslip for {MONTH} seems incorrect. Can you check?", "I haven't received my salary for {MONTH}. What's happening?", "There's a {NUM} EUR deduction on my payslip I don't understand.", "When is the next payroll run?", "I need to update my bank details for salary payments.", "My withholding tax rate looks wrong on the latest payslip."],
        "agent": ["Let me pull up your payroll record. I see the issue β€” {EXPLANATION}.", "The payroll for {MONTH} was processed on {DATE}. It should arrive within 24 hours.", "The deduction corresponds to {REASON}. This was communicated on {DATE}.", "You can update your banking info in the HR portal.", "I'll flag this for the payroll team and you'll receive a corrected statement."],
    },
    "Benefits": {
        "user": ["Does the company offer gym membership as part of benefits?", "I want to understand the retirement savings plan.", "How does health insurance coverage work during maternity leave?", "I opted out of meal vouchers but still see a deduction.", "What wellness benefits are available?"],
        "agent": ["Yes, we partner with {PARTNER}. The company subsidizes {PERCENT}% up to {AMOUNT} EUR/month.", "The PERCO allows voluntary contributions. The company matches up to {AMOUNT} EUR/year.", "Your health insurance remains active during leave. Meal vouchers pause for non-worked days.", "Your opt-out was processed after the payroll cutoff. The overcharge will be reimbursed.", "We offer {LIST} through the benefits portal."],
    },
    "Leave & Absence": {
        "user": ["I need to take sick leave starting today.", "I'd like to understand the rules around parental leave in France.", "My leave balance is wrong. It should be {NUM} days, not {NUM}.", "Can I split my paternity leave into two periods?", "I'm going on maternity leave in {NUM} months. How does coverage work?"],
        "agent": ["Please submit a sick leave request in the HR portal and upload your medical certificate within 48 hours.", "Under French law, the second parent is entitled to {DAYS} calendar days.", "I can see a {NUM}-day absence was double-counted. I'll submit the correction.", "Yes, the first {NUM} days must be taken immediately after birth leave.", "Your health insurance remains fully active. Meal vouchers pause for non-worked days."],
    },
    "Contracts": {
        "user": ["I would like a copy of my current employment contract.", "My fixed-term contract ends in {NUM} months. Will it be renewed?", "I signed an amendment for a raise but my salary hasn't changed.", "What does the non-compete clause in my contract mean?", "I'm transferring to the London office. Do I get a new contract?"],
        "agent": ["I'll send your current contract to your registered email as a PDF within the hour.", "The renewal decision is expected by {DATE}. Your manager will discuss it with you.", "The effective date on the amendment is {DATE}, not the signing date.", "Your non-compete is {MONTHS} months limited to {SECTOR} within {COUNTRY}. Monthly compensation is {PERCENT}% of gross.", "For international transfers, you'll receive a new local contract under local law."],
    },
    "Recruitment": {
        "user": ["I applied for the Data Scientist position {NUM} weeks ago. Any update?", "I got my offer letter. I have questions before signing.", "As a hiring manager, how do I open a new headcount?", "I saw an internal posting for ML Engineer. Can I apply without all qualifications?"],
        "agent": ["Your application is under review. You should hear back within {NUM} business days.", "The probation period is standard and not typically negotiated.", "Submit a headcount request through the Talent portal. Approval takes {NUM}-{NUM} business days.", "If you meet 60-70% of requirements, you should apply. We value growth potential."],
    },
    "Training": {
        "user": ["I'd like to enroll in the AWS certification. Is it covered?", "My manager suggested I work on leadership skills. Any programs?", "I requested a Python course {NUM} weeks ago and still haven't heard back.", "Can I use my CPF for company-related training?", "Is there a structured training plan for new joiners?"],
        "agent": ["Yes, cloud certifications are fully covered. Submit a request through the portal.", "We offer a 6-month Leadership Accelerator. Your manager can nominate you.", "Your request got stuck at budget validation. I've escalated it.", "Yes, you can combine CPF credits with company budget for more expensive programs.", "Every new joiner should have a 30-60-90 day onboarding plan."],
    },
    "Performance Management": {
        "user": ["When is the next performance review cycle?", "My rating was 'Exceeds Expectations' but I didn't get a raise.", "My manager put me on a PIP. Am I being fired?", "I want to dispute my performance review.", "I've been rated top performer for {NUM} years. Can I move to a senior role?"],
        "agent": ["The next cycle opens on {DATE}. Self-assessments due by {DATE}.", "Performance ratings are one input into compensation.", "A PIP is not termination β€” it's structured support lasting {MONTHS} months.", "You can submit a written appeal through the HR portal within 15 days.", "With your track record, you're well-positioned for internal mobility."],
    },
    "Onboarding": {
        "user": ["I'm joining next week. What do I need to bring?", "I finished my first week. When do I get my first payslip?", "I'm onboarding a new team member next Monday. Is there a manager checklist?", "I'm starting as a Legal Counsel. Are there specific documents I need?"],
        "agent": ["Bring a valid photo ID and your signed contract. Arrive at 9 AM at reception.", "First payslip at the end of your first full month. Virtual meal card is active now.", "Yes β€” the Manager Onboarding Checklist covers IT setup, access requests, 30-60-90 plan.", "We need your valid work permit, passport, and proof of address."],
    },
    "Compliance & Legal": {
        "user": ["I need to complete the GDPR refresher training. Is it mandatory?", "A colleague is sharing confidential salary info externally. What should I do?", "I received an email asking for a criminal background check. Is this legal?", "I want to understand my rights regarding a sabbatical leave.", "A vendor asked me to share internal process documentation. Is this allowed?"],
        "agent": ["Yes, it's mandatory annually. Takes about 20 minutes.", "Please forward evidence to ethics@company.com. You're protected under our whistleblower policy.", "For certain roles, a background check may be requested under French law, but only after a conditional offer.", "With {NUM} years seniority, you're eligible for a sabbatical of 6-11 months.", "Internal documentation cannot be shared without approval from the Data Protection Officer."],
    },
    "Mobility": {
        "user": ["I'm interested in transferring from Engineering to Product. Is this common?", "I'm moving to the Tokyo office in September. How does salary work?", "I'm moving from Madrid to Paris HQ. Will my seniority carry over?", "I've been offered a {MONTHS}-month project in Dubai. Any tax implications?"],
        "agent": ["About 20% of open roles are filled internally each year.", "You'll receive a split salary: base in EUR, local allowance in JPY.", "Seniority recognition depends on the transfer agreement.", "For a {MONTHS}-month assignment, you remain on French payroll. UAE has no income tax."],
    },
    "Harassment": {
        "user": ["I need to file a harassment complaint. What documentation is required?", "If the behavior happened in virtual meetings, how do I document it?", "What protections exist against retaliation after filing?", "Who has access to my complaint details during the investigation?"],
        "agent": ["Submit a written statement with dates, times, locations, descriptions.", "Document the platform, meeting ID, timestamp, participant list, chat logs.", "Retaliation is strictly prohibited. HR monitors for adverse actions.", "Only assigned HR investigators, their supervisor, and legal counsel if required."],
    },
    "Career Development": {
        "user": ["I want to pursue an AWS certification. Is it reimbursed?", "My manager said I need leadership skills for a team lead role.", "Can I use my CPF for an MBA part-time?", "How do I request a learning plan for ML training?"],
        "agent": ["Senior-level certifications are reimbursed up to {AMOUNT} USD per attempt.", "We offer a 6-month Leadership Accelerator. Your manager submits a nomination.", "An MBA would need explicit approval from your HRBP and department head.", "Request a learning plan through the Training portal."],
    },
    "IT & Equipment": {
        "user": ["My laptop keeps crashing. Can I get a replacement?", "I need a second monitor for my home office.", "I forgot my VPN password. How do I reset it?", "What software licenses are available for developers?"],
        "agent": ["Please open an IT ticket.", "Submit an equipment request through the IT portal.", "Use the self-service password reset portal.", "Developers have access to JetBrains, GitHub Pro, Docker Desktop."],
    },
    "Expense Management": {
        "user": ["I submitted an expense report {NUM} weeks ago and still haven't been reimbursed.", "What's the policy on business travel expenses?", "Can I expense my monthly public transport pass?", "I lost the receipt for a client dinner. Can I still claim it?"],
        "agent": ["Let me check the status. It appears stuck at manager approval.", "Business travel covers transport, accommodation up to {AMOUNT} EUR/night, and meals up to {AMOUNT} EUR/day.", "Yes, monthly transport passes are fully reimbursable.", "Without a receipt, the claim may be rejected per policy."],
    },
    "Health": {
        "user": ["I'm going on long-term sick leave. Will this affect my health insurance?", "How do I access the Employee Assistance Programme for mental health support?", "I need a medical certificate for my sick leave. What format is required?"],
        "agent": ["Your health insurance remains fully active during sick leave.", "The EAP provides confidential counseling at no cost.", "A standard medical certificate from any licensed physician is sufficient."],
    },
    "Safety": {
        "user": ["There's a loose cable in the open space that someone could trip on.", "What should I do if there's a fire alarm during work hours?", "I noticed the emergency exit on floor {NUM} is blocked. Who should I report this to?"],
        "agent": ["Please report it immediately through the Safety Portal.", "Follow the evacuation plan. Assembly point is in the parking lot.", "Report blocked exits to facilities@company.com and your floor safety marshal."],
    },
    "Offboarding": {
        "user": ["I'm resigning and my last day is in {NUM} weeks. What do I need to do?", "Will I be paid for my unused vacation days?", "Do I need to return my laptop and badge?", "Can I keep my company email address after leaving?"],
        "agent": ["HR will schedule an offboarding checklist meeting.", "Yes, accrued but unused leave is paid out in your final settlement.", "All IT equipment, access badges, and parking passes must be returned.", "Company email addresses are deactivated on the last working day."],
    },
    "Work Arrangements": {
        "user": ["Can I work from home {NUM} days a week?", "What's the policy on flexible working hours?", "I want to switch to part-time (80%). What's the process?"],
        "agent": ["The hybrid policy allows up to {NUM} remote days per week.", "Core hours are 10:00-16:00. Outside that, you can flex start and end times.", "Submit a part-time request through the HR portal at least 1 month in advance."],
    },
    "Timetracking": {
        "user": ["How do I log overtime hours in the system?", "My timesheet for last week was rejected. What should I fix?", "Are breaks automatically deducted from my tracked hours?"],
        "agent": ["Use the 'Overtime' category in the time tracking tool.", "Common rejections: missing project codes, incorrect dates.", "Yes, a 30-minute lunch break is auto-deducted for any workday over 6 hours."],
    },
    "Diversity, Equity & Inclusion": {
        "user": ["Does the company have employee resource groups?", "I experienced age discrimination in my performance review. What can I do?", "Are diversity metrics published in the annual report?"],
        "agent": ["Yes, we have active ERGs for women in tech, LGBTQ+, parents.", "Age discrimination is prohibited. You can file a formal complaint through the Ethics hotline.", "Yes, our annual report includes diversity metrics."],
    },
}

for label in ALL_LABELS:
    if label not in TOPIC_POOLS:
        TOPIC_POOLS[label] = {
            "user": [f"I have a question about {label.lower()}.", f"Can you help me with {label.lower()}?", f"I need information regarding {label.lower()}."],
            "agent": [f"I'll look into your {label.lower()} query right away.", f"Here's what you need to know about {label.lower()}.", f"Let me pull up the {label.lower()} policy for you."],
        }

# ── Synthetic generation ──────────────────────────────────────────────
seen_combos = Counter(tuple(sorted(l)) for l in df_real["label_list"])

def generate_synthetic(n=5000):
    synth = []
    for _ in range(n):
        primary = random.choices(list(seen_combos.keys()), weights=list(seen_combos.values()), k=1)[0]
        if random.random() < 0.53 and len(primary) > 1:
            labels = list(primary)
        else:
            labels = [random.choice(primary)]
        if random.random() < 0.2:
            extra = random.choice(ALL_LABELS)
            if extra not in labels:
                labels.append(extra)
        primary_label = labels[0]
        pool = TOPIC_POOLS[primary_label]
        num_turns = random.randint(1, 3)
        conv_parts = []
        for _ in range(num_turns):
            u_msg = random.choice(pool["user"])
            a_msg = random.choice(pool["agent"])
            for old, new_func in {
                "{NAME}": lambda: random.choice(FIRSTNAMES) + " " + random.choice(LASTNAMES),
                "{FIRSTNAME}": lambda: random.choice(FIRSTNAMES),
                "{NUM}": lambda: str(random.randint(2, 30)),
                "{DATE}": lambda: f"{random.randint(1,28)} {random.choice(['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'])} 2025",
                "{MONTH}": lambda: random.choice(['January','February','March','April','May','June','July','August','September','October','November','December']),
                "{AMOUNT}": lambda: str(random.choice([100,200,300,500,1000,1500,2000])),
                "{PERCENT}": lambda: str(random.randint(30,80)),
                "{DAYS}": lambda: str(random.randint(5,30)),
                "{MONTHS}": lambda: str(random.randint(3,24)),
                "{COUNTRY}": lambda: random.choice(['France','the UK','Germany','Spain','the US']),
                "{SECTOR}": lambda: random.choice(['software development','finance','healthcare','marketing']),
                "{REASON}": lambda: random.choice(['the new health insurance premium','a tax adjustment','a pension contribution','a benefit opt-in']),
                "{EXPLANATION}": lambda: random.choice(['there was a sync error','the cutoff date had passed','the rate was incorrectly updated']),
                "{PARTNER}": lambda: random.choice(['Gymlib','Edenred','Alan','Swile','Amundi']),
                "{LIST}": lambda: random.choice(['gym access, mental health counseling, and wellness workshops','yoga classes, nutrition coaching, and mindfulness apps']),
            }.items():
                u_msg = u_msg.replace(old, new_func())
                a_msg = a_msg.replace(old, new_func())
            conv_parts.append(f"USER: {u_msg}\nAGENT: {a_msg}")
        synth.append({"conversation": "\n\n".join(conv_parts), "labels": " | ".join(labels), "label_list": labels})
    return pd.DataFrame(synth)

print(f"Generating {NUM_SYNTHETIC} synthetic samples...")
df_synth = generate_synthetic(NUM_SYNTHETIC)
df_real["synthetic"] = False
df_synth["synthetic"] = True
df_combined = pd.concat([df_real[["conversation", "labels", "label_list", "synthetic"]], df_synth], ignore_index=True)
print(f"Total: {len(df_combined)} (real: {(~df_combined['synthetic']).sum()}, synth: {df_combined['synthetic'].sum()})")

mlb = MultiLabelBinarizer(classes=ALL_LABELS)
y = mlb.fit_transform(df_combined["label_list"])
label_names = list(mlb.classes_)
num_labels = len(label_names)
print(f"Label matrix shape: {y.shape}")

# ── 5-Fold Stratified CV ──────────────────────────────────────────────
primary_labels = [lbls[0] for lbls in df_combined["label_list"]]
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

tok = AutoTokenizer.from_pretrained(MODEL_ID)

def preprocess(examples):
    e = tok(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
    e["labels"] = examples["labels"]
    return e

cv_metrics = []
fold_best_scores = []

for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(df_combined)), primary_labels)):
    print(f"\n{'='*50}\nFold {fold+1}/{N_FOLDS}\n{'='*50}")
    train_df = pd.DataFrame({"text": df_combined.iloc[train_idx]["conversation"].tolist(), "labels": y[train_idx].astype(float).tolist()})
    val_df = pd.DataFrame({"text": df_combined.iloc[val_idx]["conversation"].tolist(), "labels": y[val_idx].astype(float).tolist()})
    print(f"  Train: {len(train_df)}  Val: {len(val_df)}")

    ds = DatasetDict({"train": Dataset.from_pandas(train_df), "validation": Dataset.from_pandas(val_df)})
    tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"])
    tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32")))
    tok_ds.set_format("torch")

    model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        probs = torch.sigmoid(torch.tensor(logits)).numpy()
        y_pred = (probs >= 0.5).astype(int)
        y_true = labels.astype(int)
        return {
            "f1_micro": float(f1_score(y_true, y_pred, average="micro", zero_division=0)),
            "f1_macro": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
            "precision": float(precision_score(y_true, y_pred, average="micro", zero_division=0)),
            "recall": float(recall_score(y_true, y_pred, average="micro", zero_division=0)),
            "accuracy": float(accuracy_score(y_true, y_pred)),
            "hamming": float(hamming_loss(y_true, y_pred)),
        }

    fold_dir = f"{OUTPUT_DIR}/fold_{fold}"
    args = TrainingArguments(
        output_dir=fold_dir,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=1,
        learning_rate=LR,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4),
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="steps",
        logging_steps=50,
        logging_first_step=True,
        disable_tqdm=True,
        load_best_model_at_end=True,
        metric_for_best_model="f1_micro",
        greater_is_better=True,
        report_to=[],
        seed=SEED + fold,
        bf16=False, fp16=False,
    )

    trainer = Trainer(
        model=model, args=args,
        train_dataset=tok_ds["train"],
        eval_dataset=tok_ds["validation"],
        processing_class=tok,
        data_collator=DataCollatorWithPadding(tok),
        compute_metrics=compute_metrics,
    )
    trainer.train()
    m = trainer.evaluate()
    print(f"  Fold {fold+1} eval: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}")
    cv_metrics.append(m)
    fold_best_scores.append(m["eval_f1_micro"])

    if m["eval_f1_micro"] == max(fold_best_scores):
        print(f"  -> New best fold! Saving to {OUTPUT_DIR}/best")
        trainer.save_model(f"{OUTPUT_DIR}/best")
        with open(f"{OUTPUT_DIR}/best/label_config.json", "w") as f:
            json.dump({"label_names": label_names, "threshold": 0.5, "num_labels": num_labels}, f)

print(f"\n{'='*50}\nCV Summary\n{'='*50}")
for i, m in enumerate(cv_metrics):
    print(f"  Fold {i+1}: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}")
mean_f1_micro = np.mean([m["eval_f1_micro"] for m in cv_metrics])
std_f1_micro = np.std([m["eval_f1_micro"] for m in cv_metrics])
print(f"\n  Mean f1_micro: {mean_f1_micro:.4f} Β± {std_f1_micro:.4f}")

# Final train on all data
print(f"\nFinal training on all {len(df_combined)} samples...")
all_df = pd.DataFrame({"text": df_combined["conversation"].tolist(), "labels": y.astype(float).tolist()})
ds = DatasetDict({"train": Dataset.from_pandas(all_df), "validation": Dataset.from_pandas(all_df.sample(200, random_state=SEED))})
tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"])
tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32")))
tok_ds.set_format("torch")

model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification")
args = TrainingArguments(
    output_dir=f"{OUTPUT_DIR}/final",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4),
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=100,
    logging_first_step=True,
    disable_tqdm=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1_micro",
    greater_is_better=True,
    report_to=[],
    push_to_hub=True,
    hub_model_id=HUB_MODEL_ID,
    hub_strategy="end",
    seed=SEED,
    bf16=False, fp16=False,
)
trainer = Trainer(
    model=model, args=args,
    train_dataset=tok_ds["train"],
    eval_dataset=tok_ds["validation"],
    processing_class=tok,
    data_collator=DataCollatorWithPadding(tok),
    compute_metrics=compute_metrics,
)
trainer.train()
final_metrics = trainer.evaluate()
print(f"Final eval: f1_micro={final_metrics['eval_f1_micro']:.4f}")
trainer.push_to_hub(commit_message=f"DistilBERT HR | 5-fold CV f1_micro={mean_f1_micro:.4f}Β±{std_f1_micro:.4f} | N={len(df_combined)}")

with open(f"{OUTPUT_DIR}/cv_results.json", "w") as f:
    json.dump({
        "cv_f1_micro_mean": float(mean_f1_micro),
        "cv_f1_micro_std": float(std_f1_micro),
        "fold_scores": [float(s) for s in fold_best_scores],
        "final_eval": {k: float(v) for k, v in final_metrics.items()},
        "num_synthetic": NUM_SYNTHETIC,
        "num_real": int(len(df_real)),
        "total": int(len(df_combined)),
    }, f, indent=2)

print("\nDONE!")