Spaces:
Sleeping
Sleeping
File size: 21,327 Bytes
cbb1b1a 4c85df9 cbb1b1a 4c85df9 cbb1b1a 4c85df9 cbb1b1a 4c85df9 b016462 4c85df9 cbb1b1a 4c85df9 cbb1b1a 4c85df9 49a9433 4c85df9 cbb1b1a 4c85df9 49a9433 4c85df9 49a9433 4c85df9 cbb1b1a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 | """
EvidenceNER training script.
Data source: ~4 000 synthetic complaint sentences generated in-memory by
build_synthetic_dataset(). No download required.
Optionally augmented with CoNLL-2003 (via HuggingFace datasets)
when internet is available; PER→PERSON, ORG→ORG; LOC/MISC discarded.
CLI usage:
python -m src.ner.train --output_dir models/evidence_ner
python -m src.ner.train --output_dir models/evidence_ner --n_samples 6000 --epochs 5
"""
from __future__ import annotations
import argparse
import logging
import random
import re
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
)
from src.ner.model import BIO_LABELS, ID2LABEL, LABEL2ID, NUM_LABELS, NER_LABELS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Entity value banks
# ---------------------------------------------------------------------------
ENTITY_VALUES: dict[str, list[str]] = {
"ORG": [
"Flipkart", "Amazon India", "Myntra", "Snapdeal", "Meesho",
"HDFC Bank", "ICICI Bank", "State Bank of India", "Axis Bank", "Kotak Mahindra Bank",
"Punjab National Bank", "Bank of Baroda",
"Airtel", "Reliance Jio", "Vodafone Idea", "BSNL",
"LIC of India", "Star Health Insurance", "New India Assurance", "ICICI Lombard",
"CIBIL", "Experian India",
"Swiggy", "Zomato", "Ola Cabs", "Uber India", "IRCTC",
"MakeMyTrip", "Paytm", "PhonePe",
# "Indian Bank" — public sector bank (HQ Chennai); distinct from the generic phrase
"Indian Bank",
# Health insurance providers
"Niva Bupa", "Care Health Insurance", "Bajaj Allianz",
# Travel / hospitality OTAs
"Agoda", "OYO Rooms", "Booking.com India",
# Fintech
"Razorpay", "BharatPe", "CRED",
],
"AMOUNT": [
"₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500",
"Rs. 8,900", "Rs 15,000", "₹3,499", "₹1,00,000", "Rs. 499",
"₹25,000", "₹750", "Rs. 1,50,000", "₹12,000", "₹5,999",
"₹35,000", "₹800", "Rs. 2,000", "₹18,500", "₹9,999",
],
"DATE": [
"12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023",
"3 weeks ago", "two months ago", "last Tuesday", "last Friday",
"on 15th February 2024", "15/01/2024", "on 5th April",
"in December 2023", "last month", "three days ago",
"on 22 May 2024", "on 30 June 2024",
],
"REF_ID": [
"Order #OD-2930291", "transaction ID TXN987654321",
"reference number REF-20240312-001", "ticket ID TKT-9876543",
"complaint number CMP-2024-001", "booking ID BK-56789",
"claim number CLM-2024-12345", "case ID CASE-789012",
"loan reference LN-20240101", "policy number POL-123456789",
"complaint reference CR-20240415",
],
"ACCOUNT": [
"account ending in 4521", "loan account 9876543210",
"savings account XXXX1234", "credit card ending 9087",
"account number XXXX-4321", "demat account IN12345678",
"current account", "fixed deposit account",
"joint savings account", "salary account",
],
"PERSON": [
"customer care executive", "branch manager",
"relationship manager", "loan officer",
"insurance agent", "delivery executive",
"the support agent", "your representative",
"technical support executive", "grievance officer",
"account manager",
],
}
# ---------------------------------------------------------------------------
# Sentence templates
# Placeholders: {ORG} {AMOUNT} {DATE} {REF_ID} {ACCOUNT} {PERSON}
# ---------------------------------------------------------------------------
TEMPLATES: list[str] = [
# --- Single entity ---
"I want to file a complaint against {ORG}.",
"{ORG} has been completely unresponsive to my grievances.",
"I am a long-standing customer of {ORG} and I am deeply dissatisfied.",
"A deduction of {AMOUNT} was made from my account without authorization.",
"I am owed a refund of {AMOUNT} which has not been credited.",
"Please reverse the incorrect charge of {AMOUNT} immediately.",
"The incident occurred on {DATE} and has not been resolved since.",
"I filed a formal complaint on {DATE} but have received no response.",
"I am writing with reference to {REF_ID} which remains unresolved.",
"Please look into complaint {REF_ID} at the earliest.",
"My {ACCOUNT} has been showing incorrect transactions for several weeks.",
"The {ACCOUNT} was blocked without any prior notice.",
"The {PERSON} I spoke to was completely unhelpful and dismissive.",
"I was promised by a {PERSON} that the issue would be resolved within 24 hours.",
# --- ORG + AMOUNT ---
"I ordered from {ORG} but was incorrectly charged {AMOUNT}.",
"{ORG} has deducted {AMOUNT} from my account without my consent.",
"I am requesting a refund of {AMOUNT} from {ORG} for a defective product.",
"{ORG} charged me {AMOUNT} for a service I never subscribed to.",
"Despite cancellation {ORG} has not refunded {AMOUNT} to date.",
"{ORG} owes me {AMOUNT} as compensation for the inconvenience caused.",
"I was billed {AMOUNT} by {ORG} in error and seek immediate correction.",
# --- ORG + DATE ---
"I filed a complaint with {ORG} on {DATE} but have received no update.",
"{ORG} failed to deliver my order by the promised date of {DATE}.",
"I visited the {ORG} branch on {DATE} but the issue was not resolved.",
"Since {DATE} {ORG} has not responded to any of my communications.",
# --- ORG + REF_ID ---
"My complaint {REF_ID} with {ORG} has been unresolved for several weeks.",
"I am following up on {REF_ID} raised with {ORG}.",
"{ORG} has not taken any action on my ticket {REF_ID}.",
# --- ORG + ACCOUNT ---
"{ORG} debited funds from my {ACCOUNT} without my knowledge.",
"I noticed that {ORG} had incorrectly blocked my {ACCOUNT}.",
"The {ACCOUNT} with {ORG} has been showing erroneous entries.",
# --- ORG + PERSON ---
"The {PERSON} at {ORG} was rude and refused to address my concern.",
"A {PERSON} from {ORG} promised to resolve the issue but never followed up.",
"I spoke to a {PERSON} at {ORG} who assured me of a refund.",
"The {PERSON} at {ORG} refused to process my refund request.",
# --- AMOUNT + DATE ---
"An unauthorized transaction of {AMOUNT} occurred on {DATE}.",
"The refund of {AMOUNT} promised for {DATE} was never processed.",
"On {DATE} a deduction of {AMOUNT} appeared on my account without reason.",
# --- AMOUNT + ACCOUNT ---
"{AMOUNT} was wrongly debited from my {ACCOUNT} and I request an immediate refund.",
"My {ACCOUNT} shows an erroneous charge of {AMOUNT} that I did not authorize.",
"{AMOUNT} was deducted from my {ACCOUNT} without my knowledge.",
# --- AMOUNT + REF_ID ---
"Transaction {REF_ID} of {AMOUNT} is disputed and I seek reversal.",
"I raised complaint {REF_ID} against an incorrect charge of {AMOUNT}.",
# --- DATE + REF_ID ---
"I raised {REF_ID} on {DATE} and have not received any resolution.",
"As of {DATE} my complaint {REF_ID} remains open with no action taken.",
# --- PERSON + ACCOUNT ---
"The {PERSON} disconnected my call without resolving the issue with my {ACCOUNT}.",
"A {PERSON} assured me that my {ACCOUNT} would be unblocked within 24 hours.",
# --- ORG + AMOUNT + DATE ---
"I placed an order with {ORG} for {AMOUNT} on {DATE} but it was never delivered.",
"{ORG} charged {AMOUNT} to my account on {DATE} without any authorization.",
"On {DATE} I paid {AMOUNT} to {ORG} but the service was not provided as promised.",
"I cancelled my subscription with {ORG} on {DATE} but the refund of {AMOUNT} has not been credited.",
"{ORG} promised to refund {AMOUNT} by {DATE} but has failed to do so.",
"I was billed {AMOUNT} by {ORG} for a service cancelled on {DATE}.",
"A duplicate payment of {AMOUNT} to {ORG} made on {DATE} has not been reversed.",
# --- ORG + AMOUNT + REF_ID ---
"My order {REF_ID} from {ORG} worth {AMOUNT} was returned but the refund was not received.",
"I raised complaint {REF_ID} with {ORG} regarding an erroneous charge of {AMOUNT}.",
"{ORG} owes me {AMOUNT} against transaction reference {REF_ID}.",
"Despite follow-up on ticket {REF_ID} {ORG} has not refunded {AMOUNT}.",
# --- ORG + ACCOUNT + AMOUNT ---
"{ORG} debited {AMOUNT} from my {ACCOUNT} without my consent.",
"I noticed an unauthorized charge of {AMOUNT} on my {ACCOUNT} with {ORG}.",
"{ORG} has applied a penalty of {AMOUNT} to my {ACCOUNT} without prior notice.",
"Funds amounting to {AMOUNT} were withdrawn from my {ACCOUNT} at {ORG} without authorization.",
# --- ORG + DATE + REF_ID ---
"I filed complaint {REF_ID} with {ORG} on {DATE} and request immediate resolution.",
"My ticket {REF_ID} raised with {ORG} on {DATE} has not been addressed.",
"Since {DATE} {ORG} has not responded to my complaint {REF_ID}.",
# --- ORG + ACCOUNT + DATE ---
"{ORG} deducted funds from my {ACCOUNT} on {DATE} without any prior notification.",
"I noticed on {DATE} that {ORG} had placed an incorrect hold on my {ACCOUNT}.",
# --- PERSON + ORG + AMOUNT ---
"The {PERSON} at {ORG} assured me of a refund of {AMOUNT} which I am yet to receive.",
"A {PERSON} from {ORG} processed an unauthorized deduction of {AMOUNT} from my account.",
# --- AMOUNT + DATE + REF_ID ---
"Transaction {REF_ID} of {AMOUNT} made on {DATE} was unauthorized and I seek reversal.",
"On {DATE} I filed complaint {REF_ID} for the recovery of {AMOUNT}.",
# --- ORG + AMOUNT + DATE + REF_ID ---
"My order {REF_ID} from {ORG} placed on {DATE} for {AMOUNT} was cancelled without a refund.",
"I paid {AMOUNT} to {ORG} on {DATE} against reference {REF_ID} but the service was not rendered.",
"{ORG} charged {AMOUNT} on {DATE} referencing {REF_ID} without my authorization.",
"Despite raising {REF_ID} with {ORG} on {DATE} the refund of {AMOUNT} remains pending.",
# --- ORG + ACCOUNT + AMOUNT + DATE ---
"On {DATE} {ORG} debited {AMOUNT} from my {ACCOUNT} without authorization.",
"{ORG} withdrew {AMOUNT} from my {ACCOUNT} on {DATE} citing a technical error.",
# --- PERSON + ORG + AMOUNT + DATE ---
"The {PERSON} at {ORG} processed a transaction of {AMOUNT} on {DATE} without my knowledge.",
"On {DATE} a {PERSON} from {ORG} assured me the disputed {AMOUNT} would be refunded.",
# --- ORG + ACCOUNT + AMOUNT + REF_ID ---
"I raised {REF_ID} with {ORG} regarding {AMOUNT} wrongly debited from my {ACCOUNT}.",
"{ORG} has not reversed {AMOUNT} credited to {ACCOUNT} as per complaint {REF_ID}.",
# --- Five / six entities ---
"On {DATE} the {PERSON} at {ORG} deducted {AMOUNT} from my {ACCOUNT} against {REF_ID}.",
"I spoke to a {PERSON} from {ORG} on {DATE} regarding {REF_ID} worth {AMOUNT} debited from my {ACCOUNT}.",
"The {PERSON} at {ORG} confirmed on {DATE} that {REF_ID} of {AMOUNT} debited from {ACCOUNT} would be reversed.",
]
# ---------------------------------------------------------------------------
# Template filling helpers
# ---------------------------------------------------------------------------
_SLOT_RE = re.compile(r"\{(ORG|AMOUNT|DATE|REF_ID|ACCOUNT|PERSON)\}")
_WORD_RE = re.compile(r"\S+")
def _extract_slots(template: str) -> list[str]:
"""Return ordered list of slot labels in *template*."""
return _SLOT_RE.findall(template)
def _fill_template(
template: str, slot_values: dict[str, str]
) -> tuple[str, list[dict]]:
"""
Fill slots and return (sentence, entity_spans).
entity_spans: [{"start": int, "end": int, "label": str, "text": str}, ...]
"""
parts = _SLOT_RE.split(template)
# re.split with a capturing group interleaves: [text, label, text, label, ..., text]
sentence = ""
spans: list[dict] = []
for i, part in enumerate(parts):
if i % 2 == 0:
sentence += part
else:
label = part
value = slot_values[label]
start = len(sentence)
sentence += value
spans.append({"start": start, "end": start + len(value), "label": label})
return sentence, spans
def _word_tokenize(sentence: str) -> list[tuple[str, int, int]]:
"""Tokenise *sentence* into (word, char_start, char_end) tuples."""
return [(m.group(), m.start(), m.end()) for m in _WORD_RE.finditer(sentence)]
def _assign_bio_labels(
words: list[tuple[str, int, int]], entity_spans: list[dict]
) -> list[int]:
"""
Assign a BIO label id to each word token.
A word is "inside" an entity when its start character falls within
[span.start, span.end). The first such word in each span gets B-,
subsequent words get I-.
"""
labels = ["O"] * len(words)
for span in entity_spans:
first_in_span = True
for i, (_word, wstart, _wend) in enumerate(words):
if span["start"] <= wstart < span["end"]:
bio = "B" if first_in_span else "I"
labels[i] = f"{bio}-{span['label']}"
first_in_span = False
return [LABEL2ID[lbl] for lbl in labels]
# ---------------------------------------------------------------------------
# Synthetic dataset builder
# ---------------------------------------------------------------------------
def build_synthetic_dataset(n_samples: int = 4000, seed: int = 42) -> Dataset:
"""
Generate *n_samples* labelled complaint sentences in memory.
Returns a HuggingFace Dataset with columns:
words : list[str] — whitespace-split word tokens
ner_tags : list[int] — BIO label id per word
"""
rng = random.Random(seed)
examples: list[dict] = []
seen: set[str] = set()
max_attempts = n_samples * 8
for _ in range(max_attempts):
if len(examples) >= n_samples:
break
template = rng.choice(TEMPLATES)
slots = _extract_slots(template)
slot_values = {s: rng.choice(ENTITY_VALUES[s]) for s in slots}
sentence, spans = _fill_template(template, slot_values)
if sentence in seen:
continue
seen.add(sentence)
words = _word_tokenize(sentence)
examples.append({
"words": [w for w, _, _ in words],
"ner_tags": _assign_bio_labels(words, spans),
})
logger.info("Synthetic dataset: %d examples generated.", len(examples))
return Dataset.from_list(examples)
# ---------------------------------------------------------------------------
# CoNLL-2003 augmentation (optional — silently skipped if unavailable)
# ---------------------------------------------------------------------------
_CONLL_LABEL_MAP = {"PER": "PERSON", "ORG": "ORG"} # LOC / MISC discarded
def _try_load_conll() -> Dataset | None:
"""
Attempt to load CoNLL-2003 from HuggingFace Hub and remap to G.U.I.D.E. labels.
Returns None if the dataset is unavailable (no internet, auth error, etc.).
"""
try:
from datasets import load_dataset
conll = load_dataset("conll2003")
train_split = conll["train"]
conll_id2label: dict[int, str] = {
i: name for i, name in enumerate(train_split.features["ner_tags"].feature.names)
}
remapped: list[dict] = []
for example in train_split:
new_tags: list[int] = []
for tag_id in example["ner_tags"]:
conll_label = conll_id2label[tag_id] # e.g. "B-PER", "I-ORG", "O"
if conll_label == "O":
new_tags.append(LABEL2ID["O"])
continue
bio, etype = conll_label.split("-", 1)
mapped = _CONLL_LABEL_MAP.get(etype)
if mapped is None:
new_tags.append(LABEL2ID["O"]) # discard LOC / MISC
else:
new_tags.append(LABEL2ID[f"{bio}-{mapped}"])
remapped.append({"words": example["tokens"], "ner_tags": new_tags})
logger.info("CoNLL-2003 augmentation: %d examples loaded.", len(remapped))
return Dataset.from_list(remapped)
except Exception:
logger.info("CoNLL-2003 unavailable — skipping augmentation.")
return None
# ---------------------------------------------------------------------------
# Tokeniser alignment
# ---------------------------------------------------------------------------
def _make_tokenise_fn(tokenizer):
"""
Return a batched map function that tokenises word sequences and aligns
BIO labels to subword tokens using word_ids().
Only the first subword of each word receives its word's label; remaining
subwords receive -100 (ignored by CrossEntropyLoss).
"""
def tokenise_and_align(examples):
tokenized = tokenizer(
examples["words"],
truncation=True,
max_length=512,
is_split_into_words=True,
)
all_labels: list[list[int]] = []
for i, word_labels in enumerate(examples["ner_tags"]):
word_ids = tokenized.word_ids(batch_index=i)
prev_word_id = None
labels: list[int] = []
for word_id in word_ids:
if word_id is None:
labels.append(-100) # special token
elif word_id != prev_word_id:
labels.append(word_labels[word_id]) # first subword → real label
else:
labels.append(-100) # continuation subword → ignored
prev_word_id = word_id
all_labels.append(labels)
tokenized["labels"] = all_labels
return tokenized
return tokenise_and_align
# ---------------------------------------------------------------------------
# Training entry point
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> None:
"""Fine-tune distilbert-base-uncased for BIO token classification."""
logging.basicConfig(level=logging.INFO)
# 1. Build dataset
synthetic_ds = build_synthetic_dataset(n_samples=args.n_samples)
conll_ds = _try_load_conll()
if conll_ds is not None:
full_ds = concatenate_datasets([synthetic_ds, conll_ds])
full_ds = full_ds.shuffle(seed=42)
else:
full_ds = synthetic_ds
split = full_ds.train_test_split(test_size=0.1, seed=42)
# 2. Tokenise
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
tokenise_fn = _make_tokenise_fn(tokenizer)
tokenized = DatasetDict({
"train": split["train"].map(tokenise_fn, batched=True,
remove_columns=["words", "ner_tags"]),
"eval": split["test"].map(tokenise_fn, batched=True,
remove_columns=["words", "ner_tags"]),
})
# 3. Model
model = AutoModelForTokenClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=NUM_LABELS,
id2label=ID2LABEL,
label2id=LABEL2ID,
ignore_mismatched_sizes=True,
)
# 4. Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
logging_steps=50,
report_to="none",
)
# 5. Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["eval"],
data_collator=DataCollatorForTokenClassification(tokenizer),
processing_class=tokenizer,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
logger.info("EvidenceNER checkpoint saved to %s", args.output_dir)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train EvidenceNER")
p.add_argument("--output_dir", default="models/evidence_ner",
help="Directory to save the fine-tuned checkpoint")
p.add_argument("--n_samples", type=int, default=4000,
help="Number of synthetic training sentences to generate")
p.add_argument("--epochs", type=int, default=4,
help="Number of fine-tuning epochs")
p.add_argument("--batch_size", type=int, default=16,
help="Per-device train/eval batch size")
return p.parse_args()
if __name__ == "__main__":
train(parse_args())
|