#!/usr/bin/env python3 -u """ HR Conversations: data augmentation + 5-fold stratified cross-validation + fine-tuning. Copy this script and run it in Google Colab (free T4) or any Python environment. """ import os, json, re, random, numpy as np, pandas as pd, torch from collections import defaultdict, Counter from sklearn.preprocessing import MultiLabelBinarizer from sklearn.model_selection import StratifiedKFold from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, hamming_loss from datasets import Dataset, DatasetDict, Sequence, Value from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding ) from huggingface_hub import hf_hub_download SEED = 42 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) MODEL_ID = "distilbert/distilbert-base-uncased" OUTPUT_DIR = "./hr-distilbert-cv" HUB_MODEL_ID = "AurelPx/hr-conversations-classifier" NUM_SYNTHETIC = 5000 MAX_LENGTH = 512 BATCH_SIZE = 16 LR = 3e-5 WEIGHT_DECAY = 0.1 EPOCHS = 4 N_FOLDS = 5 csv_path = hf_hub_download( repo_id="AurelPx/ml-intern-a2d69eee-datasets", filename="uploads/76ee47c7699c/HRDatasetConv-English.csv", repo_type="dataset", ) df_real = pd.read_csv(csv_path, sep=";", encoding="utf-8-sig", quoting=1) df_real["label_list"] = df_real["labels"].apply(lambda x: [l.strip() for l in str(x).split("|")]) print(f"Real samples: {len(df_real)}") # Extract reusable parts user_messages_by_topic = defaultdict(list) agent_messages_by_topic = defaultdict(list) for _, row in df_real.iterrows(): labels = row["label_list"] conv = str(row["conversation"]) turns = re.split(r'(USER:|AGENT:)', conv) for i in range(1, len(turns), 2): speaker = turns[i].strip(":") msg = turns[i+1].strip() if i+1 < len(turns) else "" primary = labels[0] if speaker == "USER": msg_clean = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "{NAME}", msg) msg_clean = re.sub(r"\b[A-Z][a-z]+\b", "{FIRSTNAME}", msg_clean, count=1) msg_clean = re.sub(r"\b\d{4,}\b", "{NUM}", msg_clean) msg_clean = re.sub(r"\b\d{1,2} [A-Za-z]+ \d{4}\b", "{DATE}", msg_clean) user_messages_by_topic[primary].append(msg_clean) elif speaker == "AGENT": agent_messages_by_topic[primary].append(msg) ALL_LABELS = [ "Benefits", "Career Development", "Compliance & Legal", "Contracts", "Diversity, Equity & Inclusion", "Expense Management", "Harassment", "Health", "IT & Equipment", "Leave & Absence", "Mobility", "Offboarding", "Onboarding", "Payroll", "Performance Management", "Recruitment", "Safety", "Timetracking", "Training", "Work Arrangements", ] # Template pools FIRSTNAMES = ["Emma","Lucas","Sophie","Thomas","Léa","Alexandre","Julia","Maxime","Camille","Nathan","Chloé","Antoine","Manon","Louis","Sarah","Hugo","Zoé","Gabriel","Inès","Raphaël"] LASTNAMES = ["Martin","Bernard","Dubois","Thomas","Robert","Petit","Durand","Leroy","Moreau","Simon","Laurent","Lefebvre","Michel","Garcia","Roux","Bonnet","André","François","Mercier","Dupont"] TOPIC_POOLS = { "Payroll": { "user": ["My payslip for {MONTH} seems incorrect. Can you check?", "I haven't received my salary for {MONTH}. What's happening?", "There's a {NUM} EUR deduction on my payslip I don't understand.", "When is the next payroll run?", "I need to update my bank details for salary payments.", "My withholding tax rate looks wrong on the latest payslip."], "agent": ["Let me pull up your payroll record. I see the issue — {EXPLANATION}.", "The payroll for {MONTH} was processed on {DATE}. It should arrive within 24 hours.", "The deduction corresponds to {REASON}. This was communicated on {DATE}.", "You can update your banking info in the HR portal.", "I'll flag this for the payroll team and you'll receive a corrected statement."], }, "Benefits": { "user": ["Does the company offer gym membership as part of benefits?", "I want to understand the retirement savings plan.", "How does health insurance coverage work during maternity leave?", "I opted out of meal vouchers but still see a deduction.", "What wellness benefits are available?"], "agent": ["Yes, we partner with {PARTNER}. The company subsidizes {PERCENT}% up to {AMOUNT} EUR/month.", "The PERCO allows voluntary contributions. The company matches up to {AMOUNT} EUR/year.", "Your health insurance remains active during leave. Meal vouchers pause for non-worked days.", "Your opt-out was processed after the payroll cutoff. The overcharge will be reimbursed.", "We offer {LIST} through the benefits portal."], }, "Leave & Absence": { "user": ["I need to take sick leave starting today.", "I'd like to understand the rules around parental leave in France.", "My leave balance is wrong. It should be {NUM} days, not {NUM}.", "Can I split my paternity leave into two periods?", "I'm going on maternity leave in {NUM} months. How does coverage work?"], "agent": ["Please submit a sick leave request in the HR portal and upload your medical certificate within 48 hours.", "Under French law, the second parent is entitled to {DAYS} calendar days.", "I can see a {NUM}-day absence was double-counted. I'll submit the correction.", "Yes, the first {NUM} days must be taken immediately after birth leave.", "Your health insurance remains fully active. Meal vouchers pause for non-worked days."], }, "Contracts": { "user": ["I would like a copy of my current employment contract.", "My fixed-term contract ends in {NUM} months. Will it be renewed?", "I signed an amendment for a raise but my salary hasn't changed.", "What does the non-compete clause in my contract mean?", "I'm transferring to the London office. Do I get a new contract?"], "agent": ["I'll send your current contract to your registered email as a PDF within the hour.", "The renewal decision is expected by {DATE}. Your manager will discuss it with you.", "The effective date on the amendment is {DATE}, not the signing date.", "Your non-compete is {MONTHS} months limited to {SECTOR} within {COUNTRY}. Monthly compensation is {PERCENT}% of gross.", "For international transfers, you'll receive a new local contract under local law."], }, "Recruitment": { "user": ["I applied for the Data Scientist position {NUM} weeks ago. Any update?", "I got my offer letter. I have questions before signing.", "As a hiring manager, how do I open a new headcount?", "I saw an internal posting for ML Engineer. Can I apply without all qualifications?"], "agent": ["Your application is under review. You should hear back within {NUM} business days.", "The probation period is standard and not typically negotiated.", "Submit a headcount request through the Talent portal. Approval takes {NUM}-{NUM} business days.", "If you meet 60-70% of requirements, you should apply. We value growth potential."], }, "Training": { "user": ["I'd like to enroll in the AWS certification. Is it covered?", "My manager suggested I work on leadership skills. Any programs?", "I requested a Python course {NUM} weeks ago and still haven't heard back.", "Can I use my CPF for company-related training?", "Is there a structured training plan for new joiners?"], "agent": ["Yes, cloud certifications are fully covered. Submit a request through the portal.", "We offer a 6-month Leadership Accelerator. Your manager can nominate you.", "Your request got stuck at budget validation. I've escalated it.", "Yes, you can combine CPF credits with company budget for more expensive programs.", "Every new joiner should have a 30-60-90 day onboarding plan."], }, "Performance Management": { "user": ["When is the next performance review cycle?", "My rating was 'Exceeds Expectations' but I didn't get a raise.", "My manager put me on a PIP. Am I being fired?", "I want to dispute my performance review.", "I've been rated top performer for {NUM} years. Can I move to a senior role?"], "agent": ["The next cycle opens on {DATE}. Self-assessments due by {DATE}.", "Performance ratings are one input into compensation.", "A PIP is not termination — it's structured support lasting {MONTHS} months.", "You can submit a written appeal through the HR portal within 15 days.", "With your track record, you're well-positioned for internal mobility."], }, "Onboarding": { "user": ["I'm joining next week. What do I need to bring?", "I finished my first week. When do I get my first payslip?", "I'm onboarding a new team member next Monday. Is there a manager checklist?", "I'm starting as a Legal Counsel. Are there specific documents I need?"], "agent": ["Bring a valid photo ID and your signed contract. Arrive at 9 AM at reception.", "First payslip at the end of your first full month. Virtual meal card is active now.", "Yes — the Manager Onboarding Checklist covers IT setup, access requests, 30-60-90 plan.", "We need your valid work permit, passport, and proof of address."], }, "Compliance & Legal": { "user": ["I need to complete the GDPR refresher training. Is it mandatory?", "A colleague is sharing confidential salary info externally. What should I do?", "I received an email asking for a criminal background check. Is this legal?", "I want to understand my rights regarding a sabbatical leave.", "A vendor asked me to share internal process documentation. Is this allowed?"], "agent": ["Yes, it's mandatory annually. Takes about 20 minutes.", "Please forward evidence to ethics@company.com. You're protected under our whistleblower policy.", "For certain roles, a background check may be requested under French law, but only after a conditional offer.", "With {NUM} years seniority, you're eligible for a sabbatical of 6-11 months.", "Internal documentation cannot be shared without approval from the Data Protection Officer."], }, "Mobility": { "user": ["I'm interested in transferring from Engineering to Product. Is this common?", "I'm moving to the Tokyo office in September. How does salary work?", "I'm moving from Madrid to Paris HQ. Will my seniority carry over?", "I've been offered a {MONTHS}-month project in Dubai. Any tax implications?"], "agent": ["About 20% of open roles are filled internally each year.", "You'll receive a split salary: base in EUR, local allowance in JPY.", "Seniority recognition depends on the transfer agreement.", "For a {MONTHS}-month assignment, you remain on French payroll. UAE has no income tax."], }, "Harassment": { "user": ["I need to file a harassment complaint. What documentation is required?", "If the behavior happened in virtual meetings, how do I document it?", "What protections exist against retaliation after filing?", "Who has access to my complaint details during the investigation?"], "agent": ["Submit a written statement with dates, times, locations, descriptions.", "Document the platform, meeting ID, timestamp, participant list, chat logs.", "Retaliation is strictly prohibited. HR monitors for adverse actions.", "Only assigned HR investigators, their supervisor, and legal counsel if required."], }, "Career Development": { "user": ["I want to pursue an AWS certification. Is it reimbursed?", "My manager said I need leadership skills for a team lead role.", "Can I use my CPF for an MBA part-time?", "How do I request a learning plan for ML training?"], "agent": ["Senior-level certifications are reimbursed up to {AMOUNT} USD per attempt.", "We offer a 6-month Leadership Accelerator. Your manager submits a nomination.", "An MBA would need explicit approval from your HRBP and department head.", "Request a learning plan through the Training portal."], }, "IT & Equipment": { "user": ["My laptop keeps crashing. Can I get a replacement?", "I need a second monitor for my home office.", "I forgot my VPN password. How do I reset it?", "What software licenses are available for developers?"], "agent": ["Please open an IT ticket.", "Submit an equipment request through the IT portal.", "Use the self-service password reset portal.", "Developers have access to JetBrains, GitHub Pro, Docker Desktop."], }, "Expense Management": { "user": ["I submitted an expense report {NUM} weeks ago and still haven't been reimbursed.", "What's the policy on business travel expenses?", "Can I expense my monthly public transport pass?", "I lost the receipt for a client dinner. Can I still claim it?"], "agent": ["Let me check the status. It appears stuck at manager approval.", "Business travel covers transport, accommodation up to {AMOUNT} EUR/night, and meals up to {AMOUNT} EUR/day.", "Yes, monthly transport passes are fully reimbursable.", "Without a receipt, the claim may be rejected per policy."], }, "Health": { "user": ["I'm going on long-term sick leave. Will this affect my health insurance?", "How do I access the Employee Assistance Programme for mental health support?", "I need a medical certificate for my sick leave. What format is required?"], "agent": ["Your health insurance remains fully active during sick leave.", "The EAP provides confidential counseling at no cost.", "A standard medical certificate from any licensed physician is sufficient."], }, "Safety": { "user": ["There's a loose cable in the open space that someone could trip on.", "What should I do if there's a fire alarm during work hours?", "I noticed the emergency exit on floor {NUM} is blocked. Who should I report this to?"], "agent": ["Please report it immediately through the Safety Portal.", "Follow the evacuation plan. Assembly point is in the parking lot.", "Report blocked exits to facilities@company.com and your floor safety marshal."], }, "Offboarding": { "user": ["I'm resigning and my last day is in {NUM} weeks. What do I need to do?", "Will I be paid for my unused vacation days?", "Do I need to return my laptop and badge?", "Can I keep my company email address after leaving?"], "agent": ["HR will schedule an offboarding checklist meeting.", "Yes, accrued but unused leave is paid out in your final settlement.", "All IT equipment, access badges, and parking passes must be returned.", "Company email addresses are deactivated on the last working day."], }, "Work Arrangements": { "user": ["Can I work from home {NUM} days a week?", "What's the policy on flexible working hours?", "I want to switch to part-time (80%). What's the process?"], "agent": ["The hybrid policy allows up to {NUM} remote days per week.", "Core hours are 10:00-16:00. Outside that, you can flex start and end times.", "Submit a part-time request through the HR portal at least 1 month in advance."], }, "Timetracking": { "user": ["How do I log overtime hours in the system?", "My timesheet for last week was rejected. What should I fix?", "Are breaks automatically deducted from my tracked hours?"], "agent": ["Use the 'Overtime' category in the time tracking tool.", "Common rejections: missing project codes, incorrect dates.", "Yes, a 30-minute lunch break is auto-deducted for any workday over 6 hours."], }, "Diversity, Equity & Inclusion": { "user": ["Does the company have employee resource groups?", "I experienced age discrimination in my performance review. What can I do?", "Are diversity metrics published in the annual report?"], "agent": ["Yes, we have active ERGs for women in tech, LGBTQ+, parents.", "Age discrimination is prohibited. You can file a formal complaint through the Ethics hotline.", "Yes, our annual report includes diversity metrics."], }, } for label in ALL_LABELS: if label not in TOPIC_POOLS: TOPIC_POOLS[label] = { "user": [f"I have a question about {label.lower()}.", f"Can you help me with {label.lower()}?", f"I need information regarding {label.lower()}."], "agent": [f"I'll look into your {label.lower()} query right away.", f"Here's what you need to know about {label.lower()}.", f"Let me pull up the {label.lower()} policy for you."], } # ── Synthetic generation ────────────────────────────────────────────── seen_combos = Counter(tuple(sorted(l)) for l in df_real["label_list"]) def generate_synthetic(n=5000): synth = [] for _ in range(n): primary = random.choices(list(seen_combos.keys()), weights=list(seen_combos.values()), k=1)[0] if random.random() < 0.53 and len(primary) > 1: labels = list(primary) else: labels = [random.choice(primary)] if random.random() < 0.2: extra = random.choice(ALL_LABELS) if extra not in labels: labels.append(extra) primary_label = labels[0] pool = TOPIC_POOLS[primary_label] num_turns = random.randint(1, 3) conv_parts = [] for _ in range(num_turns): u_msg = random.choice(pool["user"]) a_msg = random.choice(pool["agent"]) for old, new_func in { "{NAME}": lambda: random.choice(FIRSTNAMES) + " " + random.choice(LASTNAMES), "{FIRSTNAME}": lambda: random.choice(FIRSTNAMES), "{NUM}": lambda: str(random.randint(2, 30)), "{DATE}": lambda: f"{random.randint(1,28)} {random.choice(['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec'])} 2025", "{MONTH}": lambda: random.choice(['January','February','March','April','May','June','July','August','September','October','November','December']), "{AMOUNT}": lambda: str(random.choice([100,200,300,500,1000,1500,2000])), "{PERCENT}": lambda: str(random.randint(30,80)), "{DAYS}": lambda: str(random.randint(5,30)), "{MONTHS}": lambda: str(random.randint(3,24)), "{COUNTRY}": lambda: random.choice(['France','the UK','Germany','Spain','the US']), "{SECTOR}": lambda: random.choice(['software development','finance','healthcare','marketing']), "{REASON}": lambda: random.choice(['the new health insurance premium','a tax adjustment','a pension contribution','a benefit opt-in']), "{EXPLANATION}": lambda: random.choice(['there was a sync error','the cutoff date had passed','the rate was incorrectly updated']), "{PARTNER}": lambda: random.choice(['Gymlib','Edenred','Alan','Swile','Amundi']), "{LIST}": lambda: random.choice(['gym access, mental health counseling, and wellness workshops','yoga classes, nutrition coaching, and mindfulness apps']), }.items(): u_msg = u_msg.replace(old, new_func()) a_msg = a_msg.replace(old, new_func()) conv_parts.append(f"USER: {u_msg}\nAGENT: {a_msg}") synth.append({"conversation": "\n\n".join(conv_parts), "labels": " | ".join(labels), "label_list": labels}) return pd.DataFrame(synth) print(f"Generating {NUM_SYNTHETIC} synthetic samples...") df_synth = generate_synthetic(NUM_SYNTHETIC) df_real["synthetic"] = False df_synth["synthetic"] = True df_combined = pd.concat([df_real[["conversation", "labels", "label_list", "synthetic"]], df_synth], ignore_index=True) print(f"Total: {len(df_combined)} (real: {(~df_combined['synthetic']).sum()}, synth: {df_combined['synthetic'].sum()})") mlb = MultiLabelBinarizer(classes=ALL_LABELS) y = mlb.fit_transform(df_combined["label_list"]) label_names = list(mlb.classes_) num_labels = len(label_names) print(f"Label matrix shape: {y.shape}") # ── 5-Fold Stratified CV ────────────────────────────────────────────── primary_labels = [lbls[0] for lbls in df_combined["label_list"]] skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED) tok = AutoTokenizer.from_pretrained(MODEL_ID) def preprocess(examples): e = tok(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False) e["labels"] = examples["labels"] return e cv_metrics = [] fold_best_scores = [] for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(df_combined)), primary_labels)): print(f"\n{'='*50}\nFold {fold+1}/{N_FOLDS}\n{'='*50}") train_df = pd.DataFrame({"text": df_combined.iloc[train_idx]["conversation"].tolist(), "labels": y[train_idx].astype(float).tolist()}) val_df = pd.DataFrame({"text": df_combined.iloc[val_idx]["conversation"].tolist(), "labels": y[val_idx].astype(float).tolist()}) print(f" Train: {len(train_df)} Val: {len(val_df)}") ds = DatasetDict({"train": Dataset.from_pandas(train_df), "validation": Dataset.from_pandas(val_df)}) tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"]) tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32"))) tok_ds.set_format("torch") model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification") def compute_metrics(eval_pred): logits, labels = eval_pred probs = torch.sigmoid(torch.tensor(logits)).numpy() y_pred = (probs >= 0.5).astype(int) y_true = labels.astype(int) return { "f1_micro": float(f1_score(y_true, y_pred, average="micro", zero_division=0)), "f1_macro": float(f1_score(y_true, y_pred, average="macro", zero_division=0)), "precision": float(precision_score(y_true, y_pred, average="micro", zero_division=0)), "recall": float(recall_score(y_true, y_pred, average="micro", zero_division=0)), "accuracy": float(accuracy_score(y_true, y_pred)), "hamming": float(hamming_loss(y_true, y_pred)), } fold_dir = f"{OUTPUT_DIR}/fold_{fold}" args = TrainingArguments( output_dir=fold_dir, num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=1, learning_rate=LR, weight_decay=WEIGHT_DECAY, warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4), eval_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=50, logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True, metric_for_best_model="f1_micro", greater_is_better=True, report_to=[], seed=SEED + fold, bf16=False, fp16=False, ) trainer = Trainer( model=model, args=args, train_dataset=tok_ds["train"], eval_dataset=tok_ds["validation"], processing_class=tok, data_collator=DataCollatorWithPadding(tok), compute_metrics=compute_metrics, ) trainer.train() m = trainer.evaluate() print(f" Fold {fold+1} eval: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}") cv_metrics.append(m) fold_best_scores.append(m["eval_f1_micro"]) if m["eval_f1_micro"] == max(fold_best_scores): print(f" -> New best fold! Saving to {OUTPUT_DIR}/best") trainer.save_model(f"{OUTPUT_DIR}/best") with open(f"{OUTPUT_DIR}/best/label_config.json", "w") as f: json.dump({"label_names": label_names, "threshold": 0.5, "num_labels": num_labels}, f) print(f"\n{'='*50}\nCV Summary\n{'='*50}") for i, m in enumerate(cv_metrics): print(f" Fold {i+1}: f1_micro={m['eval_f1_micro']:.4f} f1_macro={m['eval_f1_macro']:.4f}") mean_f1_micro = np.mean([m["eval_f1_micro"] for m in cv_metrics]) std_f1_micro = np.std([m["eval_f1_micro"] for m in cv_metrics]) print(f"\n Mean f1_micro: {mean_f1_micro:.4f} ± {std_f1_micro:.4f}") # Final train on all data print(f"\nFinal training on all {len(df_combined)} samples...") all_df = pd.DataFrame({"text": df_combined["conversation"].tolist(), "labels": y.astype(float).tolist()}) ds = DatasetDict({"train": Dataset.from_pandas(all_df), "validation": Dataset.from_pandas(all_df.sample(200, random_state=SEED))}) tok_ds = ds.map(preprocess, batched=True, remove_columns=["text"]) tok_ds = tok_ds.cast_column("labels", Sequence(Value("float32"))) tok_ds.set_format("torch") model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, problem_type="multi_label_classification") args = TrainingArguments( output_dir=f"{OUTPUT_DIR}/final", num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=BATCH_SIZE, gradient_accumulation_steps=1, learning_rate=LR, weight_decay=WEIGHT_DECAY, warmup_steps=max(1, len(tok_ds["train"]) // BATCH_SIZE // 4), eval_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=100, logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True, metric_for_best_model="f1_micro", greater_is_better=True, report_to=[], push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="end", seed=SEED, bf16=False, fp16=False, ) trainer = Trainer( model=model, args=args, train_dataset=tok_ds["train"], eval_dataset=tok_ds["validation"], processing_class=tok, data_collator=DataCollatorWithPadding(tok), compute_metrics=compute_metrics, ) trainer.train() final_metrics = trainer.evaluate() print(f"Final eval: f1_micro={final_metrics['eval_f1_micro']:.4f}") trainer.push_to_hub(commit_message=f"DistilBERT HR | 5-fold CV f1_micro={mean_f1_micro:.4f}±{std_f1_micro:.4f} | N={len(df_combined)}") with open(f"{OUTPUT_DIR}/cv_results.json", "w") as f: json.dump({ "cv_f1_micro_mean": float(mean_f1_micro), "cv_f1_micro_std": float(std_f1_micro), "fold_scores": [float(s) for s in fold_best_scores], "final_eval": {k: float(v) for k, v in final_metrics.items()}, "num_synthetic": NUM_SYNTHETIC, "num_real": int(len(df_real)), "total": int(len(df_combined)), }, f, indent=2) print("\nDONE!")