guide / src /ner /train.py
Saravanakumar R
intial traces bug fixes commit
b016462
Raw
History Blame Contribute Delete
21.3 kB
"""
EvidenceNER training script.
Data source: ~4 000 synthetic complaint sentences generated in-memory by
build_synthetic_dataset(). No download required.
Optionally augmented with CoNLL-2003 (via HuggingFace datasets)
when internet is available; PER→PERSON, ORG→ORG; LOC/MISC discarded.
CLI usage:
python -m src.ner.train --output_dir models/evidence_ner
python -m src.ner.train --output_dir models/evidence_ner --n_samples 6000 --epochs 5
"""
from __future__ import annotations
import argparse
import logging
import random
import re
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
)
from src.ner.model import BIO_LABELS, ID2LABEL, LABEL2ID, NUM_LABELS, NER_LABELS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Entity value banks
# ---------------------------------------------------------------------------
ENTITY_VALUES: dict[str, list[str]] = {
"ORG": [
"Flipkart", "Amazon India", "Myntra", "Snapdeal", "Meesho",
"HDFC Bank", "ICICI Bank", "State Bank of India", "Axis Bank", "Kotak Mahindra Bank",
"Punjab National Bank", "Bank of Baroda",
"Airtel", "Reliance Jio", "Vodafone Idea", "BSNL",
"LIC of India", "Star Health Insurance", "New India Assurance", "ICICI Lombard",
"CIBIL", "Experian India",
"Swiggy", "Zomato", "Ola Cabs", "Uber India", "IRCTC",
"MakeMyTrip", "Paytm", "PhonePe",
# "Indian Bank" — public sector bank (HQ Chennai); distinct from the generic phrase
"Indian Bank",
# Health insurance providers
"Niva Bupa", "Care Health Insurance", "Bajaj Allianz",
# Travel / hospitality OTAs
"Agoda", "OYO Rooms", "Booking.com India",
# Fintech
"Razorpay", "BharatPe", "CRED",
],
"AMOUNT": [
"₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500",
"Rs. 8,900", "Rs 15,000", "₹3,499", "₹1,00,000", "Rs. 499",
"₹25,000", "₹750", "Rs. 1,50,000", "₹12,000", "₹5,999",
"₹35,000", "₹800", "Rs. 2,000", "₹18,500", "₹9,999",
],
"DATE": [
"12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023",
"3 weeks ago", "two months ago", "last Tuesday", "last Friday",
"on 15th February 2024", "15/01/2024", "on 5th April",
"in December 2023", "last month", "three days ago",
"on 22 May 2024", "on 30 June 2024",
],
"REF_ID": [
"Order #OD-2930291", "transaction ID TXN987654321",
"reference number REF-20240312-001", "ticket ID TKT-9876543",
"complaint number CMP-2024-001", "booking ID BK-56789",
"claim number CLM-2024-12345", "case ID CASE-789012",
"loan reference LN-20240101", "policy number POL-123456789",
"complaint reference CR-20240415",
],
"ACCOUNT": [
"account ending in 4521", "loan account 9876543210",
"savings account XXXX1234", "credit card ending 9087",
"account number XXXX-4321", "demat account IN12345678",
"current account", "fixed deposit account",
"joint savings account", "salary account",
],
"PERSON": [
"customer care executive", "branch manager",
"relationship manager", "loan officer",
"insurance agent", "delivery executive",
"the support agent", "your representative",
"technical support executive", "grievance officer",
"account manager",
],
}
# ---------------------------------------------------------------------------
# Sentence templates
# Placeholders: {ORG} {AMOUNT} {DATE} {REF_ID} {ACCOUNT} {PERSON}
# ---------------------------------------------------------------------------
TEMPLATES: list[str] = [
# --- Single entity ---
"I want to file a complaint against {ORG}.",
"{ORG} has been completely unresponsive to my grievances.",
"I am a long-standing customer of {ORG} and I am deeply dissatisfied.",
"A deduction of {AMOUNT} was made from my account without authorization.",
"I am owed a refund of {AMOUNT} which has not been credited.",
"Please reverse the incorrect charge of {AMOUNT} immediately.",
"The incident occurred on {DATE} and has not been resolved since.",
"I filed a formal complaint on {DATE} but have received no response.",
"I am writing with reference to {REF_ID} which remains unresolved.",
"Please look into complaint {REF_ID} at the earliest.",
"My {ACCOUNT} has been showing incorrect transactions for several weeks.",
"The {ACCOUNT} was blocked without any prior notice.",
"The {PERSON} I spoke to was completely unhelpful and dismissive.",
"I was promised by a {PERSON} that the issue would be resolved within 24 hours.",
# --- ORG + AMOUNT ---
"I ordered from {ORG} but was incorrectly charged {AMOUNT}.",
"{ORG} has deducted {AMOUNT} from my account without my consent.",
"I am requesting a refund of {AMOUNT} from {ORG} for a defective product.",
"{ORG} charged me {AMOUNT} for a service I never subscribed to.",
"Despite cancellation {ORG} has not refunded {AMOUNT} to date.",
"{ORG} owes me {AMOUNT} as compensation for the inconvenience caused.",
"I was billed {AMOUNT} by {ORG} in error and seek immediate correction.",
# --- ORG + DATE ---
"I filed a complaint with {ORG} on {DATE} but have received no update.",
"{ORG} failed to deliver my order by the promised date of {DATE}.",
"I visited the {ORG} branch on {DATE} but the issue was not resolved.",
"Since {DATE} {ORG} has not responded to any of my communications.",
# --- ORG + REF_ID ---
"My complaint {REF_ID} with {ORG} has been unresolved for several weeks.",
"I am following up on {REF_ID} raised with {ORG}.",
"{ORG} has not taken any action on my ticket {REF_ID}.",
# --- ORG + ACCOUNT ---
"{ORG} debited funds from my {ACCOUNT} without my knowledge.",
"I noticed that {ORG} had incorrectly blocked my {ACCOUNT}.",
"The {ACCOUNT} with {ORG} has been showing erroneous entries.",
# --- ORG + PERSON ---
"The {PERSON} at {ORG} was rude and refused to address my concern.",
"A {PERSON} from {ORG} promised to resolve the issue but never followed up.",
"I spoke to a {PERSON} at {ORG} who assured me of a refund.",
"The {PERSON} at {ORG} refused to process my refund request.",
# --- AMOUNT + DATE ---
"An unauthorized transaction of {AMOUNT} occurred on {DATE}.",
"The refund of {AMOUNT} promised for {DATE} was never processed.",
"On {DATE} a deduction of {AMOUNT} appeared on my account without reason.",
# --- AMOUNT + ACCOUNT ---
"{AMOUNT} was wrongly debited from my {ACCOUNT} and I request an immediate refund.",
"My {ACCOUNT} shows an erroneous charge of {AMOUNT} that I did not authorize.",
"{AMOUNT} was deducted from my {ACCOUNT} without my knowledge.",
# --- AMOUNT + REF_ID ---
"Transaction {REF_ID} of {AMOUNT} is disputed and I seek reversal.",
"I raised complaint {REF_ID} against an incorrect charge of {AMOUNT}.",
# --- DATE + REF_ID ---
"I raised {REF_ID} on {DATE} and have not received any resolution.",
"As of {DATE} my complaint {REF_ID} remains open with no action taken.",
# --- PERSON + ACCOUNT ---
"The {PERSON} disconnected my call without resolving the issue with my {ACCOUNT}.",
"A {PERSON} assured me that my {ACCOUNT} would be unblocked within 24 hours.",
# --- ORG + AMOUNT + DATE ---
"I placed an order with {ORG} for {AMOUNT} on {DATE} but it was never delivered.",
"{ORG} charged {AMOUNT} to my account on {DATE} without any authorization.",
"On {DATE} I paid {AMOUNT} to {ORG} but the service was not provided as promised.",
"I cancelled my subscription with {ORG} on {DATE} but the refund of {AMOUNT} has not been credited.",
"{ORG} promised to refund {AMOUNT} by {DATE} but has failed to do so.",
"I was billed {AMOUNT} by {ORG} for a service cancelled on {DATE}.",
"A duplicate payment of {AMOUNT} to {ORG} made on {DATE} has not been reversed.",
# --- ORG + AMOUNT + REF_ID ---
"My order {REF_ID} from {ORG} worth {AMOUNT} was returned but the refund was not received.",
"I raised complaint {REF_ID} with {ORG} regarding an erroneous charge of {AMOUNT}.",
"{ORG} owes me {AMOUNT} against transaction reference {REF_ID}.",
"Despite follow-up on ticket {REF_ID} {ORG} has not refunded {AMOUNT}.",
# --- ORG + ACCOUNT + AMOUNT ---
"{ORG} debited {AMOUNT} from my {ACCOUNT} without my consent.",
"I noticed an unauthorized charge of {AMOUNT} on my {ACCOUNT} with {ORG}.",
"{ORG} has applied a penalty of {AMOUNT} to my {ACCOUNT} without prior notice.",
"Funds amounting to {AMOUNT} were withdrawn from my {ACCOUNT} at {ORG} without authorization.",
# --- ORG + DATE + REF_ID ---
"I filed complaint {REF_ID} with {ORG} on {DATE} and request immediate resolution.",
"My ticket {REF_ID} raised with {ORG} on {DATE} has not been addressed.",
"Since {DATE} {ORG} has not responded to my complaint {REF_ID}.",
# --- ORG + ACCOUNT + DATE ---
"{ORG} deducted funds from my {ACCOUNT} on {DATE} without any prior notification.",
"I noticed on {DATE} that {ORG} had placed an incorrect hold on my {ACCOUNT}.",
# --- PERSON + ORG + AMOUNT ---
"The {PERSON} at {ORG} assured me of a refund of {AMOUNT} which I am yet to receive.",
"A {PERSON} from {ORG} processed an unauthorized deduction of {AMOUNT} from my account.",
# --- AMOUNT + DATE + REF_ID ---
"Transaction {REF_ID} of {AMOUNT} made on {DATE} was unauthorized and I seek reversal.",
"On {DATE} I filed complaint {REF_ID} for the recovery of {AMOUNT}.",
# --- ORG + AMOUNT + DATE + REF_ID ---
"My order {REF_ID} from {ORG} placed on {DATE} for {AMOUNT} was cancelled without a refund.",
"I paid {AMOUNT} to {ORG} on {DATE} against reference {REF_ID} but the service was not rendered.",
"{ORG} charged {AMOUNT} on {DATE} referencing {REF_ID} without my authorization.",
"Despite raising {REF_ID} with {ORG} on {DATE} the refund of {AMOUNT} remains pending.",
# --- ORG + ACCOUNT + AMOUNT + DATE ---
"On {DATE} {ORG} debited {AMOUNT} from my {ACCOUNT} without authorization.",
"{ORG} withdrew {AMOUNT} from my {ACCOUNT} on {DATE} citing a technical error.",
# --- PERSON + ORG + AMOUNT + DATE ---
"The {PERSON} at {ORG} processed a transaction of {AMOUNT} on {DATE} without my knowledge.",
"On {DATE} a {PERSON} from {ORG} assured me the disputed {AMOUNT} would be refunded.",
# --- ORG + ACCOUNT + AMOUNT + REF_ID ---
"I raised {REF_ID} with {ORG} regarding {AMOUNT} wrongly debited from my {ACCOUNT}.",
"{ORG} has not reversed {AMOUNT} credited to {ACCOUNT} as per complaint {REF_ID}.",
# --- Five / six entities ---
"On {DATE} the {PERSON} at {ORG} deducted {AMOUNT} from my {ACCOUNT} against {REF_ID}.",
"I spoke to a {PERSON} from {ORG} on {DATE} regarding {REF_ID} worth {AMOUNT} debited from my {ACCOUNT}.",
"The {PERSON} at {ORG} confirmed on {DATE} that {REF_ID} of {AMOUNT} debited from {ACCOUNT} would be reversed.",
]
# ---------------------------------------------------------------------------
# Template filling helpers
# ---------------------------------------------------------------------------
_SLOT_RE = re.compile(r"\{(ORG|AMOUNT|DATE|REF_ID|ACCOUNT|PERSON)\}")
_WORD_RE = re.compile(r"\S+")
def _extract_slots(template: str) -> list[str]:
"""Return ordered list of slot labels in *template*."""
return _SLOT_RE.findall(template)
def _fill_template(
template: str, slot_values: dict[str, str]
) -> tuple[str, list[dict]]:
"""
Fill slots and return (sentence, entity_spans).
entity_spans: [{"start": int, "end": int, "label": str, "text": str}, ...]
"""
parts = _SLOT_RE.split(template)
# re.split with a capturing group interleaves: [text, label, text, label, ..., text]
sentence = ""
spans: list[dict] = []
for i, part in enumerate(parts):
if i % 2 == 0:
sentence += part
else:
label = part
value = slot_values[label]
start = len(sentence)
sentence += value
spans.append({"start": start, "end": start + len(value), "label": label})
return sentence, spans
def _word_tokenize(sentence: str) -> list[tuple[str, int, int]]:
"""Tokenise *sentence* into (word, char_start, char_end) tuples."""
return [(m.group(), m.start(), m.end()) for m in _WORD_RE.finditer(sentence)]
def _assign_bio_labels(
words: list[tuple[str, int, int]], entity_spans: list[dict]
) -> list[int]:
"""
Assign a BIO label id to each word token.
A word is "inside" an entity when its start character falls within
[span.start, span.end). The first such word in each span gets B-,
subsequent words get I-.
"""
labels = ["O"] * len(words)
for span in entity_spans:
first_in_span = True
for i, (_word, wstart, _wend) in enumerate(words):
if span["start"] <= wstart < span["end"]:
bio = "B" if first_in_span else "I"
labels[i] = f"{bio}-{span['label']}"
first_in_span = False
return [LABEL2ID[lbl] for lbl in labels]
# ---------------------------------------------------------------------------
# Synthetic dataset builder
# ---------------------------------------------------------------------------
def build_synthetic_dataset(n_samples: int = 4000, seed: int = 42) -> Dataset:
"""
Generate *n_samples* labelled complaint sentences in memory.
Returns a HuggingFace Dataset with columns:
words : list[str] — whitespace-split word tokens
ner_tags : list[int] — BIO label id per word
"""
rng = random.Random(seed)
examples: list[dict] = []
seen: set[str] = set()
max_attempts = n_samples * 8
for _ in range(max_attempts):
if len(examples) >= n_samples:
break
template = rng.choice(TEMPLATES)
slots = _extract_slots(template)
slot_values = {s: rng.choice(ENTITY_VALUES[s]) for s in slots}
sentence, spans = _fill_template(template, slot_values)
if sentence in seen:
continue
seen.add(sentence)
words = _word_tokenize(sentence)
examples.append({
"words": [w for w, _, _ in words],
"ner_tags": _assign_bio_labels(words, spans),
})
logger.info("Synthetic dataset: %d examples generated.", len(examples))
return Dataset.from_list(examples)
# ---------------------------------------------------------------------------
# CoNLL-2003 augmentation (optional — silently skipped if unavailable)
# ---------------------------------------------------------------------------
_CONLL_LABEL_MAP = {"PER": "PERSON", "ORG": "ORG"} # LOC / MISC discarded
def _try_load_conll() -> Dataset | None:
"""
Attempt to load CoNLL-2003 from HuggingFace Hub and remap to G.U.I.D.E. labels.
Returns None if the dataset is unavailable (no internet, auth error, etc.).
"""
try:
from datasets import load_dataset
conll = load_dataset("conll2003")
train_split = conll["train"]
conll_id2label: dict[int, str] = {
i: name for i, name in enumerate(train_split.features["ner_tags"].feature.names)
}
remapped: list[dict] = []
for example in train_split:
new_tags: list[int] = []
for tag_id in example["ner_tags"]:
conll_label = conll_id2label[tag_id] # e.g. "B-PER", "I-ORG", "O"
if conll_label == "O":
new_tags.append(LABEL2ID["O"])
continue
bio, etype = conll_label.split("-", 1)
mapped = _CONLL_LABEL_MAP.get(etype)
if mapped is None:
new_tags.append(LABEL2ID["O"]) # discard LOC / MISC
else:
new_tags.append(LABEL2ID[f"{bio}-{mapped}"])
remapped.append({"words": example["tokens"], "ner_tags": new_tags})
logger.info("CoNLL-2003 augmentation: %d examples loaded.", len(remapped))
return Dataset.from_list(remapped)
except Exception:
logger.info("CoNLL-2003 unavailable — skipping augmentation.")
return None
# ---------------------------------------------------------------------------
# Tokeniser alignment
# ---------------------------------------------------------------------------
def _make_tokenise_fn(tokenizer):
"""
Return a batched map function that tokenises word sequences and aligns
BIO labels to subword tokens using word_ids().
Only the first subword of each word receives its word's label; remaining
subwords receive -100 (ignored by CrossEntropyLoss).
"""
def tokenise_and_align(examples):
tokenized = tokenizer(
examples["words"],
truncation=True,
max_length=512,
is_split_into_words=True,
)
all_labels: list[list[int]] = []
for i, word_labels in enumerate(examples["ner_tags"]):
word_ids = tokenized.word_ids(batch_index=i)
prev_word_id = None
labels: list[int] = []
for word_id in word_ids:
if word_id is None:
labels.append(-100) # special token
elif word_id != prev_word_id:
labels.append(word_labels[word_id]) # first subword → real label
else:
labels.append(-100) # continuation subword → ignored
prev_word_id = word_id
all_labels.append(labels)
tokenized["labels"] = all_labels
return tokenized
return tokenise_and_align
# ---------------------------------------------------------------------------
# Training entry point
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> None:
"""Fine-tune distilbert-base-uncased for BIO token classification."""
logging.basicConfig(level=logging.INFO)
# 1. Build dataset
synthetic_ds = build_synthetic_dataset(n_samples=args.n_samples)
conll_ds = _try_load_conll()
if conll_ds is not None:
full_ds = concatenate_datasets([synthetic_ds, conll_ds])
full_ds = full_ds.shuffle(seed=42)
else:
full_ds = synthetic_ds
split = full_ds.train_test_split(test_size=0.1, seed=42)
# 2. Tokenise
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenise_fn = _make_tokenise_fn(tokenizer)
tokenized = DatasetDict({
"train": split["train"].map(tokenise_fn, batched=True,
remove_columns=["words", "ner_tags"]),
"eval": split["test"].map(tokenise_fn, batched=True,
remove_columns=["words", "ner_tags"]),
})
# 3. Model
model = AutoModelForTokenClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=NUM_LABELS,
id2label=ID2LABEL,
label2id=LABEL2ID,
ignore_mismatched_sizes=True,
)
# 4. Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
logging_steps=50,
report_to="none",
)
# 5. Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["eval"],
data_collator=DataCollatorForTokenClassification(tokenizer),
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("EvidenceNER checkpoint saved to %s", args.output_dir)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train EvidenceNER")
p.add_argument("--output_dir", default="models/evidence_ner",
help="Directory to save the fine-tuned checkpoint")
p.add_argument("--n_samples", type=int, default=4000,
help="Number of synthetic training sentences to generate")
p.add_argument("--epochs", type=int, default=4,
help="Number of fine-tuning epochs")
p.add_argument("--batch_size", type=int, default=16,
help="Per-device train/eval batch size")
return p.parse_args()
if __name__ == "__main__":
train(parse_args())