Spaces:
Sleeping
Sleeping
| """ | |
| EvidenceNER training script. | |
| Data source: ~4 000 synthetic complaint sentences generated in-memory by | |
| build_synthetic_dataset(). No download required. | |
| Optionally augmented with CoNLL-2003 (via HuggingFace datasets) | |
| when internet is available; PER→PERSON, ORG→ORG; LOC/MISC discarded. | |
| CLI usage: | |
| python -m src.ner.train --output_dir models/evidence_ner | |
| python -m src.ner.train --output_dir models/evidence_ner --n_samples 6000 --epochs 5 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import random | |
| import re | |
| from datasets import Dataset, DatasetDict, concatenate_datasets | |
| from transformers import ( | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| DataCollatorForTokenClassification, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from src.ner.model import BIO_LABELS, ID2LABEL, LABEL2ID, NUM_LABELS, NER_LABELS | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Entity value banks | |
| # --------------------------------------------------------------------------- | |
| ENTITY_VALUES: dict[str, list[str]] = { | |
| "ORG": [ | |
| "Flipkart", "Amazon India", "Myntra", "Snapdeal", "Meesho", | |
| "HDFC Bank", "ICICI Bank", "State Bank of India", "Axis Bank", "Kotak Mahindra Bank", | |
| "Punjab National Bank", "Bank of Baroda", | |
| "Airtel", "Reliance Jio", "Vodafone Idea", "BSNL", | |
| "LIC of India", "Star Health Insurance", "New India Assurance", "ICICI Lombard", | |
| "CIBIL", "Experian India", | |
| "Swiggy", "Zomato", "Ola Cabs", "Uber India", "IRCTC", | |
| "MakeMyTrip", "Paytm", "PhonePe", | |
| # "Indian Bank" — public sector bank (HQ Chennai); distinct from the generic phrase | |
| "Indian Bank", | |
| # Health insurance providers | |
| "Niva Bupa", "Care Health Insurance", "Bajaj Allianz", | |
| # Travel / hospitality OTAs | |
| "Agoda", "OYO Rooms", "Booking.com India", | |
| # Fintech | |
| "Razorpay", "BharatPe", "CRED", | |
| ], | |
| "AMOUNT": [ | |
| "₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500", | |
| "Rs. 8,900", "Rs 15,000", "₹3,499", "₹1,00,000", "Rs. 499", | |
| "₹25,000", "₹750", "Rs. 1,50,000", "₹12,000", "₹5,999", | |
| "₹35,000", "₹800", "Rs. 2,000", "₹18,500", "₹9,999", | |
| ], | |
| "DATE": [ | |
| "12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023", | |
| "3 weeks ago", "two months ago", "last Tuesday", "last Friday", | |
| "on 15th February 2024", "15/01/2024", "on 5th April", | |
| "in December 2023", "last month", "three days ago", | |
| "on 22 May 2024", "on 30 June 2024", | |
| ], | |
| "REF_ID": [ | |
| "Order #OD-2930291", "transaction ID TXN987654321", | |
| "reference number REF-20240312-001", "ticket ID TKT-9876543", | |
| "complaint number CMP-2024-001", "booking ID BK-56789", | |
| "claim number CLM-2024-12345", "case ID CASE-789012", | |
| "loan reference LN-20240101", "policy number POL-123456789", | |
| "complaint reference CR-20240415", | |
| ], | |
| "ACCOUNT": [ | |
| "account ending in 4521", "loan account 9876543210", | |
| "savings account XXXX1234", "credit card ending 9087", | |
| "account number XXXX-4321", "demat account IN12345678", | |
| "current account", "fixed deposit account", | |
| "joint savings account", "salary account", | |
| ], | |
| "PERSON": [ | |
| "customer care executive", "branch manager", | |
| "relationship manager", "loan officer", | |
| "insurance agent", "delivery executive", | |
| "the support agent", "your representative", | |
| "technical support executive", "grievance officer", | |
| "account manager", | |
| ], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Sentence templates | |
| # Placeholders: {ORG} {AMOUNT} {DATE} {REF_ID} {ACCOUNT} {PERSON} | |
| # --------------------------------------------------------------------------- | |
| TEMPLATES: list[str] = [ | |
| # --- Single entity --- | |
| "I want to file a complaint against {ORG}.", | |
| "{ORG} has been completely unresponsive to my grievances.", | |
| "I am a long-standing customer of {ORG} and I am deeply dissatisfied.", | |
| "A deduction of {AMOUNT} was made from my account without authorization.", | |
| "I am owed a refund of {AMOUNT} which has not been credited.", | |
| "Please reverse the incorrect charge of {AMOUNT} immediately.", | |
| "The incident occurred on {DATE} and has not been resolved since.", | |
| "I filed a formal complaint on {DATE} but have received no response.", | |
| "I am writing with reference to {REF_ID} which remains unresolved.", | |
| "Please look into complaint {REF_ID} at the earliest.", | |
| "My {ACCOUNT} has been showing incorrect transactions for several weeks.", | |
| "The {ACCOUNT} was blocked without any prior notice.", | |
| "The {PERSON} I spoke to was completely unhelpful and dismissive.", | |
| "I was promised by a {PERSON} that the issue would be resolved within 24 hours.", | |
| # --- ORG + AMOUNT --- | |
| "I ordered from {ORG} but was incorrectly charged {AMOUNT}.", | |
| "{ORG} has deducted {AMOUNT} from my account without my consent.", | |
| "I am requesting a refund of {AMOUNT} from {ORG} for a defective product.", | |
| "{ORG} charged me {AMOUNT} for a service I never subscribed to.", | |
| "Despite cancellation {ORG} has not refunded {AMOUNT} to date.", | |
| "{ORG} owes me {AMOUNT} as compensation for the inconvenience caused.", | |
| "I was billed {AMOUNT} by {ORG} in error and seek immediate correction.", | |
| # --- ORG + DATE --- | |
| "I filed a complaint with {ORG} on {DATE} but have received no update.", | |
| "{ORG} failed to deliver my order by the promised date of {DATE}.", | |
| "I visited the {ORG} branch on {DATE} but the issue was not resolved.", | |
| "Since {DATE} {ORG} has not responded to any of my communications.", | |
| # --- ORG + REF_ID --- | |
| "My complaint {REF_ID} with {ORG} has been unresolved for several weeks.", | |
| "I am following up on {REF_ID} raised with {ORG}.", | |
| "{ORG} has not taken any action on my ticket {REF_ID}.", | |
| # --- ORG + ACCOUNT --- | |
| "{ORG} debited funds from my {ACCOUNT} without my knowledge.", | |
| "I noticed that {ORG} had incorrectly blocked my {ACCOUNT}.", | |
| "The {ACCOUNT} with {ORG} has been showing erroneous entries.", | |
| # --- ORG + PERSON --- | |
| "The {PERSON} at {ORG} was rude and refused to address my concern.", | |
| "A {PERSON} from {ORG} promised to resolve the issue but never followed up.", | |
| "I spoke to a {PERSON} at {ORG} who assured me of a refund.", | |
| "The {PERSON} at {ORG} refused to process my refund request.", | |
| # --- AMOUNT + DATE --- | |
| "An unauthorized transaction of {AMOUNT} occurred on {DATE}.", | |
| "The refund of {AMOUNT} promised for {DATE} was never processed.", | |
| "On {DATE} a deduction of {AMOUNT} appeared on my account without reason.", | |
| # --- AMOUNT + ACCOUNT --- | |
| "{AMOUNT} was wrongly debited from my {ACCOUNT} and I request an immediate refund.", | |
| "My {ACCOUNT} shows an erroneous charge of {AMOUNT} that I did not authorize.", | |
| "{AMOUNT} was deducted from my {ACCOUNT} without my knowledge.", | |
| # --- AMOUNT + REF_ID --- | |
| "Transaction {REF_ID} of {AMOUNT} is disputed and I seek reversal.", | |
| "I raised complaint {REF_ID} against an incorrect charge of {AMOUNT}.", | |
| # --- DATE + REF_ID --- | |
| "I raised {REF_ID} on {DATE} and have not received any resolution.", | |
| "As of {DATE} my complaint {REF_ID} remains open with no action taken.", | |
| # --- PERSON + ACCOUNT --- | |
| "The {PERSON} disconnected my call without resolving the issue with my {ACCOUNT}.", | |
| "A {PERSON} assured me that my {ACCOUNT} would be unblocked within 24 hours.", | |
| # --- ORG + AMOUNT + DATE --- | |
| "I placed an order with {ORG} for {AMOUNT} on {DATE} but it was never delivered.", | |
| "{ORG} charged {AMOUNT} to my account on {DATE} without any authorization.", | |
| "On {DATE} I paid {AMOUNT} to {ORG} but the service was not provided as promised.", | |
| "I cancelled my subscription with {ORG} on {DATE} but the refund of {AMOUNT} has not been credited.", | |
| "{ORG} promised to refund {AMOUNT} by {DATE} but has failed to do so.", | |
| "I was billed {AMOUNT} by {ORG} for a service cancelled on {DATE}.", | |
| "A duplicate payment of {AMOUNT} to {ORG} made on {DATE} has not been reversed.", | |
| # --- ORG + AMOUNT + REF_ID --- | |
| "My order {REF_ID} from {ORG} worth {AMOUNT} was returned but the refund was not received.", | |
| "I raised complaint {REF_ID} with {ORG} regarding an erroneous charge of {AMOUNT}.", | |
| "{ORG} owes me {AMOUNT} against transaction reference {REF_ID}.", | |
| "Despite follow-up on ticket {REF_ID} {ORG} has not refunded {AMOUNT}.", | |
| # --- ORG + ACCOUNT + AMOUNT --- | |
| "{ORG} debited {AMOUNT} from my {ACCOUNT} without my consent.", | |
| "I noticed an unauthorized charge of {AMOUNT} on my {ACCOUNT} with {ORG}.", | |
| "{ORG} has applied a penalty of {AMOUNT} to my {ACCOUNT} without prior notice.", | |
| "Funds amounting to {AMOUNT} were withdrawn from my {ACCOUNT} at {ORG} without authorization.", | |
| # --- ORG + DATE + REF_ID --- | |
| "I filed complaint {REF_ID} with {ORG} on {DATE} and request immediate resolution.", | |
| "My ticket {REF_ID} raised with {ORG} on {DATE} has not been addressed.", | |
| "Since {DATE} {ORG} has not responded to my complaint {REF_ID}.", | |
| # --- ORG + ACCOUNT + DATE --- | |
| "{ORG} deducted funds from my {ACCOUNT} on {DATE} without any prior notification.", | |
| "I noticed on {DATE} that {ORG} had placed an incorrect hold on my {ACCOUNT}.", | |
| # --- PERSON + ORG + AMOUNT --- | |
| "The {PERSON} at {ORG} assured me of a refund of {AMOUNT} which I am yet to receive.", | |
| "A {PERSON} from {ORG} processed an unauthorized deduction of {AMOUNT} from my account.", | |
| # --- AMOUNT + DATE + REF_ID --- | |
| "Transaction {REF_ID} of {AMOUNT} made on {DATE} was unauthorized and I seek reversal.", | |
| "On {DATE} I filed complaint {REF_ID} for the recovery of {AMOUNT}.", | |
| # --- ORG + AMOUNT + DATE + REF_ID --- | |
| "My order {REF_ID} from {ORG} placed on {DATE} for {AMOUNT} was cancelled without a refund.", | |
| "I paid {AMOUNT} to {ORG} on {DATE} against reference {REF_ID} but the service was not rendered.", | |
| "{ORG} charged {AMOUNT} on {DATE} referencing {REF_ID} without my authorization.", | |
| "Despite raising {REF_ID} with {ORG} on {DATE} the refund of {AMOUNT} remains pending.", | |
| # --- ORG + ACCOUNT + AMOUNT + DATE --- | |
| "On {DATE} {ORG} debited {AMOUNT} from my {ACCOUNT} without authorization.", | |
| "{ORG} withdrew {AMOUNT} from my {ACCOUNT} on {DATE} citing a technical error.", | |
| # --- PERSON + ORG + AMOUNT + DATE --- | |
| "The {PERSON} at {ORG} processed a transaction of {AMOUNT} on {DATE} without my knowledge.", | |
| "On {DATE} a {PERSON} from {ORG} assured me the disputed {AMOUNT} would be refunded.", | |
| # --- ORG + ACCOUNT + AMOUNT + REF_ID --- | |
| "I raised {REF_ID} with {ORG} regarding {AMOUNT} wrongly debited from my {ACCOUNT}.", | |
| "{ORG} has not reversed {AMOUNT} credited to {ACCOUNT} as per complaint {REF_ID}.", | |
| # --- Five / six entities --- | |
| "On {DATE} the {PERSON} at {ORG} deducted {AMOUNT} from my {ACCOUNT} against {REF_ID}.", | |
| "I spoke to a {PERSON} from {ORG} on {DATE} regarding {REF_ID} worth {AMOUNT} debited from my {ACCOUNT}.", | |
| "The {PERSON} at {ORG} confirmed on {DATE} that {REF_ID} of {AMOUNT} debited from {ACCOUNT} would be reversed.", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Template filling helpers | |
| # --------------------------------------------------------------------------- | |
| _SLOT_RE = re.compile(r"\{(ORG|AMOUNT|DATE|REF_ID|ACCOUNT|PERSON)\}") | |
| _WORD_RE = re.compile(r"\S+") | |
| def _extract_slots(template: str) -> list[str]: | |
| """Return ordered list of slot labels in *template*.""" | |
| return _SLOT_RE.findall(template) | |
| def _fill_template( | |
| template: str, slot_values: dict[str, str] | |
| ) -> tuple[str, list[dict]]: | |
| """ | |
| Fill slots and return (sentence, entity_spans). | |
| entity_spans: [{"start": int, "end": int, "label": str, "text": str}, ...] | |
| """ | |
| parts = _SLOT_RE.split(template) | |
| # re.split with a capturing group interleaves: [text, label, text, label, ..., text] | |
| sentence = "" | |
| spans: list[dict] = [] | |
| for i, part in enumerate(parts): | |
| if i % 2 == 0: | |
| sentence += part | |
| else: | |
| label = part | |
| value = slot_values[label] | |
| start = len(sentence) | |
| sentence += value | |
| spans.append({"start": start, "end": start + len(value), "label": label}) | |
| return sentence, spans | |
| def _word_tokenize(sentence: str) -> list[tuple[str, int, int]]: | |
| """Tokenise *sentence* into (word, char_start, char_end) tuples.""" | |
| return [(m.group(), m.start(), m.end()) for m in _WORD_RE.finditer(sentence)] | |
| def _assign_bio_labels( | |
| words: list[tuple[str, int, int]], entity_spans: list[dict] | |
| ) -> list[int]: | |
| """ | |
| Assign a BIO label id to each word token. | |
| A word is "inside" an entity when its start character falls within | |
| [span.start, span.end). The first such word in each span gets B-, | |
| subsequent words get I-. | |
| """ | |
| labels = ["O"] * len(words) | |
| for span in entity_spans: | |
| first_in_span = True | |
| for i, (_word, wstart, _wend) in enumerate(words): | |
| if span["start"] <= wstart < span["end"]: | |
| bio = "B" if first_in_span else "I" | |
| labels[i] = f"{bio}-{span['label']}" | |
| first_in_span = False | |
| return [LABEL2ID[lbl] for lbl in labels] | |
| # --------------------------------------------------------------------------- | |
| # Synthetic dataset builder | |
| # --------------------------------------------------------------------------- | |
| def build_synthetic_dataset(n_samples: int = 4000, seed: int = 42) -> Dataset: | |
| """ | |
| Generate *n_samples* labelled complaint sentences in memory. | |
| Returns a HuggingFace Dataset with columns: | |
| words : list[str] — whitespace-split word tokens | |
| ner_tags : list[int] — BIO label id per word | |
| """ | |
| rng = random.Random(seed) | |
| examples: list[dict] = [] | |
| seen: set[str] = set() | |
| max_attempts = n_samples * 8 | |
| for _ in range(max_attempts): | |
| if len(examples) >= n_samples: | |
| break | |
| template = rng.choice(TEMPLATES) | |
| slots = _extract_slots(template) | |
| slot_values = {s: rng.choice(ENTITY_VALUES[s]) for s in slots} | |
| sentence, spans = _fill_template(template, slot_values) | |
| if sentence in seen: | |
| continue | |
| seen.add(sentence) | |
| words = _word_tokenize(sentence) | |
| examples.append({ | |
| "words": [w for w, _, _ in words], | |
| "ner_tags": _assign_bio_labels(words, spans), | |
| }) | |
| logger.info("Synthetic dataset: %d examples generated.", len(examples)) | |
| return Dataset.from_list(examples) | |
| # --------------------------------------------------------------------------- | |
| # CoNLL-2003 augmentation (optional — silently skipped if unavailable) | |
| # --------------------------------------------------------------------------- | |
| _CONLL_LABEL_MAP = {"PER": "PERSON", "ORG": "ORG"} # LOC / MISC discarded | |
| def _try_load_conll() -> Dataset | None: | |
| """ | |
| Attempt to load CoNLL-2003 from HuggingFace Hub and remap to G.U.I.D.E. labels. | |
| Returns None if the dataset is unavailable (no internet, auth error, etc.). | |
| """ | |
| try: | |
| from datasets import load_dataset | |
| conll = load_dataset("conll2003") | |
| train_split = conll["train"] | |
| conll_id2label: dict[int, str] = { | |
| i: name for i, name in enumerate(train_split.features["ner_tags"].feature.names) | |
| } | |
| remapped: list[dict] = [] | |
| for example in train_split: | |
| new_tags: list[int] = [] | |
| for tag_id in example["ner_tags"]: | |
| conll_label = conll_id2label[tag_id] # e.g. "B-PER", "I-ORG", "O" | |
| if conll_label == "O": | |
| new_tags.append(LABEL2ID["O"]) | |
| continue | |
| bio, etype = conll_label.split("-", 1) | |
| mapped = _CONLL_LABEL_MAP.get(etype) | |
| if mapped is None: | |
| new_tags.append(LABEL2ID["O"]) # discard LOC / MISC | |
| else: | |
| new_tags.append(LABEL2ID[f"{bio}-{mapped}"]) | |
| remapped.append({"words": example["tokens"], "ner_tags": new_tags}) | |
| logger.info("CoNLL-2003 augmentation: %d examples loaded.", len(remapped)) | |
| return Dataset.from_list(remapped) | |
| except Exception: | |
| logger.info("CoNLL-2003 unavailable — skipping augmentation.") | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Tokeniser alignment | |
| # --------------------------------------------------------------------------- | |
| def _make_tokenise_fn(tokenizer): | |
| """ | |
| Return a batched map function that tokenises word sequences and aligns | |
| BIO labels to subword tokens using word_ids(). | |
| Only the first subword of each word receives its word's label; remaining | |
| subwords receive -100 (ignored by CrossEntropyLoss). | |
| """ | |
| def tokenise_and_align(examples): | |
| tokenized = tokenizer( | |
| examples["words"], | |
| truncation=True, | |
| max_length=512, | |
| is_split_into_words=True, | |
| ) | |
| all_labels: list[list[int]] = [] | |
| for i, word_labels in enumerate(examples["ner_tags"]): | |
| word_ids = tokenized.word_ids(batch_index=i) | |
| prev_word_id = None | |
| labels: list[int] = [] | |
| for word_id in word_ids: | |
| if word_id is None: | |
| labels.append(-100) # special token | |
| elif word_id != prev_word_id: | |
| labels.append(word_labels[word_id]) # first subword → real label | |
| else: | |
| labels.append(-100) # continuation subword → ignored | |
| prev_word_id = word_id | |
| all_labels.append(labels) | |
| tokenized["labels"] = all_labels | |
| return tokenized | |
| return tokenise_and_align | |
| # --------------------------------------------------------------------------- | |
| # Training entry point | |
| # --------------------------------------------------------------------------- | |
| def train(args: argparse.Namespace) -> None: | |
| """Fine-tune distilbert-base-uncased for BIO token classification.""" | |
| logging.basicConfig(level=logging.INFO) | |
| # 1. Build dataset | |
| synthetic_ds = build_synthetic_dataset(n_samples=args.n_samples) | |
| conll_ds = _try_load_conll() | |
| if conll_ds is not None: | |
| full_ds = concatenate_datasets([synthetic_ds, conll_ds]) | |
| full_ds = full_ds.shuffle(seed=42) | |
| else: | |
| full_ds = synthetic_ds | |
| split = full_ds.train_test_split(test_size=0.1, seed=42) | |
| # 2. Tokenise | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| tokenise_fn = _make_tokenise_fn(tokenizer) | |
| tokenized = DatasetDict({ | |
| "train": split["train"].map(tokenise_fn, batched=True, | |
| remove_columns=["words", "ner_tags"]), | |
| "eval": split["test"].map(tokenise_fn, batched=True, | |
| remove_columns=["words", "ner_tags"]), | |
| }) | |
| # 3. Model | |
| model = AutoModelForTokenClassification.from_pretrained( | |
| "distilbert-base-uncased", | |
| num_labels=NUM_LABELS, | |
| id2label=ID2LABEL, | |
| label2id=LABEL2ID, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| # 4. Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.batch_size, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| logging_steps=50, | |
| report_to="none", | |
| ) | |
| # 5. Train | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized["eval"], | |
| data_collator=DataCollatorForTokenClassification(tokenizer), | |
| processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| logger.info("EvidenceNER checkpoint saved to %s", args.output_dir) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Train EvidenceNER") | |
| p.add_argument("--output_dir", default="models/evidence_ner", | |
| help="Directory to save the fine-tuned checkpoint") | |
| p.add_argument("--n_samples", type=int, default=4000, | |
| help="Number of synthetic training sentences to generate") | |
| p.add_argument("--epochs", type=int, default=4, | |
| help="Number of fine-tuning epochs") | |
| p.add_argument("--batch_size", type=int, default=16, | |
| help="Per-device train/eval batch size") | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| train(parse_args()) | |