File size: 21,327 Bytes
cbb1b1a
 
 
 
 
 
 
 
 
 
4c85df9
cbb1b1a
 
4c85df9
 
cbb1b1a
4c85df9
 
 
cbb1b1a
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b016462
 
 
 
 
 
 
 
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
4c85df9
 
 
cbb1b1a
 
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49a9433
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49a9433
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
49a9433
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb1b1a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
"""
EvidenceNER training script.

Data source: ~4 000 synthetic complaint sentences generated in-memory by
             build_synthetic_dataset().  No download required.
             Optionally augmented with CoNLL-2003 (via HuggingFace datasets)
             when internet is available; PER→PERSON, ORG→ORG; LOC/MISC discarded.

CLI usage:
    python -m src.ner.train --output_dir models/evidence_ner
    python -m src.ner.train --output_dir models/evidence_ner --n_samples 6000 --epochs 5
"""

from __future__ import annotations

import argparse
import logging
import random
import re

from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
)

from src.ner.model import BIO_LABELS, ID2LABEL, LABEL2ID, NUM_LABELS, NER_LABELS

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Entity value banks
# ---------------------------------------------------------------------------

ENTITY_VALUES: dict[str, list[str]] = {
    "ORG": [
        "Flipkart", "Amazon India", "Myntra", "Snapdeal", "Meesho",
        "HDFC Bank", "ICICI Bank", "State Bank of India", "Axis Bank", "Kotak Mahindra Bank",
        "Punjab National Bank", "Bank of Baroda",
        "Airtel", "Reliance Jio", "Vodafone Idea", "BSNL",
        "LIC of India", "Star Health Insurance", "New India Assurance", "ICICI Lombard",
        "CIBIL", "Experian India",
        "Swiggy", "Zomato", "Ola Cabs", "Uber India", "IRCTC",
        "MakeMyTrip", "Paytm", "PhonePe",
        # "Indian Bank" — public sector bank (HQ Chennai); distinct from the generic phrase
        "Indian Bank",
        # Health insurance providers
        "Niva Bupa", "Care Health Insurance", "Bajaj Allianz",
        # Travel / hospitality OTAs
        "Agoda", "OYO Rooms", "Booking.com India",
        # Fintech
        "Razorpay", "BharatPe", "CRED",
    ],
    "AMOUNT": [
        "₹4,299", "₹1,200", "₹50,000", "₹10,500", "₹2,500",
        "Rs. 8,900", "Rs 15,000", "₹3,499", "₹1,00,000", "Rs. 499",
        "₹25,000", "₹750", "Rs. 1,50,000", "₹12,000", "₹5,999",
        "₹35,000", "₹800", "Rs. 2,000", "₹18,500", "₹9,999",
    ],
    "DATE": [
        "12 March 2024", "5 January 2024", "20 April 2024", "8 November 2023",
        "3 weeks ago", "two months ago", "last Tuesday", "last Friday",
        "on 15th February 2024", "15/01/2024", "on 5th April",
        "in December 2023", "last month", "three days ago",
        "on 22 May 2024", "on 30 June 2024",
    ],
    "REF_ID": [
        "Order #OD-2930291", "transaction ID TXN987654321",
        "reference number REF-20240312-001", "ticket ID TKT-9876543",
        "complaint number CMP-2024-001", "booking ID BK-56789",
        "claim number CLM-2024-12345", "case ID CASE-789012",
        "loan reference LN-20240101", "policy number POL-123456789",
        "complaint reference CR-20240415",
    ],
    "ACCOUNT": [
        "account ending in 4521", "loan account 9876543210",
        "savings account XXXX1234", "credit card ending 9087",
        "account number XXXX-4321", "demat account IN12345678",
        "current account", "fixed deposit account",
        "joint savings account", "salary account",
    ],
    "PERSON": [
        "customer care executive", "branch manager",
        "relationship manager", "loan officer",
        "insurance agent", "delivery executive",
        "the support agent", "your representative",
        "technical support executive", "grievance officer",
        "account manager",
    ],
}

# ---------------------------------------------------------------------------
# Sentence templates
# Placeholders: {ORG}  {AMOUNT}  {DATE}  {REF_ID}  {ACCOUNT}  {PERSON}
# ---------------------------------------------------------------------------

TEMPLATES: list[str] = [
    # --- Single entity ---
    "I want to file a complaint against {ORG}.",
    "{ORG} has been completely unresponsive to my grievances.",
    "I am a long-standing customer of {ORG} and I am deeply dissatisfied.",
    "A deduction of {AMOUNT} was made from my account without authorization.",
    "I am owed a refund of {AMOUNT} which has not been credited.",
    "Please reverse the incorrect charge of {AMOUNT} immediately.",
    "The incident occurred on {DATE} and has not been resolved since.",
    "I filed a formal complaint on {DATE} but have received no response.",
    "I am writing with reference to {REF_ID} which remains unresolved.",
    "Please look into complaint {REF_ID} at the earliest.",
    "My {ACCOUNT} has been showing incorrect transactions for several weeks.",
    "The {ACCOUNT} was blocked without any prior notice.",
    "The {PERSON} I spoke to was completely unhelpful and dismissive.",
    "I was promised by a {PERSON} that the issue would be resolved within 24 hours.",

    # --- ORG + AMOUNT ---
    "I ordered from {ORG} but was incorrectly charged {AMOUNT}.",
    "{ORG} has deducted {AMOUNT} from my account without my consent.",
    "I am requesting a refund of {AMOUNT} from {ORG} for a defective product.",
    "{ORG} charged me {AMOUNT} for a service I never subscribed to.",
    "Despite cancellation {ORG} has not refunded {AMOUNT} to date.",
    "{ORG} owes me {AMOUNT} as compensation for the inconvenience caused.",
    "I was billed {AMOUNT} by {ORG} in error and seek immediate correction.",

    # --- ORG + DATE ---
    "I filed a complaint with {ORG} on {DATE} but have received no update.",
    "{ORG} failed to deliver my order by the promised date of {DATE}.",
    "I visited the {ORG} branch on {DATE} but the issue was not resolved.",
    "Since {DATE} {ORG} has not responded to any of my communications.",

    # --- ORG + REF_ID ---
    "My complaint {REF_ID} with {ORG} has been unresolved for several weeks.",
    "I am following up on {REF_ID} raised with {ORG}.",
    "{ORG} has not taken any action on my ticket {REF_ID}.",

    # --- ORG + ACCOUNT ---
    "{ORG} debited funds from my {ACCOUNT} without my knowledge.",
    "I noticed that {ORG} had incorrectly blocked my {ACCOUNT}.",
    "The {ACCOUNT} with {ORG} has been showing erroneous entries.",

    # --- ORG + PERSON ---
    "The {PERSON} at {ORG} was rude and refused to address my concern.",
    "A {PERSON} from {ORG} promised to resolve the issue but never followed up.",
    "I spoke to a {PERSON} at {ORG} who assured me of a refund.",
    "The {PERSON} at {ORG} refused to process my refund request.",

    # --- AMOUNT + DATE ---
    "An unauthorized transaction of {AMOUNT} occurred on {DATE}.",
    "The refund of {AMOUNT} promised for {DATE} was never processed.",
    "On {DATE} a deduction of {AMOUNT} appeared on my account without reason.",

    # --- AMOUNT + ACCOUNT ---
    "{AMOUNT} was wrongly debited from my {ACCOUNT} and I request an immediate refund.",
    "My {ACCOUNT} shows an erroneous charge of {AMOUNT} that I did not authorize.",
    "{AMOUNT} was deducted from my {ACCOUNT} without my knowledge.",

    # --- AMOUNT + REF_ID ---
    "Transaction {REF_ID} of {AMOUNT} is disputed and I seek reversal.",
    "I raised complaint {REF_ID} against an incorrect charge of {AMOUNT}.",

    # --- DATE + REF_ID ---
    "I raised {REF_ID} on {DATE} and have not received any resolution.",
    "As of {DATE} my complaint {REF_ID} remains open with no action taken.",

    # --- PERSON + ACCOUNT ---
    "The {PERSON} disconnected my call without resolving the issue with my {ACCOUNT}.",
    "A {PERSON} assured me that my {ACCOUNT} would be unblocked within 24 hours.",

    # --- ORG + AMOUNT + DATE ---
    "I placed an order with {ORG} for {AMOUNT} on {DATE} but it was never delivered.",
    "{ORG} charged {AMOUNT} to my account on {DATE} without any authorization.",
    "On {DATE} I paid {AMOUNT} to {ORG} but the service was not provided as promised.",
    "I cancelled my subscription with {ORG} on {DATE} but the refund of {AMOUNT} has not been credited.",
    "{ORG} promised to refund {AMOUNT} by {DATE} but has failed to do so.",
    "I was billed {AMOUNT} by {ORG} for a service cancelled on {DATE}.",
    "A duplicate payment of {AMOUNT} to {ORG} made on {DATE} has not been reversed.",

    # --- ORG + AMOUNT + REF_ID ---
    "My order {REF_ID} from {ORG} worth {AMOUNT} was returned but the refund was not received.",
    "I raised complaint {REF_ID} with {ORG} regarding an erroneous charge of {AMOUNT}.",
    "{ORG} owes me {AMOUNT} against transaction reference {REF_ID}.",
    "Despite follow-up on ticket {REF_ID} {ORG} has not refunded {AMOUNT}.",

    # --- ORG + ACCOUNT + AMOUNT ---
    "{ORG} debited {AMOUNT} from my {ACCOUNT} without my consent.",
    "I noticed an unauthorized charge of {AMOUNT} on my {ACCOUNT} with {ORG}.",
    "{ORG} has applied a penalty of {AMOUNT} to my {ACCOUNT} without prior notice.",
    "Funds amounting to {AMOUNT} were withdrawn from my {ACCOUNT} at {ORG} without authorization.",

    # --- ORG + DATE + REF_ID ---
    "I filed complaint {REF_ID} with {ORG} on {DATE} and request immediate resolution.",
    "My ticket {REF_ID} raised with {ORG} on {DATE} has not been addressed.",
    "Since {DATE} {ORG} has not responded to my complaint {REF_ID}.",

    # --- ORG + ACCOUNT + DATE ---
    "{ORG} deducted funds from my {ACCOUNT} on {DATE} without any prior notification.",
    "I noticed on {DATE} that {ORG} had placed an incorrect hold on my {ACCOUNT}.",

    # --- PERSON + ORG + AMOUNT ---
    "The {PERSON} at {ORG} assured me of a refund of {AMOUNT} which I am yet to receive.",
    "A {PERSON} from {ORG} processed an unauthorized deduction of {AMOUNT} from my account.",

    # --- AMOUNT + DATE + REF_ID ---
    "Transaction {REF_ID} of {AMOUNT} made on {DATE} was unauthorized and I seek reversal.",
    "On {DATE} I filed complaint {REF_ID} for the recovery of {AMOUNT}.",

    # --- ORG + AMOUNT + DATE + REF_ID ---
    "My order {REF_ID} from {ORG} placed on {DATE} for {AMOUNT} was cancelled without a refund.",
    "I paid {AMOUNT} to {ORG} on {DATE} against reference {REF_ID} but the service was not rendered.",
    "{ORG} charged {AMOUNT} on {DATE} referencing {REF_ID} without my authorization.",
    "Despite raising {REF_ID} with {ORG} on {DATE} the refund of {AMOUNT} remains pending.",

    # --- ORG + ACCOUNT + AMOUNT + DATE ---
    "On {DATE} {ORG} debited {AMOUNT} from my {ACCOUNT} without authorization.",
    "{ORG} withdrew {AMOUNT} from my {ACCOUNT} on {DATE} citing a technical error.",

    # --- PERSON + ORG + AMOUNT + DATE ---
    "The {PERSON} at {ORG} processed a transaction of {AMOUNT} on {DATE} without my knowledge.",
    "On {DATE} a {PERSON} from {ORG} assured me the disputed {AMOUNT} would be refunded.",

    # --- ORG + ACCOUNT + AMOUNT + REF_ID ---
    "I raised {REF_ID} with {ORG} regarding {AMOUNT} wrongly debited from my {ACCOUNT}.",
    "{ORG} has not reversed {AMOUNT} credited to {ACCOUNT} as per complaint {REF_ID}.",

    # --- Five / six entities ---
    "On {DATE} the {PERSON} at {ORG} deducted {AMOUNT} from my {ACCOUNT} against {REF_ID}.",
    "I spoke to a {PERSON} from {ORG} on {DATE} regarding {REF_ID} worth {AMOUNT} debited from my {ACCOUNT}.",
    "The {PERSON} at {ORG} confirmed on {DATE} that {REF_ID} of {AMOUNT} debited from {ACCOUNT} would be reversed.",
]

# ---------------------------------------------------------------------------
# Template filling helpers
# ---------------------------------------------------------------------------

_SLOT_RE = re.compile(r"\{(ORG|AMOUNT|DATE|REF_ID|ACCOUNT|PERSON)\}")
_WORD_RE = re.compile(r"\S+")


def _extract_slots(template: str) -> list[str]:
    """Return ordered list of slot labels in *template*."""
    return _SLOT_RE.findall(template)


def _fill_template(
    template: str, slot_values: dict[str, str]
) -> tuple[str, list[dict]]:
    """
    Fill slots and return (sentence, entity_spans).

    entity_spans: [{"start": int, "end": int, "label": str, "text": str}, ...]
    """
    parts = _SLOT_RE.split(template)
    # re.split with a capturing group interleaves: [text, label, text, label, ..., text]
    sentence = ""
    spans: list[dict] = []
    for i, part in enumerate(parts):
        if i % 2 == 0:
            sentence += part
        else:
            label = part
            value = slot_values[label]
            start = len(sentence)
            sentence += value
            spans.append({"start": start, "end": start + len(value), "label": label})
    return sentence, spans


def _word_tokenize(sentence: str) -> list[tuple[str, int, int]]:
    """Tokenise *sentence* into (word, char_start, char_end) tuples."""
    return [(m.group(), m.start(), m.end()) for m in _WORD_RE.finditer(sentence)]


def _assign_bio_labels(
    words: list[tuple[str, int, int]], entity_spans: list[dict]
) -> list[int]:
    """
    Assign a BIO label id to each word token.

    A word is "inside" an entity when its start character falls within
    [span.start, span.end).  The first such word in each span gets B-,
    subsequent words get I-.
    """
    labels = ["O"] * len(words)
    for span in entity_spans:
        first_in_span = True
        for i, (_word, wstart, _wend) in enumerate(words):
            if span["start"] <= wstart < span["end"]:
                bio = "B" if first_in_span else "I"
                labels[i] = f"{bio}-{span['label']}"
                first_in_span = False
    return [LABEL2ID[lbl] for lbl in labels]


# ---------------------------------------------------------------------------
# Synthetic dataset builder
# ---------------------------------------------------------------------------

def build_synthetic_dataset(n_samples: int = 4000, seed: int = 42) -> Dataset:
    """
    Generate *n_samples* labelled complaint sentences in memory.

    Returns a HuggingFace Dataset with columns:
        words    : list[str]   — whitespace-split word tokens
        ner_tags : list[int]   — BIO label id per word
    """
    rng = random.Random(seed)
    examples: list[dict] = []
    seen: set[str] = set()
    max_attempts = n_samples * 8

    for _ in range(max_attempts):
        if len(examples) >= n_samples:
            break
        template = rng.choice(TEMPLATES)
        slots = _extract_slots(template)
        slot_values = {s: rng.choice(ENTITY_VALUES[s]) for s in slots}
        sentence, spans = _fill_template(template, slot_values)
        if sentence in seen:
            continue
        seen.add(sentence)
        words = _word_tokenize(sentence)
        examples.append({
            "words": [w for w, _, _ in words],
            "ner_tags": _assign_bio_labels(words, spans),
        })

    logger.info("Synthetic dataset: %d examples generated.", len(examples))
    return Dataset.from_list(examples)


# ---------------------------------------------------------------------------
# CoNLL-2003 augmentation (optional — silently skipped if unavailable)
# ---------------------------------------------------------------------------

_CONLL_LABEL_MAP = {"PER": "PERSON", "ORG": "ORG"}  # LOC / MISC discarded


def _try_load_conll() -> Dataset | None:
    """
    Attempt to load CoNLL-2003 from HuggingFace Hub and remap to G.U.I.D.E. labels.

    Returns None if the dataset is unavailable (no internet, auth error, etc.).
    """
    try:
        from datasets import load_dataset

        conll = load_dataset("conll2003")
        train_split = conll["train"]

        conll_id2label: dict[int, str] = {
            i: name for i, name in enumerate(train_split.features["ner_tags"].feature.names)
        }

        remapped: list[dict] = []
        for example in train_split:
            new_tags: list[int] = []
            for tag_id in example["ner_tags"]:
                conll_label = conll_id2label[tag_id]  # e.g. "B-PER", "I-ORG", "O"
                if conll_label == "O":
                    new_tags.append(LABEL2ID["O"])
                    continue
                bio, etype = conll_label.split("-", 1)
                mapped = _CONLL_LABEL_MAP.get(etype)
                if mapped is None:
                    new_tags.append(LABEL2ID["O"])  # discard LOC / MISC
                else:
                    new_tags.append(LABEL2ID[f"{bio}-{mapped}"])
            remapped.append({"words": example["tokens"], "ner_tags": new_tags})

        logger.info("CoNLL-2003 augmentation: %d examples loaded.", len(remapped))
        return Dataset.from_list(remapped)

    except Exception:
        logger.info("CoNLL-2003 unavailable — skipping augmentation.")
        return None


# ---------------------------------------------------------------------------
# Tokeniser alignment
# ---------------------------------------------------------------------------

def _make_tokenise_fn(tokenizer):
    """
    Return a batched map function that tokenises word sequences and aligns
    BIO labels to subword tokens using word_ids().

    Only the first subword of each word receives its word's label; remaining
    subwords receive -100 (ignored by CrossEntropyLoss).
    """
    def tokenise_and_align(examples):
        tokenized = tokenizer(
            examples["words"],
            truncation=True,
            max_length=512,
            is_split_into_words=True,
        )
        all_labels: list[list[int]] = []
        for i, word_labels in enumerate(examples["ner_tags"]):
            word_ids = tokenized.word_ids(batch_index=i)
            prev_word_id = None
            labels: list[int] = []
            for word_id in word_ids:
                if word_id is None:
                    labels.append(-100)         # special token
                elif word_id != prev_word_id:
                    labels.append(word_labels[word_id])  # first subword → real label
                else:
                    labels.append(-100)         # continuation subword → ignored
                prev_word_id = word_id
            all_labels.append(labels)
        tokenized["labels"] = all_labels
        return tokenized

    return tokenise_and_align


# ---------------------------------------------------------------------------
# Training entry point
# ---------------------------------------------------------------------------

def train(args: argparse.Namespace) -> None:
    """Fine-tune distilbert-base-uncased for BIO token classification."""
    logging.basicConfig(level=logging.INFO)

    # 1. Build dataset
    synthetic_ds = build_synthetic_dataset(n_samples=args.n_samples)
    conll_ds = _try_load_conll()
    if conll_ds is not None:
        full_ds = concatenate_datasets([synthetic_ds, conll_ds])
        full_ds = full_ds.shuffle(seed=42)
    else:
        full_ds = synthetic_ds

    split = full_ds.train_test_split(test_size=0.1, seed=42)

    # 2. Tokenise
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    tokenise_fn = _make_tokenise_fn(tokenizer)

    tokenized = DatasetDict({
        "train": split["train"].map(tokenise_fn, batched=True,
                                    remove_columns=["words", "ner_tags"]),
        "eval":  split["test"].map(tokenise_fn,  batched=True,
                                    remove_columns=["words", "ner_tags"]),
    })

    # 3. Model
    model = AutoModelForTokenClassification.from_pretrained(
        "distilbert-base-uncased",
        num_labels=NUM_LABELS,
        id2label=ID2LABEL,
        label2id=LABEL2ID,
        ignore_mismatched_sizes=True,
    )

    # 4. Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        logging_steps=50,
        report_to="none",
    )

    # 5. Train
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["eval"],
        data_collator=DataCollatorForTokenClassification(tokenizer),
        processing_class=tokenizer,
    )

    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    logger.info("EvidenceNER checkpoint saved to %s", args.output_dir)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Train EvidenceNER")
    p.add_argument("--output_dir", default="models/evidence_ner",
                   help="Directory to save the fine-tuned checkpoint")
    p.add_argument("--n_samples", type=int, default=4000,
                   help="Number of synthetic training sentences to generate")
    p.add_argument("--epochs", type=int, default=4,
                   help="Number of fine-tuning epochs")
    p.add_argument("--batch_size", type=int, default=16,
                   help="Per-device train/eval batch size")
    return p.parse_args()


if __name__ == "__main__":
    train(parse_args())