Spaces:
Sleeping
Sleeping
| """ | |
| DomainClassifier training script. | |
| Data sources: | |
| 1. CFPB Consumer Complaint Database (3M+ rows) — place as data/raw/complaints.csv. | |
| Covers banking and cibil domains well; other domains remapped to those two. | |
| 2. Synthetic supplement — ~2 000 in-memory examples per domain to ensure all 6 | |
| classes are represented, especially ecommerce / telecom / insurance. | |
| Estimated training time: ~30 min CPU / ~5 min GPU (T4). | |
| CLI usage: | |
| python -m src.classifier.train \\ | |
| --cfpb_csv data/raw/complaints.csv \\ | |
| --output_dir models/domain_classifier | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import random | |
| import re | |
| import pandas as pd | |
| from datasets import Dataset, Features, Value, concatenate_datasets | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from src.classifier.model import DOMAIN2ID, DOMAIN_LABELS, ID2DOMAIN, NUM_CLASSES | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # CFPB product → G.U.I.D.E. domain mapping | |
| # --------------------------------------------------------------------------- | |
| CFPB_PRODUCT_MAP: dict[str, str] = { | |
| # banking | |
| "Checking or savings account": "banking", | |
| "Bank account or service": "banking", | |
| "Mortgage": "banking", | |
| "Student loan": "banking", | |
| "Vehicle loan or lease": "banking", | |
| "Payday loan, title loan, or personal loan": "banking", | |
| "Payday loan": "banking", | |
| "Consumer Loan": "banking", | |
| "Money transfer, virtual currency, or money service": "banking", | |
| "Money transfers": "banking", | |
| "Other financial service": "banking", | |
| # cibil | |
| "Credit card or prepaid card": "cibil", | |
| "Credit card": "cibil", | |
| "Prepaid card": "cibil", | |
| "Credit reporting, credit repair services, or other personal consumer reports": "cibil", | |
| "Credit reporting": "cibil", | |
| "Debt collection": "cibil", | |
| "Debt Collection": "cibil", | |
| "Virtual currency": "cibil", | |
| } | |
| # Keyword-based fallback used when the exact CFPB product string is not in the map | |
| _BANKING_KWS = ("mortgage", "loan", "checking", "savings", "bank account", "money transfer") | |
| _CIBIL_KWS = ("credit report", "credit repair", "credit card", "debt collection", "prepaid") | |
| def _map_product(product: str) -> str | None: | |
| """Map a CFPB product string to a G.U.I.D.E. domain, or None if unmappable.""" | |
| product = product.strip() | |
| if product in CFPB_PRODUCT_MAP: | |
| return CFPB_PRODUCT_MAP[product] | |
| lower = product.lower() | |
| if any(kw in lower for kw in _BANKING_KWS): | |
| return "banking" | |
| if any(kw in lower for kw in _CIBIL_KWS): | |
| return "cibil" | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # CFPB dataset loader | |
| # --------------------------------------------------------------------------- | |
| def load_and_remap_cfpb(csv_path: str, max_per_class: int = 50_000) -> Dataset: | |
| """ | |
| Load *csv_path*, remap CFPB product labels → G.U.I.D.E. domains, and return | |
| a HuggingFace Dataset with columns ``text`` and ``labels``. | |
| Reads in 200k-row chunks to stay memory-friendly on the 3M-row file. | |
| Caps each domain at *max_per_class* rows to reduce label imbalance. | |
| """ | |
| logger.info("Loading CFPB CSV from %s …", csv_path) | |
| chunks: list[pd.DataFrame] = [] | |
| for chunk in pd.read_csv( | |
| csv_path, | |
| usecols=lambda c: c.strip().lower() in ( | |
| "product", "consumer complaint narrative" | |
| ), | |
| chunksize=200_000, | |
| low_memory=False, | |
| encoding="utf-8", | |
| on_bad_lines="skip", | |
| ): | |
| # Normalise column names across CFPB dataset versions | |
| chunk.columns = [c.strip().lower() for c in chunk.columns] | |
| chunks.append(chunk) | |
| df = pd.concat(chunks, ignore_index=True) | |
| df = df.rename(columns={"consumer complaint narrative": "text"}) | |
| df = df.dropna(subset=["text"]) | |
| df["text"] = df["text"].astype(str).str.strip() | |
| df = df[df["text"].str.len() > 20] # drop near-empty narratives | |
| df["labels"] = df["product"].map( | |
| lambda p: _map_product(str(p)) if pd.notna(p) else None | |
| ) | |
| df = df.dropna(subset=["labels"]) | |
| # Cap per class (pandas 3.x compatible — groupby/apply dropped include_groups) | |
| df = pd.concat( | |
| [g.sample(min(len(g), max_per_class), random_state=42) | |
| for _, g in df.groupby("labels")], | |
| ignore_index=True, | |
| ) | |
| logger.info( | |
| "CFPB dataset: %d examples.\n%s", | |
| len(df), | |
| df["labels"].value_counts().to_string(), | |
| ) | |
| df["labels"] = df["labels"].map(DOMAIN2ID) | |
| features = Features({"text": Value("string"), "labels": Value("int64")}) | |
| return Dataset.from_pandas(df[["text", "labels"]], features=features, preserve_index=False) | |
| # --------------------------------------------------------------------------- | |
| # Synthetic supplement for under-represented domains | |
| # --------------------------------------------------------------------------- | |
| _SUPPLEMENT_TEMPLATES: dict[str, list[str]] = { | |
| "ecommerce": [ | |
| "My order from Flipkart placed {date} has not been delivered.", | |
| "Amazon India has not processed my refund of {amount} for the returned product.", | |
| "I received a damaged product from Myntra; my return request has been ignored.", | |
| "Snapdeal charged me {amount} twice for order {ref}.", | |
| "Meesho delivered the wrong item and has not responded to my complaint.", | |
| "The product ordered from Flipkart shows shipped but was not received after {days} days.", | |
| "Amazon India rejected my return request for a defective item worth {amount}.", | |
| "I have been waiting for a refund of {amount} from Myntra since {date}.", | |
| "My COD order from Flipkart was marked delivered but I never received it.", | |
| "Zomato charged me for an order that was never delivered on {date}.", | |
| "Swiggy refunded the wrong amount; I am owed {amount} more.", | |
| "Meesho has blocked my account without any notice or explanation.", | |
| "Snapdeal took {amount} from my wallet but the order was never placed.", | |
| "The product I received from Amazon India differs from what I ordered.", | |
| "Myntra has not honoured the discount of {amount} applied at checkout.", | |
| "My Flipkart subscription was renewed without consent, charging me {amount}.", | |
| "MakeMyTrip cancelled my booking and has not refunded {amount} since {date}.", | |
| "IRCTC charged my account {amount} but the ticket was not booked.", | |
| "The Uber ride on {date} overcharged me by {amount}.", | |
| "Ola Cabs deducted {amount} for a trip I never took.", | |
| # Hotel / travel OTAs (same escalation path as MakeMyTrip/IRCTC) | |
| "Agoda confirmed my hotel booking but the property has no record of it.", | |
| "OYO Rooms cancelled my booking on {date} without refunding {amount}.", | |
| "Booking.com India did not process my refund of {amount} after hotel cancellation.", | |
| "I booked a hotel via Agoda for {amount} on {date} but check-in was denied.", | |
| "OYO Rooms charged {amount} extra at check-out beyond the confirmed booking price.", | |
| ], | |
| "telecom": [ | |
| "Airtel deducted {amount} from my prepaid account without explanation.", | |
| "My Jio plan was not activated despite recharging {amount} on {date}.", | |
| "Vodafone Idea is billing me for services I never subscribed to.", | |
| "BSNL has not resolved my broadband issue since {date}.", | |
| "Airtel ported my number without my consent.", | |
| "Reliance Jio throttled my internet speed without any prior notice.", | |
| "My mobile number with Vodafone Idea was disconnected incorrectly.", | |
| "Airtel charged me {amount} extra every month without justification.", | |
| "I submitted a porting request to Jio on {date} but it has not been processed.", | |
| "BSNL customer care has been unresponsive to my repeated complaints.", | |
| "My Airtel postpaid bill of {amount} includes services I did not activate.", | |
| "Jio deducted {amount} from my account citing roaming charges during domestic travel.", | |
| "BSNL did not credit the cashback of {amount} promised on recharge.", | |
| "My Vodafone Idea sim was deactivated without notice on {date}.", | |
| "Airtel is charging me for a TV subscription I cancelled on {date}.", | |
| "I was not informed of the tariff change that increased my bill by {amount}.", | |
| "Jio's technical team promised resolution by {date} but has not followed up.", | |
| "My BSNL landline has been non-functional since {date}.", | |
| "Vodafone Idea charged me for international calls I never made.", | |
| "Airtel did not process my DND registration despite multiple requests.", | |
| # SIM activation and prepaid recharge scenarios (Trace 5 gap) | |
| "I bought a new SIM card from BSNL on {date} but it has not been activated yet.", | |
| "My prepaid recharge of {amount} on Airtel was successful but services are not reflecting.", | |
| "BSNL SIM activation is pending for {days} days despite multiple follow-ups.", | |
| "I recharged my Jio number with {amount} on {date} but the plan is not active.", | |
| "My new Vodafone Idea SIM was issued on {date} but calls and data are not working.", | |
| "Airtel prepaid recharge of {amount} deducted but balance not updated.", | |
| "I ported my number to BSNL on {date} but incoming calls are not connecting.", | |
| "My BSNL mobile number was deactivated after recharging {amount} on {date}.", | |
| "New SIM card from Jio purchased for {amount} but activation failed without reason.", | |
| "Prepaid plan of {amount} recharged on Vodafone Idea is not reflecting after {days} days.", | |
| ], | |
| "insurance": [ | |
| "My health insurance claim with LIC for {amount} was rejected without reason.", | |
| "Star Health Insurance has been delaying my settlement since {date}.", | |
| "New India Assurance denied my motor claim citing incorrect reasons.", | |
| "I have not received policy documents from ICICI Lombard despite paying the premium.", | |
| "An LIC agent collected {amount} as premium but my policy was never issued.", | |
| "My mediclaim with Star Health was rejected over a pre-existing condition I disclosed.", | |
| "HDFC ERGO has not settled claim {ref} filed on {date}.", | |
| "The LIC maturity amount has not been credited after {days} months.", | |
| "New India Assurance is not responding to my motor accident claim from {date}.", | |
| "I was mis-sold a policy worth {amount} by an ICICI Lombard agent.", | |
| "Star Health deducted the renewal premium but sent no confirmation.", | |
| "My LIC policy was lapsed without notice despite timely premium payments.", | |
| "IRDAI registered my complaint but the insurer has still not responded after {days} days.", | |
| "The surveyor from New India Assurance undervalued my claim by {amount}.", | |
| "ICICI Lombard cancelled my policy without refunding the unused premium of {amount}.", | |
| "I am unable to get cashless treatment because Star Health blocked my policy.", | |
| "LIC branch refused to accept my premium payment on {date}.", | |
| "My nominee was denied the claim on a valid life insurance policy.", | |
| "The health insurance top-up of {amount} was not applied despite the rider.", | |
| "New India Assurance took {days} days to appoint a surveyor for my claim.", | |
| ], | |
| "cibil": [ | |
| "My CIBIL score is incorrect due to a wrong default entry by HDFC Bank.", | |
| "Experian India shows a closed loan as active on my credit report.", | |
| "My credit score dropped because ICICI Bank filed an incorrect report.", | |
| "CIBIL has not updated my score after loan closure on {date}.", | |
| "There is a fraudulent credit card entry in my CIBIL report I never applied for.", | |
| "My credit report shows an outstanding of {amount} which I have already paid.", | |
| "TransUnion CIBIL has not removed the settled debt entry.", | |
| "Bank of Baroda incorrectly reported my account as an NPA.", | |
| "My Experian credit report has errors that are hurting my loan eligibility.", | |
| "Kotak Bank reported me as a defaulter despite on-time EMI payments.", | |
| "I disputed the incorrect entry with CIBIL on {date} but received no response.", | |
| "My CIBIL report still shows a write-off that was settled {days} months ago.", | |
| "HDFC Bank is reporting a closed personal loan as outstanding on my Experian report.", | |
| "The credit card I never applied for from Axis Bank is damaging my CIBIL score.", | |
| "My credit score fell by {amount} points due to an error in the CIBIL database.", | |
| "SBI has not updated the repayment history for my home loan with CIBIL.", | |
| "Two different banks are reporting the same debt, inflating my outstanding balance.", | |
| "I requested a free credit report from CIBIL on {date} but did not receive it.", | |
| "My loan account at PNB shows a missed payment that I have a receipt for.", | |
| "The debt collector listed on my Experian report does not match any known account.", | |
| ], | |
| "banking": [ | |
| "HDFC Bank deducted {amount} from my account without authorization.", | |
| "SBI has not processed my NEFT transfer of {amount} sent on {date}.", | |
| "ICICI Bank is charging hidden fees on my savings account.", | |
| "Axis Bank rejected my loan application without providing a reason.", | |
| "My ATM card was blocked by Kotak Bank without prior notice.", | |
| "PNB deducted maintenance charges from my zero-balance account.", | |
| "My fixed deposit with SBI matured on {date} but the amount has not been credited.", | |
| "HDFC Bank charged {amount} as processing fee but rejected my loan.", | |
| "Unauthorized transactions appeared in my Bank of Baroda account.", | |
| "ICICI Bank has not reversed the duplicate EMI of {amount}.", | |
| "My SBI account was frozen on {date} without any explanation.", | |
| "HDFC Bank did not credit the cashback of {amount} on my credit card.", | |
| "Axis Bank charged me a penalty of {amount} for an EMI that was auto-debited.", | |
| "I was charged {amount} for a Kotak Bank service I never requested.", | |
| "My Bank of Baroda ATM card was cloned and {amount} was withdrawn.", | |
| "PNB issued a cheque return charge of {amount} without valid reason.", | |
| "ICICI Bank approved my loan but the disbursement has been pending since {date}.", | |
| "The interest rate on my SBI home loan was changed without informing me.", | |
| "Axis Bank debited my account twice for the same EMI of {amount} on {date}.", | |
| "HDFC Bank has not closed my account despite a written request on {date}.", | |
| ], | |
| "general": [ | |
| "The company has not responded to my complaints since {date}.", | |
| "I paid {amount} for a service that was never delivered.", | |
| "The customer care team has been completely unresponsive to my grievance.", | |
| "I want to lodge a formal complaint about the defective product I purchased.", | |
| "Despite multiple follow-ups my refund of {amount} has not been processed.", | |
| "The company promised to resolve my issue by {date} but has not done so.", | |
| "I am being harassed by the collections team for a payment I already made.", | |
| "The service provider charged me {amount} for services not rendered.", | |
| "I have been trying to reach customer support for {days} days with no response.", | |
| "The company delivered a completely different product than what I ordered.", | |
| "My complaint {ref} has been closed without being resolved.", | |
| "I filed a grievance on {date} but there has been no acknowledgement.", | |
| "The company charged me {amount} for a warranty claim that should be free.", | |
| "I have written to the grievance officer but have not received a reply.", | |
| "The product warranty was denied despite the item being within the warranty period.", | |
| "I was promised a callback by {date} but no one has reached out.", | |
| "The company has deducted {amount} from my wallet and is not crediting it back.", | |
| "My subscription was renewed without consent, billing me {amount}.", | |
| "The service outage on {date} caused me a financial loss of {amount}.", | |
| "I never gave consent for the company to auto-renew and charge me {amount}.", | |
| ], | |
| } | |
| _FILLERS: dict[str, list[str]] = { | |
| "amount": [ | |
| "₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500", | |
| "Rs. 8,900", "Rs 15,000", "₹3,499", "₹25,000", "₹9,999", | |
| "₹35,000", "₹750", "Rs. 2,000", "₹18,500", "₹5,999", | |
| ], | |
| "date": [ | |
| "12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023", | |
| "three weeks ago", "two months ago", "last Tuesday", | |
| "on 15th February 2024", "last month", "on 30 June 2024", | |
| ], | |
| "ref": [ | |
| "#OD-2930291", "TXN987654321", "REF-20240312", "CMP-2024-001", | |
| "TKT-9876543", "CLM-2024-12345", "BK-56789", | |
| ], | |
| "days": ["7", "10", "15", "20", "30", "45", "60"], | |
| } | |
| _PLACEHOLDER_RE = re.compile(r"\{(\w+)\}") | |
| def _fill(template: str, rng: random.Random) -> str: | |
| def _replace(m: re.Match) -> str: | |
| key = m.group(1) | |
| options = _FILLERS.get(key) | |
| return rng.choice(options) if options else m.group(0) | |
| return _PLACEHOLDER_RE.sub(_replace, template) | |
| def _build_supplement(n_per_class: int = 5_000, seed: int = 42) -> Dataset: | |
| """ | |
| Generate *n_per_class* synthetic examples for each of the 6 domains. | |
| Used to ensure all domains are represented in training regardless of how | |
| CFPB data maps — particularly important for ecommerce, telecom, insurance. | |
| """ | |
| rng = random.Random(seed) | |
| rows: list[dict] = [] | |
| for domain, templates in _SUPPLEMENT_TEMPLATES.items(): | |
| generated = 0 | |
| seen: set[str] = set() | |
| max_attempts = n_per_class * 10 | |
| for _ in range(max_attempts): | |
| if generated >= n_per_class: | |
| break | |
| text = _fill(rng.choice(templates), rng) | |
| if text not in seen: | |
| seen.add(text) | |
| rows.append({"text": text, "labels": DOMAIN2ID[domain]}) | |
| generated += 1 | |
| logger.info("Synthetic supplement: %d examples (%d per class).", len(rows), n_per_class) | |
| features = Features({"text": Value("string"), "labels": Value("int64")}) | |
| return Dataset.from_list(rows, features=features) | |
| # --------------------------------------------------------------------------- | |
| # Tokeniser map function | |
| # --------------------------------------------------------------------------- | |
| def _make_tokenise_fn(tokenizer, pad_to_max: bool = False): | |
| def tokenise(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length" if pad_to_max else False, | |
| ) | |
| return tokenise | |
| # --------------------------------------------------------------------------- | |
| # Training entry point | |
| # --------------------------------------------------------------------------- | |
| def train(args: argparse.Namespace) -> None: | |
| """Fine-tune distilbert-base-uncased for 6-class domain classification.""" | |
| logging.basicConfig(level=logging.INFO) | |
| # 1. Data | |
| cfpb_ds = load_and_remap_cfpb(args.cfpb_csv, max_per_class=args.max_per_class) | |
| suppl_ds = _build_supplement(n_per_class=args.supplement_per_class) | |
| full_ds = concatenate_datasets([cfpb_ds, suppl_ds]).shuffle(seed=42) | |
| split = full_ds.train_test_split(test_size=0.1, seed=42) | |
| # 2. Tokenise | |
| import torch as _torch | |
| _use_mps = _torch.backends.mps.is_available() and not _torch.cuda.is_available() | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| # Fixed-length padding on MPS avoids graph recompilation on every new input shape | |
| tokenise_fn = _make_tokenise_fn(tokenizer, pad_to_max=_use_mps) | |
| tokenized = split.map( | |
| tokenise_fn, | |
| batched=True, | |
| remove_columns=["text"], | |
| ) | |
| # 3. Model | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "distilbert-base-uncased", | |
| num_labels=NUM_CLASSES, | |
| id2label=ID2DOMAIN, | |
| label2id=DOMAIN2ID, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| # 4. Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| logging_steps=100, | |
| report_to="none", | |
| dataloader_pin_memory=False, | |
| torch_compile=_use_mps, | |
| torch_compile_backend="aot_eager" if _use_mps else "inductor", | |
| ) | |
| # 5. Train — skip dynamic padding collator when sequences are already fixed-length | |
| _collator = None if _use_mps else DataCollatorWithPadding(tokenizer) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized["test"], | |
| data_collator=_collator, | |
| processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| logger.info("DomainClassifier checkpoint saved to %s", args.output_dir) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Train DomainClassifier") | |
| p.add_argument( | |
| "--cfpb_csv", required=True, | |
| help="Path to the CFPB Consumer Complaint CSV (data/raw/complaints.csv)", | |
| ) | |
| p.add_argument( | |
| "--output_dir", default="models/domain_classifier", | |
| help="Directory to save the fine-tuned checkpoint", | |
| ) | |
| p.add_argument( | |
| "--max_per_class", type=int, default=50_000, | |
| help="Maximum CFPB rows to use per remapped domain (caps class imbalance)", | |
| ) | |
| p.add_argument( | |
| "--supplement_per_class", type=int, default=5_000, | |
| help="Synthetic examples per class added to ensure all 6 domains are covered", | |
| ) | |
| p.add_argument("--epochs", type=int, default=3, help="Fine-tuning epochs") | |
| p.add_argument("--batch_size", type=int, default=32, help="Per-device batch size") | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| train(parse_args()) | |