guide / src /classifier /train.py
Saravanakumar R
intial traces bug fixes commit
b016462
Raw
History Blame Contribute Delete
23.7 kB
"""
DomainClassifier training script.
Data sources:
1. CFPB Consumer Complaint Database (3M+ rows) — place as data/raw/complaints.csv.
Covers banking and cibil domains well; other domains remapped to those two.
2. Synthetic supplement — ~2 000 in-memory examples per domain to ensure all 6
classes are represented, especially ecommerce / telecom / insurance.
Estimated training time: ~30 min CPU / ~5 min GPU (T4).
CLI usage:
python -m src.classifier.train \\
--cfpb_csv data/raw/complaints.csv \\
--output_dir models/domain_classifier
"""
from __future__ import annotations
import argparse
import logging
import random
import re
import pandas as pd
from datasets import Dataset, Features, Value, concatenate_datasets
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, ID2DOMAIN, NUM_CLASSES
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# CFPB product → G.U.I.D.E. domain mapping
# ---------------------------------------------------------------------------
CFPB_PRODUCT_MAP: dict[str, str] = {
# banking
"Checking or savings account": "banking",
"Bank account or service": "banking",
"Mortgage": "banking",
"Student loan": "banking",
"Vehicle loan or lease": "banking",
"Payday loan, title loan, or personal loan": "banking",
"Payday loan": "banking",
"Consumer Loan": "banking",
"Money transfer, virtual currency, or money service": "banking",
"Money transfers": "banking",
"Other financial service": "banking",
# cibil
"Credit card or prepaid card": "cibil",
"Credit card": "cibil",
"Prepaid card": "cibil",
"Credit reporting, credit repair services, or other personal consumer reports": "cibil",
"Credit reporting": "cibil",
"Debt collection": "cibil",
"Debt Collection": "cibil",
"Virtual currency": "cibil",
}
# Keyword-based fallback used when the exact CFPB product string is not in the map
_BANKING_KWS = ("mortgage", "loan", "checking", "savings", "bank account", "money transfer")
_CIBIL_KWS = ("credit report", "credit repair", "credit card", "debt collection", "prepaid")
def _map_product(product: str) -> str | None:
"""Map a CFPB product string to a G.U.I.D.E. domain, or None if unmappable."""
product = product.strip()
if product in CFPB_PRODUCT_MAP:
return CFPB_PRODUCT_MAP[product]
lower = product.lower()
if any(kw in lower for kw in _BANKING_KWS):
return "banking"
if any(kw in lower for kw in _CIBIL_KWS):
return "cibil"
return None
# ---------------------------------------------------------------------------
# CFPB dataset loader
# ---------------------------------------------------------------------------
def load_and_remap_cfpb(csv_path: str, max_per_class: int = 50_000) -> Dataset:
"""
Load *csv_path*, remap CFPB product labels → G.U.I.D.E. domains, and return
a HuggingFace Dataset with columns ``text`` and ``labels``.
Reads in 200k-row chunks to stay memory-friendly on the 3M-row file.
Caps each domain at *max_per_class* rows to reduce label imbalance.
"""
logger.info("Loading CFPB CSV from %s …", csv_path)
chunks: list[pd.DataFrame] = []
for chunk in pd.read_csv(
csv_path,
usecols=lambda c: c.strip().lower() in (
"product", "consumer complaint narrative"
),
chunksize=200_000,
low_memory=False,
encoding="utf-8",
on_bad_lines="skip",
):
# Normalise column names across CFPB dataset versions
chunk.columns = [c.strip().lower() for c in chunk.columns]
chunks.append(chunk)
df = pd.concat(chunks, ignore_index=True)
df = df.rename(columns={"consumer complaint narrative": "text"})
df = df.dropna(subset=["text"])
df["text"] = df["text"].astype(str).str.strip()
df = df[df["text"].str.len() > 20] # drop near-empty narratives
df["labels"] = df["product"].map(
lambda p: _map_product(str(p)) if pd.notna(p) else None
)
df = df.dropna(subset=["labels"])
# Cap per class (pandas 3.x compatible — groupby/apply dropped include_groups)
df = pd.concat(
[g.sample(min(len(g), max_per_class), random_state=42)
for _, g in df.groupby("labels")],
ignore_index=True,
)
logger.info(
"CFPB dataset: %d examples.\n%s",
len(df),
df["labels"].value_counts().to_string(),
)
df["labels"] = df["labels"].map(DOMAIN2ID)
features = Features({"text": Value("string"), "labels": Value("int64")})
return Dataset.from_pandas(df[["text", "labels"]], features=features, preserve_index=False)
# ---------------------------------------------------------------------------
# Synthetic supplement for under-represented domains
# ---------------------------------------------------------------------------
_SUPPLEMENT_TEMPLATES: dict[str, list[str]] = {
"ecommerce": [
"My order from Flipkart placed {date} has not been delivered.",
"Amazon India has not processed my refund of {amount} for the returned product.",
"I received a damaged product from Myntra; my return request has been ignored.",
"Snapdeal charged me {amount} twice for order {ref}.",
"Meesho delivered the wrong item and has not responded to my complaint.",
"The product ordered from Flipkart shows shipped but was not received after {days} days.",
"Amazon India rejected my return request for a defective item worth {amount}.",
"I have been waiting for a refund of {amount} from Myntra since {date}.",
"My COD order from Flipkart was marked delivered but I never received it.",
"Zomato charged me for an order that was never delivered on {date}.",
"Swiggy refunded the wrong amount; I am owed {amount} more.",
"Meesho has blocked my account without any notice or explanation.",
"Snapdeal took {amount} from my wallet but the order was never placed.",
"The product I received from Amazon India differs from what I ordered.",
"Myntra has not honoured the discount of {amount} applied at checkout.",
"My Flipkart subscription was renewed without consent, charging me {amount}.",
"MakeMyTrip cancelled my booking and has not refunded {amount} since {date}.",
"IRCTC charged my account {amount} but the ticket was not booked.",
"The Uber ride on {date} overcharged me by {amount}.",
"Ola Cabs deducted {amount} for a trip I never took.",
# Hotel / travel OTAs (same escalation path as MakeMyTrip/IRCTC)
"Agoda confirmed my hotel booking but the property has no record of it.",
"OYO Rooms cancelled my booking on {date} without refunding {amount}.",
"Booking.com India did not process my refund of {amount} after hotel cancellation.",
"I booked a hotel via Agoda for {amount} on {date} but check-in was denied.",
"OYO Rooms charged {amount} extra at check-out beyond the confirmed booking price.",
],
"telecom": [
"Airtel deducted {amount} from my prepaid account without explanation.",
"My Jio plan was not activated despite recharging {amount} on {date}.",
"Vodafone Idea is billing me for services I never subscribed to.",
"BSNL has not resolved my broadband issue since {date}.",
"Airtel ported my number without my consent.",
"Reliance Jio throttled my internet speed without any prior notice.",
"My mobile number with Vodafone Idea was disconnected incorrectly.",
"Airtel charged me {amount} extra every month without justification.",
"I submitted a porting request to Jio on {date} but it has not been processed.",
"BSNL customer care has been unresponsive to my repeated complaints.",
"My Airtel postpaid bill of {amount} includes services I did not activate.",
"Jio deducted {amount} from my account citing roaming charges during domestic travel.",
"BSNL did not credit the cashback of {amount} promised on recharge.",
"My Vodafone Idea sim was deactivated without notice on {date}.",
"Airtel is charging me for a TV subscription I cancelled on {date}.",
"I was not informed of the tariff change that increased my bill by {amount}.",
"Jio's technical team promised resolution by {date} but has not followed up.",
"My BSNL landline has been non-functional since {date}.",
"Vodafone Idea charged me for international calls I never made.",
"Airtel did not process my DND registration despite multiple requests.",
# SIM activation and prepaid recharge scenarios (Trace 5 gap)
"I bought a new SIM card from BSNL on {date} but it has not been activated yet.",
"My prepaid recharge of {amount} on Airtel was successful but services are not reflecting.",
"BSNL SIM activation is pending for {days} days despite multiple follow-ups.",
"I recharged my Jio number with {amount} on {date} but the plan is not active.",
"My new Vodafone Idea SIM was issued on {date} but calls and data are not working.",
"Airtel prepaid recharge of {amount} deducted but balance not updated.",
"I ported my number to BSNL on {date} but incoming calls are not connecting.",
"My BSNL mobile number was deactivated after recharging {amount} on {date}.",
"New SIM card from Jio purchased for {amount} but activation failed without reason.",
"Prepaid plan of {amount} recharged on Vodafone Idea is not reflecting after {days} days.",
],
"insurance": [
"My health insurance claim with LIC for {amount} was rejected without reason.",
"Star Health Insurance has been delaying my settlement since {date}.",
"New India Assurance denied my motor claim citing incorrect reasons.",
"I have not received policy documents from ICICI Lombard despite paying the premium.",
"An LIC agent collected {amount} as premium but my policy was never issued.",
"My mediclaim with Star Health was rejected over a pre-existing condition I disclosed.",
"HDFC ERGO has not settled claim {ref} filed on {date}.",
"The LIC maturity amount has not been credited after {days} months.",
"New India Assurance is not responding to my motor accident claim from {date}.",
"I was mis-sold a policy worth {amount} by an ICICI Lombard agent.",
"Star Health deducted the renewal premium but sent no confirmation.",
"My LIC policy was lapsed without notice despite timely premium payments.",
"IRDAI registered my complaint but the insurer has still not responded after {days} days.",
"The surveyor from New India Assurance undervalued my claim by {amount}.",
"ICICI Lombard cancelled my policy without refunding the unused premium of {amount}.",
"I am unable to get cashless treatment because Star Health blocked my policy.",
"LIC branch refused to accept my premium payment on {date}.",
"My nominee was denied the claim on a valid life insurance policy.",
"The health insurance top-up of {amount} was not applied despite the rider.",
"New India Assurance took {days} days to appoint a surveyor for my claim.",
],
"cibil": [
"My CIBIL score is incorrect due to a wrong default entry by HDFC Bank.",
"Experian India shows a closed loan as active on my credit report.",
"My credit score dropped because ICICI Bank filed an incorrect report.",
"CIBIL has not updated my score after loan closure on {date}.",
"There is a fraudulent credit card entry in my CIBIL report I never applied for.",
"My credit report shows an outstanding of {amount} which I have already paid.",
"TransUnion CIBIL has not removed the settled debt entry.",
"Bank of Baroda incorrectly reported my account as an NPA.",
"My Experian credit report has errors that are hurting my loan eligibility.",
"Kotak Bank reported me as a defaulter despite on-time EMI payments.",
"I disputed the incorrect entry with CIBIL on {date} but received no response.",
"My CIBIL report still shows a write-off that was settled {days} months ago.",
"HDFC Bank is reporting a closed personal loan as outstanding on my Experian report.",
"The credit card I never applied for from Axis Bank is damaging my CIBIL score.",
"My credit score fell by {amount} points due to an error in the CIBIL database.",
"SBI has not updated the repayment history for my home loan with CIBIL.",
"Two different banks are reporting the same debt, inflating my outstanding balance.",
"I requested a free credit report from CIBIL on {date} but did not receive it.",
"My loan account at PNB shows a missed payment that I have a receipt for.",
"The debt collector listed on my Experian report does not match any known account.",
],
"banking": [
"HDFC Bank deducted {amount} from my account without authorization.",
"SBI has not processed my NEFT transfer of {amount} sent on {date}.",
"ICICI Bank is charging hidden fees on my savings account.",
"Axis Bank rejected my loan application without providing a reason.",
"My ATM card was blocked by Kotak Bank without prior notice.",
"PNB deducted maintenance charges from my zero-balance account.",
"My fixed deposit with SBI matured on {date} but the amount has not been credited.",
"HDFC Bank charged {amount} as processing fee but rejected my loan.",
"Unauthorized transactions appeared in my Bank of Baroda account.",
"ICICI Bank has not reversed the duplicate EMI of {amount}.",
"My SBI account was frozen on {date} without any explanation.",
"HDFC Bank did not credit the cashback of {amount} on my credit card.",
"Axis Bank charged me a penalty of {amount} for an EMI that was auto-debited.",
"I was charged {amount} for a Kotak Bank service I never requested.",
"My Bank of Baroda ATM card was cloned and {amount} was withdrawn.",
"PNB issued a cheque return charge of {amount} without valid reason.",
"ICICI Bank approved my loan but the disbursement has been pending since {date}.",
"The interest rate on my SBI home loan was changed without informing me.",
"Axis Bank debited my account twice for the same EMI of {amount} on {date}.",
"HDFC Bank has not closed my account despite a written request on {date}.",
],
"general": [
"The company has not responded to my complaints since {date}.",
"I paid {amount} for a service that was never delivered.",
"The customer care team has been completely unresponsive to my grievance.",
"I want to lodge a formal complaint about the defective product I purchased.",
"Despite multiple follow-ups my refund of {amount} has not been processed.",
"The company promised to resolve my issue by {date} but has not done so.",
"I am being harassed by the collections team for a payment I already made.",
"The service provider charged me {amount} for services not rendered.",
"I have been trying to reach customer support for {days} days with no response.",
"The company delivered a completely different product than what I ordered.",
"My complaint {ref} has been closed without being resolved.",
"I filed a grievance on {date} but there has been no acknowledgement.",
"The company charged me {amount} for a warranty claim that should be free.",
"I have written to the grievance officer but have not received a reply.",
"The product warranty was denied despite the item being within the warranty period.",
"I was promised a callback by {date} but no one has reached out.",
"The company has deducted {amount} from my wallet and is not crediting it back.",
"My subscription was renewed without consent, billing me {amount}.",
"The service outage on {date} caused me a financial loss of {amount}.",
"I never gave consent for the company to auto-renew and charge me {amount}.",
],
}
_FILLERS: dict[str, list[str]] = {
"amount": [
"₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500",
"Rs. 8,900", "Rs 15,000", "₹3,499", "₹25,000", "₹9,999",
"₹35,000", "₹750", "Rs. 2,000", "₹18,500", "₹5,999",
],
"date": [
"12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023",
"three weeks ago", "two months ago", "last Tuesday",
"on 15th February 2024", "last month", "on 30 June 2024",
],
"ref": [
"#OD-2930291", "TXN987654321", "REF-20240312", "CMP-2024-001",
"TKT-9876543", "CLM-2024-12345", "BK-56789",
],
"days": ["7", "10", "15", "20", "30", "45", "60"],
}
_PLACEHOLDER_RE = re.compile(r"\{(\w+)\}")
def _fill(template: str, rng: random.Random) -> str:
def _replace(m: re.Match) -> str:
key = m.group(1)
options = _FILLERS.get(key)
return rng.choice(options) if options else m.group(0)
return _PLACEHOLDER_RE.sub(_replace, template)
def _build_supplement(n_per_class: int = 5_000, seed: int = 42) -> Dataset:
"""
Generate *n_per_class* synthetic examples for each of the 6 domains.
Used to ensure all domains are represented in training regardless of how
CFPB data maps — particularly important for ecommerce, telecom, insurance.
"""
rng = random.Random(seed)
rows: list[dict] = []
for domain, templates in _SUPPLEMENT_TEMPLATES.items():
generated = 0
seen: set[str] = set()
max_attempts = n_per_class * 10
for _ in range(max_attempts):
if generated >= n_per_class:
break
text = _fill(rng.choice(templates), rng)
if text not in seen:
seen.add(text)
rows.append({"text": text, "labels": DOMAIN2ID[domain]})
generated += 1
logger.info("Synthetic supplement: %d examples (%d per class).", len(rows), n_per_class)
features = Features({"text": Value("string"), "labels": Value("int64")})
return Dataset.from_list(rows, features=features)
# ---------------------------------------------------------------------------
# Tokeniser map function
# ---------------------------------------------------------------------------
def _make_tokenise_fn(tokenizer, pad_to_max: bool = False):
def tokenise(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=512,
padding="max_length" if pad_to_max else False,
)
return tokenise
# ---------------------------------------------------------------------------
# Training entry point
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> None:
"""Fine-tune distilbert-base-uncased for 6-class domain classification."""
logging.basicConfig(level=logging.INFO)
# 1. Data
cfpb_ds = load_and_remap_cfpb(args.cfpb_csv, max_per_class=args.max_per_class)
suppl_ds = _build_supplement(n_per_class=args.supplement_per_class)
full_ds = concatenate_datasets([cfpb_ds, suppl_ds]).shuffle(seed=42)
split = full_ds.train_test_split(test_size=0.1, seed=42)
# 2. Tokenise
import torch as _torch
_use_mps = _torch.backends.mps.is_available() and not _torch.cuda.is_available()
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Fixed-length padding on MPS avoids graph recompilation on every new input shape
tokenise_fn = _make_tokenise_fn(tokenizer, pad_to_max=_use_mps)
tokenized = split.map(
tokenise_fn,
batched=True,
remove_columns=["text"],
)
# 3. Model
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=NUM_CLASSES,
id2label=ID2DOMAIN,
label2id=DOMAIN2ID,
ignore_mismatched_sizes=True,
)
# 4. Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
logging_steps=100,
report_to="none",
dataloader_pin_memory=False,
torch_compile=_use_mps,
torch_compile_backend="aot_eager" if _use_mps else "inductor",
)
# 5. Train — skip dynamic padding collator when sequences are already fixed-length
_collator = None if _use_mps else DataCollatorWithPadding(tokenizer)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
data_collator=_collator,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("DomainClassifier checkpoint saved to %s", args.output_dir)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train DomainClassifier")
p.add_argument(
"--cfpb_csv", required=True,
help="Path to the CFPB Consumer Complaint CSV (data/raw/complaints.csv)",
)
p.add_argument(
"--output_dir", default="models/domain_classifier",
help="Directory to save the fine-tuned checkpoint",
)
p.add_argument(
"--max_per_class", type=int, default=50_000,
help="Maximum CFPB rows to use per remapped domain (caps class imbalance)",
)
p.add_argument(
"--supplement_per_class", type=int, default=5_000,
help="Synthetic examples per class added to ensure all 6 domains are covered",
)
p.add_argument("--epochs", type=int, default=3, help="Fine-tuning epochs")
p.add_argument("--batch_size", type=int, default=32, help="Per-device batch size")
return p.parse_args()
if __name__ == "__main__":
train(parse_args())