File size: 26,376 Bytes
95d3457 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 | #!/usr/bin/env python3 -u
"""
HR Conversations: data augmentation + 5-fold stratified cross-validation + fine-tuning.
Copy this script and run it in Google Colab (free T4) or any Python environment.
"""
import os, json, re, random, numpy as np, pandas as pd, torch
from collections import defaultdict, Counter
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, hamming_loss
from datasets import Dataset, DatasetDict, Sequence, Value
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer, DataCollatorWithPadding
)
from huggingface_hub import hf_hub_download
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
MODEL_ID = "distilbert/distilbert-base-uncased"
OUTPUT_DIR = "./hr-distilbert-cv"
HUB_MODEL_ID = "AurelPx/hr-conversations-classifier"
NUM_SYNTHETIC = 5000
MAX_LENGTH = 512
BATCH_SIZE = 16
LR = 3e-5
WEIGHT_DECAY = 0.1
EPOCHS = 4
N_FOLDS = 5
csv_path = hf_hub_download(
repo_id="AurelPx/ml-intern-a2d69eee-datasets",
filename="uploads/76ee47c7699c/HRDatasetConv-English.csv",
repo_type="dataset",
)
df_real = pd.read_csv(csv_path, sep=";", encoding="utf-8-sig", quoting=1)
df_real["label_list"] = df_real["labels"].apply(lambda x: [l.strip() for l in str(x).split("|")])
print(f"Real samples: {len(df_real)}")
# Extract reusable parts
user_messages_by_topic = defaultdict(list)
agent_messages_by_topic = defaultdict(list)
for _, row in df_real.iterrows():
labels = row["label_list"]
conv = str(row["conversation"])
turns = re.split(r'(USER:|AGENT:)', conv)
for i in range(1, len(turns), 2):
speaker = turns[i].strip(":")
msg = turns[i+1].strip() if i+1 < len(turns) else ""
primary = labels[0]
if speaker == "USER":
msg_clean = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "{NAME}", msg)
msg_clean = re.sub(r"\b[A-Z][a-z]+\b", "{FIRSTNAME}", msg_clean, count=1)
msg_clean = re.sub(r"\b\d{4,}\b", "{NUM}", msg_clean)
msg_clean = re.sub(r"\b\d{1,2} [A-Za-z]+ \d{4}\b", "{DATE}", msg_clean)
user_messages_by_topic[primary].append(msg_clean)
elif speaker == "AGENT":
agent_messages_by_topic[primary].append(msg)
ALL_LABELS = [
"Benefits", "Career Development", "Compliance & Legal", "Contracts",
"Diversity, Equity & Inclusion", "Expense Management", "Harassment", "Health",
"IT & Equipment", "Leave & Absence", "Mobility", "Offboarding",
"Onboarding", "Payroll", "Performance Management", "Recruitment",
"Safety", "Timetracking", "Training", "Work Arrangements",
]
# Template pools
FIRSTNAMES = ["Emma","Lucas","Sophie","Thomas","Léa","Alexandre","Julia","Maxime","Camille","Nathan","Chloé","Antoine","Manon","Louis","Sarah","Hugo","Zoé","Gabriel","Inès","Raphaël"]
LASTNAMES = ["Martin","Bernard","Dubois","Thomas","Robert","Petit","Durand","Leroy","Moreau","Simon","Laurent","Lefebvre","Michel","Garcia","Roux","Bonnet","AndrΓ©","FranΓ§ois","Mercier","Dupont"]
TOPIC_POOLS = {
"Payroll": {
"user": ["My payslip for {MONTH} seems incorrect. Can you check?", "I haven't received my salary for {MONTH}. What's happening?", "There's a {NUM} EUR deduction on my payslip I don't understand.", "When is the next payroll run?", "I need to update my bank details for salary payments.", "My withholding tax rate looks wrong on the latest payslip."],
"agent": ["Let me pull up your payroll record. I see the issue β {EXPLANATION}.", "The payroll for {MONTH} was processed on {DATE}. It should arrive within 24 hours.", "The deduction corresponds to {REASON}. This was communicated on {DATE}.", "You can update your banking info in the HR portal.", "I'll flag this for the payroll team and you'll receive a corrected statement."],
},
"Benefits": {
"user": ["Does the company offer gym membership as part of benefits?", "I want to understand the retirement savings plan.", "How does health insurance coverage work during maternity leave?", "I opted out of meal vouchers but still see a deduction.", "What wellness benefits are available?"],
"agent": ["Yes, we partner with {PARTNER}. The company subsidizes {PERCENT}% up to {AMOUNT} EUR/month.", "The PERCO allows voluntary contributions. The company matches up to {AMOUNT} EUR/year.", "Your health insurance remains active during leave. Meal vouchers pause for non-worked days.", "Your opt-out was processed after the payroll cutoff. The overcharge will be reimbursed.", "We offer {LIST} through the benefits portal."],
},
"Leave & Absence": {
"user": ["I need to take sick leave starting today.", "I'd like to understand the rules around parental leave in France.", "My leave balance is wrong. It should be {NUM} days, not {NUM}.", "Can I split my paternity leave into two periods?", "I'm going on maternity leave in {NUM} months. How does coverage work?"],
"agent": ["Please submit a sick leave request in the HR portal and upload your medical certificate within 48 hours.", "Under French law, the second parent is entitled to {DAYS} calendar days.", "I can see a {NUM}-day absence was double-counted. I'll submit the correction.", "Yes, the first {NUM} days must be taken immediately after birth leave.", "Your health insurance remains fully active. Meal vouchers pause for non-worked days."],
},
"Contracts": {
"user": ["I would like a copy of my current employment contract.", "My fixed-term contract ends in {NUM} months. Will it be renewed?", "I signed an amendment for a raise but my salary hasn't changed.", "What does the non-compete clause in my contract mean?", "I'm transferring to the London office. Do I get a new contract?"],
"agent": ["I'll send your current contract to your registered email as a PDF within the hour.", "The renewal decision is expected by {DATE}. Your manager will discuss it with you.", "The effective date on the amendment is {DATE}, not the signing date.", "Your non-compete is {MONTHS} months limited to {SECTOR} within {COUNTRY}. Monthly compensation is {PERCENT}% of gross.", "For international transfers, you'll receive a new local contract under local law."],
},
"Recruitment": {
"user": ["I applied for the Data Scientist position {NUM} weeks ago. Any update?", "I got my offer letter. I have questions before signing.", "As a hiring manager, how do I open a new headcount?", "I saw an internal posting for ML Engineer. Can I apply without all qualifications?"],
"agent": ["Your application is under review. You should hear back within {NUM} business days.", "The probation period is standard and not typically negotiated.", "Submit a headcount request through the Talent portal. Approval takes {NUM}-{NUM} business days.", "If you meet 60-70% of requirements, you should apply. We value growth potential."],
},
"Training": {
"user": ["I'd like to enroll in the AWS certification. Is it covered?", "My manager suggested I work on leadership skills. Any programs?", "I requested a Python course {NUM} weeks ago and still haven't heard back.", "Can I use my CPF for company-related training?", "Is there a structured training plan for new joiners?"],
"agent": ["Yes, cloud certifications are fully covered. Submit a request through the portal.", "We offer a 6-month Leadership Accelerator. Your manager can nominate you.", "Your request got stuck at budget validation. I've escalated it.", "Yes, you can combine CPF credits with company budget for more expensive programs.", "Every new joiner should have a 30-60-90 day onboarding plan."],
},
"Performance Management": {
"user": ["When is the next performance review cycle?", "My rating was 'Exceeds Expectations' but I didn't get a raise.", "My manager put me on a PIP. Am I being fired?", "I want to dispute my performance review.", "I've been rated top performer for {NUM} years. Can I move to a senior role?"],
"agent": ["The next cycle opens on {DATE}. Self-assessments due by {DATE}.", "Performance ratings are one input into compensation.", "A PIP is not termination β it's structured support lasting {MONTHS} months.", "You can submit a written appeal through the HR portal within 15 days.", "With your track record, you're well-positioned for internal mobility."],
},
"Onboarding": {
"user": ["I'm joining next week. What do I need to bring?", "I finished my first week. When do I get my first payslip?", "I'm onboarding a new team member next Monday. Is there a manager checklist?", "I'm starting as a Legal Counsel. Are there specific documents I need?"],
"agent": ["Bring a valid photo ID and your signed contract. Arrive at 9 AM at reception.", "First payslip at the end of your first full month. Virtual meal card is active now.", "Yes β the Manager Onboarding Checklist covers IT setup, access requests, 30-60-90 plan.", "We need your valid work permit, passport, and proof of address."],
},
"Compliance & Legal": {
"user": ["I need to complete the GDPR refresher training. Is it mandatory?", "A colleague is sharing confidential salary info externally. What should I do?", "I received an email asking for a criminal background check. Is this legal?", "I want to understand my rights regarding a sabbatical leave.", "A vendor asked me to share internal process documentation. Is this allowed?"],
"agent": ["Yes, it's mandatory annually. Takes about 20 minutes.", "Please forward evidence to ethics@company.com. You're protected under our whistleblower policy.", "For certain roles, a background check may be requested under French law, but only after a conditional offer.", "With {NUM} years seniority, you're eligible for a sabbatical of 6-11 months.", "Internal documentation cannot be shared without approval from the Data Protection Officer."],
},
"Mobility": {
"user": ["I'm interested in transferring from Engineering to Product. Is this common?", "I'm moving to the Tokyo office in September. How does salary work?", "I'm moving from Madrid to Paris HQ. Will my seniority carry over?", "I've been offered a {MONTHS}-month project in Dubai. Any tax implications?"],
"agent": ["About 20% of open roles are filled internally each year.", "You'll receive a split salary: base in EUR, local allowance in JPY.", "Seniority recognition depends on the transfer agreement.", "For a {MONTHS}-month assignment, you remain on French payroll. UAE has no income tax."],
},
"Harassment": {
"user": ["I need to file a harassment complaint. What documentation is required?", "If the behavior happened in virtual meetings, how do I document it?", "What protections exist against retaliation after filing?", "Who has access to my complaint details during the investigation?"],
"agent": ["Submit a written statement with dates, times, locations, descriptions.", "Document the platform, meeting ID, timestamp, participant list, chat logs.", "Retaliation is strictly prohibited. HR monitors for adverse actions.", "Only assigned HR investigators, their supervisor, and legal counsel if required."],
},
"Career Development": {
"user": ["I want to pursue an AWS certification. Is it reimbursed?", "My manager said I need leadership skills for a team lead role.", "Can I use my CPF for an MBA part-time?", "How do I request a learning plan for ML training?"],
"agent": ["Senior-level certifications are reimbursed up to {AMOUNT} USD per attempt.", "We offer a 6-month Leadership Accelerator. Your manager submits a nomination.", "An MBA would need explicit approval from your HRBP and department head.", "Request a learning plan through the Training portal."],
},
"IT & Equipment": {
"user": ["My laptop keeps crashing. Can I get a replacement?", "I need a second monitor for my home office.", "I forgot my VPN password. How do I reset it?", "What software licenses are available for developers?"],
"agent": ["Please open an IT ticket.", "Submit an equipment request through the IT portal.", "Use the self-service password reset portal.", "Developers have access to JetBrains, GitHub Pro, Docker Desktop."],
},
"Expense Management": {
"user": ["I submitted an expense report {NUM} weeks ago and still haven't been reimbursed.", "What's the policy on business travel expenses?", "Can I expense my monthly public transport pass?", "I lost the receipt for a client dinner. Can I still claim it?"],
"agent": ["Let me check the status. It appears stuck at manager approval.", "Business travel covers transport, accommodation up to {AMOUNT} EUR/night, and meals up to {AMOUNT} EUR/day.", "Yes, monthly transport passes are fully reimbursable.", "Without a receipt, the claim may be rejected per policy."],
},
"Health": {
"user": ["I'm going on long-term sick leave. Will this affect my health insurance?", "How do I access the Employee Assistance Programme for mental health support?", "I need a medical certificate for my sick leave. What format is required?"],
"agent": ["Your health insurance remains fully active during sick leave.", "The EAP provides confidential counseling at no cost.", "A standard medical certificate from any licensed physician is sufficient."],
},
"Safety": {
"user": ["There's a loose cable in the open space that someone could trip on.", "What should I do if there's a fire alarm during work hours?", "I noticed the emergency exit on floor {NUM} is blocked. Who should I report this to?"],
"agent": ["Please report it immediately through the Safety Portal.", "Follow the evacuation plan. Assembly point is in the parking lot.", "Report blocked exits to facilities@company.com and your floor safety marshal."],
},
"Offboarding": {
"user": ["I'm resigning and my last day is in {NUM} weeks. What do I need to do?", "Will I be paid for my unused vacation days?", "Do I need to return my laptop and badge?", "Can I keep my company email address after leaving?"],
"agent": ["HR will schedule an offboarding checklist meeting.", "Yes, accrued but unused leave is paid out in your final settlement.", "All IT equipment, access badges, and parking passes must be returned.", "Company email addresses are deactivated on the last working day."],
},
"Work Arrangements": {
"user": ["Can I work from home {NUM} days a week?", "What's the policy on flexible working hours?", "I want to switch to part-time (80%). What's the process?"],
"agent": ["The hybrid policy allows up to {NUM} remote days per week.", "Core hours are 10:00-16:00. Outside that, you can flex start and end times.", "Submit a part-time request through the HR portal at least 1 month in advance."],
},
"Timetracking": {
"user": ["How do I log overtime hours in the system?", "My timesheet for last week was rejected. What should I fix?", "Are breaks automatically deducted from my tracked hours?"],
"agent": ["Use the 'Overtime' category in the time tracking tool.", "Common rejections: missing project codes, incorrect dates.", "Yes, a 30-minute lunch break is auto-deducted for any workday over 6 hours."],
},
"Diversity, Equity & Inclusion": {
"user": ["Does the company have employee resource groups?", "I experienced age discrimination in my performance review. What can I do?", "Are diversity metrics published in the annual report?"],
"agent": ["Yes, we have active ERGs for women in tech, LGBTQ+, parents.", "Age discrimination is prohibited. You can file a formal complaint through the Ethics hotline.", "Yes, our annual report includes diversity metrics."],
},
}
for label in ALL_LABELS:
if label not in TOPIC_POOLS:
TOPIC_POOLS[label] = {
"user": [f"I have a question about {label.lower()}.", f"Can you help me with {label.lower()}?", f"I need information regarding {label.lower()}."],
"agent": [f"I'll look into your {label.lower()} query right away.", f"Here's what you need to know about {label.lower()}.", f"Let me pull up the {label.lower()} policy for you."],
}
# ββ Synthetic generation ββββββββββββββββββββββββββββββββββββββββββββββ
seen_combos = Counter(tuple(sorted(l)) for l in df_real["label_list"])
def generate_synthetic(n=5000):
synth = []
for _ in range(n):
primary = random.choices(list(seen_combos.keys()), weights=list(seen_combos.values()), k=1)[0]
if random.random() < 0.53 and len(primary) > 1:
labels = list(primary)
else:
labels = [random.choice(primary)]
if random.random() < 0.2:
extra = random.choice(ALL_LABELS)
if extra not in labels:
labels.append(extra)
primary_label = labels[0]
pool = TOPIC_POOLS[primary_label]
num_turns = random.randint(1, 3)
conv_parts = []
for _ in range(num_turns):
u_msg = random.choice(pool["user"])
a_msg = random.choice(pool["agent"])
for old, new_func in {
"{NAME}": lambda: random.choice(FIRSTNAMES) + " " + random.choice(LASTNAMES),
"{FIRSTNAME}": lambda: random.choice(FIRSTNAMES),
"{NUM}": lambda: str(random.randint(2, 30)),
"{DATE}": lambda: f"{random.randint(1,28)} {random.choice(['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'])} 2025",
"{MONTH}": lambda: random.choice(['January','February','March','April','May','June','July','August','September','October','November','December']),
"{AMOUNT}": lambda: str(random.choice([100,200,300,500,1000,1500,2000])),
"{PERCENT}": lambda: str(random.randint(30,80)),
"{DAYS}": lambda: str(random.randint(5,30)),
"{MONTHS}": lambda: str(random.randint(3,24)),
"{COUNTRY}": lambda: random.choice(['France','the UK','Germany','Spain','the US']),
"{SECTOR}": lambda: random.choice(['software development','finance','healthcare','marketing']),
"{REASON}": lambda: random.choice(['the new health insurance premium','a tax adjustment','a pension contribution','a benefit opt-in']),
"{EXPLANATION}": lambda: random.choice(['there was a sync error','the cutoff date had passed','the rate was incorrectly updated']),
"{PARTNER}": lambda: random.choice(['Gymlib','Edenred','Alan','Swile','Amundi']),
"{LIST}": lambda: random.choice(['gym access, mental health counseling, and wellness workshops','yoga classes, nutrition coaching, and mindfulness apps']),
}.items():
u_msg = u_msg.replace(old, new_func())
a_msg = a_msg.replace(old, new_func())
conv_parts.append(f"USER: {u_msg}\nAGENT: {a_msg}")
synth.append({"conversation": "\n\n".join(conv_parts), "labels": " | ".join(labels), "label_list": labels})
return pd.DataFrame(synth)
print(f"Generating {NUM_SYNTHETIC} synthetic samples...")
df_synth = generate_synthetic(NUM_SYNTHETIC)
df_real["synthetic"] = False
df_synth["synthetic"] = True
df_combined = pd.concat([df_real[["conversation", "labels", "label_list", "synthetic"]], df_synth], ignore_index=True)
print(f"Total: {len(df_combined)} (real: {(~df_combined['synthetic']).sum()}, synth: {df_combined['synthetic'].sum()})")
mlb = MultiLabelBinarizer(classes=ALL_LABELS)
y = mlb.fit_transform(df_combined["label_list"])
label_names = list(mlb.classes_)
num_labels = len(label_names)
print(f"Label matrix shape: {y.shape}")
# ββ 5-Fold Stratified CV ββββββββββββββββββββββββββββββββββββββββββββββ
primary_labels = [lbls[0] for lbls in df_combined["label_list"]]
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
tok = AutoTokenizer.from_pretrained(MODEL_ID)
def preprocess(examples):
e = tok(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)
e["labels"] = examples["labels"]
return e
cv_metrics = []
fold_best_scores = []
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(df_combined)), primary_labels)):
print(f"\n{'='*50}\nFold {fold+1}/{N_FOLDS}\n{'='*50}")
train_df = pd.DataFrame({"text": df_combined.iloc[train_idx]["conversation"].tolist(), "labels": y[train_idx].astype(float).tolist()})
val_df = pd.DataFrame({"text": df_combined.iloc[val_idx]["conversation"].tolist(), "labels": y[val_idx].astype(float).tolist()})
print(f" Train: {len(train_df)} Val: {len(val_df)}")
ds = DatasetDict({"train": Dataset.from_pandas(train_df), "validation": Dataset.from_pandas(val_df)})
tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"])
tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32")))
tok_ds.set_format("torch")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification")
def compute_metrics(eval_pred):
logits, labels = eval_pred
probs = torch.sigmoid(torch.tensor(logits)).numpy()
y_pred = (probs >= 0.5).astype(int)
y_true = labels.astype(int)
return {
"f1_micro": float(f1_score(y_true, y_pred, average="micro", zero_division=0)),
"f1_macro": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
"precision": float(precision_score(y_true, y_pred, average="micro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="micro", zero_division=0)),
"accuracy": float(accuracy_score(y_true, y_pred)),
"hamming": float(hamming_loss(y_true, y_pred)),
}
fold_dir = f"{OUTPUT_DIR}/fold_{fold}"
args = TrainingArguments(
output_dir=fold_dir,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=1,
learning_rate=LR,
weight_decay=WEIGHT_DECAY,
warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4),
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="f1_micro",
greater_is_better=True,
report_to=[],
seed=SEED + fold,
bf16=False, fp16=False,
)
trainer = Trainer(
model=model, args=args,
train_dataset=tok_ds["train"],
eval_dataset=tok_ds["validation"],
processing_class=tok,
data_collator=DataCollatorWithPadding(tok),
compute_metrics=compute_metrics,
)
trainer.train()
m = trainer.evaluate()
print(f" Fold {fold+1} eval: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}")
cv_metrics.append(m)
fold_best_scores.append(m["eval_f1_micro"])
if m["eval_f1_micro"] == max(fold_best_scores):
print(f" -> New best fold! Saving to {OUTPUT_DIR}/best")
trainer.save_model(f"{OUTPUT_DIR}/best")
with open(f"{OUTPUT_DIR}/best/label_config.json", "w") as f:
json.dump({"label_names": label_names, "threshold": 0.5, "num_labels": num_labels}, f)
print(f"\n{'='*50}\nCV Summary\n{'='*50}")
for i, m in enumerate(cv_metrics):
print(f" Fold {i+1}: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}")
mean_f1_micro = np.mean([m["eval_f1_micro"] for m in cv_metrics])
std_f1_micro = np.std([m["eval_f1_micro"] for m in cv_metrics])
print(f"\n Mean f1_micro: {mean_f1_micro:.4f} Β± {std_f1_micro:.4f}")
# Final train on all data
print(f"\nFinal training on all {len(df_combined)} samples...")
all_df = pd.DataFrame({"text": df_combined["conversation"].tolist(), "labels": y.astype(float).tolist()})
ds = DatasetDict({"train": Dataset.from_pandas(all_df), "validation": Dataset.from_pandas(all_df.sample(200, random_state=SEED))})
tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"])
tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32")))
tok_ds.set_format("torch")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification")
args = TrainingArguments(
output_dir=f"{OUTPUT_DIR}/final",
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=1,
learning_rate=LR,
weight_decay=WEIGHT_DECAY,
warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4),
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=100,
logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True,
metric_for_best_model="f1_micro",
greater_is_better=True,
report_to=[],
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_strategy="end",
seed=SEED,
bf16=False, fp16=False,
)
trainer = Trainer(
model=model, args=args,
train_dataset=tok_ds["train"],
eval_dataset=tok_ds["validation"],
processing_class=tok,
data_collator=DataCollatorWithPadding(tok),
compute_metrics=compute_metrics,
)
trainer.train()
final_metrics = trainer.evaluate()
print(f"Final eval: f1_micro={final_metrics['eval_f1_micro']:.4f}")
trainer.push_to_hub(commit_message=f"DistilBERT HR | 5-fold CV f1_micro={mean_f1_micro:.4f}Β±{std_f1_micro:.4f} | N={len(df_combined)}")
with open(f"{OUTPUT_DIR}/cv_results.json", "w") as f:
json.dump({
"cv_f1_micro_mean": float(mean_f1_micro),
"cv_f1_micro_std": float(std_f1_micro),
"fold_scores": [float(s) for s in fold_best_scores],
"final_eval": {k: float(v) for k, v in final_metrics.items()},
"num_synthetic": NUM_SYNTHETIC,
"num_real": int(len(df_real)),
"total": int(len(df_combined)),
}, f, indent=2)
print("\nDONE!")
|