File size: 31,440 Bytes
db6aa40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
#!/usr/bin/env python3
"""
Train NFQA Classification Model from Scratch

Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data.

Usage (single file with automatic splitting):
    python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10

Usage (pre-split files):
    python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10

Author: Ali
Date: December 2024
"""

import pandas as pd
import numpy as np
import torch
import json
import argparse
import os
from collections import Counter
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    f1_score
)
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for server
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seed
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

NFQA_CATEGORIES = [   
    'NOT-A-QUESTION',
    'FACTOID',
    'DEBATE',
    'EVIDENCE-BASED',
    'INSTRUCTION',
    'REASON',
    'EXPERIENCE',
    'COMPARISON'
]

# Label mappings
LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)}
ID2LABEL = {idx: label for label, idx in LABEL2ID.items()}


class NFQADataset(Dataset):
    """Custom dataset for NFQA classification"""

    def __init__(self, questions, labels, tokenizer, max_length=128):
        self.questions = questions
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = str(self.questions[idx])
        label = int(self.labels[idx])

        # Tokenize
        encoding = self.tokenizer(
            question,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


def train_epoch(model, train_loader, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    progress_bar = tqdm(train_loader, desc="Training")

    for batch in progress_bar:
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # Track predictions
        preds = torch.argmax(outputs.logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

        # Update progress bar
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy


def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False):
    """Evaluate model on validation/test set with optional detailed analysis"""
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc=desc):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            total_loss += outputs.loss.item()

            preds = torch.argmax(outputs.logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')

    # Run detailed analysis if requested
    if show_analysis and languages is not None:
        print("\n" + "-"*70)
        print("VALIDATION ANALYSIS")
        print("-"*70)

        # Analyze by category
        analyze_performance_by_category(predictions, true_labels)

        # Analyze by language (top 5)
        analyze_performance_by_language(predictions, true_labels, languages, top_n=5)

        # Analyze combinations (top 10)
        analyze_language_category_combinations(predictions, true_labels, languages, top_n=10)

        print("-"*70)

    return avg_loss, accuracy, f1, predictions, true_labels


def load_data(file_path):
    """Load annotated data from JSONL file"""
    print(f"Loading data from: {file_path}\n")

    try:
        df = pd.read_json(file_path, lines=True)
        print(f"βœ“ Loaded {len(df)} annotated examples")

        # Check required columns
        if 'question' not in df.columns:
            raise ValueError("Missing 'question' column")

        # Determine label column
        if 'label_id' in df.columns:
            label_col = 'label_id'
        elif 'ensemble_prediction' in df.columns:
            # Convert category names to IDs
            df['label_id'] = df['ensemble_prediction'].map(LABEL2ID)
            label_col = 'label_id'
        elif 'label' in df.columns:
            label_col = 'label'
        else:
            raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')")

        # Remove any rows with missing labels
        df = df.dropna(subset=['question', label_col])

        print(f"βœ“ Data cleaned: {len(df)} examples with valid labels")

        # Show statistics
        print("\nLabel distribution:")
        label_counts = df[label_col].value_counts().sort_index()
        for label_id, count in label_counts.items():
            cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}")
            print(f"  {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)")

        # Prepare final dataset with language info
        questions = df['question'].tolist()
        labels = df[label_col].astype(int).tolist()
        languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df)

        print(f"\nβœ“ Prepared {len(questions)} question-label pairs")

        return questions, labels, languages

    except FileNotFoundError:
        print(f"❌ Error: File not found: {file_path}")
        raise
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        raise


def create_data_splits(questions, labels, test_size=0.2, val_size=0.1):
    """Create train/val/test splits"""
    print("\nCreating data splits...")

    # First split: separate test set
    train_val_questions, test_questions, train_val_labels, test_labels = train_test_split(
        questions,
        labels,
        test_size=test_size,
        random_state=RANDOM_SEED,
        stratify=labels
    )

    # Second split: separate validation from training
    train_questions, val_questions, train_labels, val_labels = train_test_split(
        train_val_questions,
        train_val_labels,
        test_size=val_size / (1 - test_size),
        random_state=RANDOM_SEED,
        stratify=train_val_labels
    )

    print(f"\nData splits:")
    print(f"  Training:   {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
    print(f"  Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
    print(f"  Test:       {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
    print(f"  Total:      {len(questions):4d} examples")

    # Verify class distribution
    print("\nClass distribution per split:")
    for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
        counts = Counter(split_labels)
        print(f"\n{split_name}:")
        for label_id in sorted(counts.keys()):
            cat_name = ID2LABEL[label_id]
            print(f"  {cat_name:20s}: {counts[label_id]:3d}")

    return train_questions, val_questions, test_questions, train_labels, val_labels, test_labels


def plot_training_curves(history, best_val_f1, output_dir):
    """Plot and save training curves"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    epochs = range(1, len(history['train_loss']) + 1)

    # Plot 1: Loss
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Plot 2: Accuracy
    axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
    axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # Plot 3: F1 Score
    axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2)
    axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('F1 Score')
    axes[2].set_title('Validation F1 Score')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plot_file = os.path.join(output_dir, 'training_curves.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"βœ“ Training curves saved to: {plot_file}")


def analyze_performance_by_language(predictions, true_labels, languages, top_n=10):
    """Analyze and print performance by language"""
    from collections import defaultdict

    lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0})

    for pred, true, lang in zip(predictions, true_labels, languages):
        lang_stats[lang]['total'] += 1
        if pred == true:
            lang_stats[lang]['correct'] += 1

    # Calculate accuracy per language
    lang_accuracies = []
    for lang, stats in lang_stats.items():
        if stats['total'] >= 5:  # Only show languages with at least 5 examples
            acc = stats['correct'] / stats['total']
            lang_accuracies.append({
                'language': lang,
                'accuracy': acc,
                'correct': stats['correct'],
                'total': stats['total'],
                'errors': stats['total'] - stats['correct']
            })

    lang_accuracies.sort(key=lambda x: x['accuracy'])

    print(f"\n{'='*70}")
    print(f"WORST {top_n} LANGUAGES (with >= 5 examples)")
    print(f"{'='*70}")
    print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
    print(f"{'-'*70}")

    for item in lang_accuracies[:top_n]:
        print(f"{item['language']:<12} {item['accuracy']:>10.2%}   {item['errors']:>8}   {item['total']:>8}")

    return lang_stats, lang_accuracies


def analyze_performance_by_category(predictions, true_labels):
    """Analyze and print performance by category"""
    from collections import defaultdict

    cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0})

    for pred, true in zip(predictions, true_labels):
        cat_stats[true]['total'] += 1
        if pred == true:
            cat_stats[true]['correct'] += 1

    cat_accuracies = []
    for cat_id, stats in cat_stats.items():
        acc = stats['correct'] / stats['total']
        cat_accuracies.append({
            'category': ID2LABEL[cat_id],
            'accuracy': acc,
            'correct': stats['correct'],
            'total': stats['total'],
            'errors': stats['total'] - stats['correct']
        })

    cat_accuracies.sort(key=lambda x: x['accuracy'])

    print(f"\n{'='*70}")
    print(f"PERFORMANCE BY CATEGORY")
    print(f"{'='*70}")
    print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
    print(f"{'-'*70}")

    for item in cat_accuracies:
        print(f"{item['category']:<20} {item['accuracy']:>10.2%}   {item['errors']:>8}   {item['total']:>8}")

    return cat_stats, cat_accuracies


def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15):
    """Analyze performance by (language, category) combinations"""
    from collections import defaultdict

    combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0})

    for pred, true, lang in zip(predictions, true_labels, languages):
        key = (lang, ID2LABEL[true])
        combo_stats[key]['total'] += 1
        if pred == true:
            combo_stats[key]['correct'] += 1

    combo_accuracies = []
    for (lang, cat), stats in combo_stats.items():
        if stats['total'] >= 3:  # Only show combinations with at least 3 examples
            acc = stats['correct'] / stats['total']
            combo_accuracies.append({
                'language': lang,
                'category': cat,
                'accuracy': acc,
                'correct': stats['correct'],
                'total': stats['total'],
                'errors': stats['total'] - stats['correct']
            })

    combo_accuracies.sort(key=lambda x: x['accuracy'])

    print(f"\n{'='*80}")
    print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)")
    print(f"{'='*80}")
    print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}")
    print(f"{'-'*80}")

    for item in combo_accuracies[:top_n]:
        print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%}   {item['errors']:>6}   {item['total']:>6}")

    return combo_stats, combo_accuracies


def plot_confusion_matrix(test_true, test_preds, output_dir):
    """Plot and save confusion matrix"""
    cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES))))

    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=NFQA_CATEGORIES,
        yticklabels=NFQA_CATEGORIES,
        cbar_kws={'label': 'Count'}
    )
    plt.xlabel('Predicted Category')
    plt.ylabel('True Category')
    plt.title('Confusion Matrix - Test Set')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

    cm_file = os.path.join(output_dir, 'confusion_matrix.png')
    plt.savefig(cm_file, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"βœ“ Confusion matrix saved to: {cm_file}")


def main():
    parser = argparse.ArgumentParser(description='Train NFQA Classification Model')

    # Data arguments - either single input file OR separate train/val/test files
    parser.add_argument('--input', type=str,
                        help='Input JSONL file with annotated data (will be split automatically)')
    parser.add_argument('--train', type=str,
                        help='Training set JSONL file (use with --val and --test)')
    parser.add_argument('--val', type=str,
                        help='Validation set JSONL file (use with --train and --test)')
    parser.add_argument('--test', type=str,
                        help='Test set JSONL file (use with --train and --val)')
    parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained',
                        help='Output directory for model and results')

    # Model arguments
    parser.add_argument('--model-name', type=str, default='xlm-roberta-base',
                        help='Pretrained model name (default: xlm-roberta-base)')
    parser.add_argument('--max-length', type=int, default=128,
                        help='Maximum sequence length (default: 128)')

    # Training arguments
    parser.add_argument('--batch-size', type=int, default=16,
                        help='Batch size (default: 16)')
    parser.add_argument('--epochs', type=int, default=10,
                        help='Number of epochs (default: 10)')
    parser.add_argument('--learning-rate', type=float, default=2e-5,
                        help='Learning rate (default: 2e-5)')
    parser.add_argument('--warmup-steps', type=int, default=500,
                        help='Warmup steps (default: 500)')
    parser.add_argument('--weight-decay', type=float, default=0.01,
                        help='Weight decay (default: 0.01)')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='Dropout probability (default: 0.1)')

    # Split arguments
    parser.add_argument('--test-size', type=float, default=0.2,
                        help='Test set size (default: 0.2)')
    parser.add_argument('--val-size', type=float, default=0.1,
                        help='Validation set size (default: 0.1)')

    # Device argument
    parser.add_argument('--device', type=str, default='auto',
                        help='Device to use: cuda, cpu, or auto (default: auto)')

    args = parser.parse_args()

    # Validate arguments
    has_single_input = args.input is not None
    has_split_inputs = all([args.train, args.val, args.test])

    if not has_single_input and not has_split_inputs:
        parser.error("Either --input OR (--train, --val, --test) must be provided")

    if has_single_input and has_split_inputs:
        parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.")

    # Print configuration
    print("="*80)
    print("NFQA MODEL TRAINING")
    print("="*80)
    if has_single_input:
        print(f"Input file: {args.input}")
        print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})")
    else:
        print(f"Train file: {args.train}")
        print(f"Val file: {args.val}")
        print(f"Test file: {args.test}")
        print(f"Data splitting: manual (pre-split)")
    print(f"Output directory: {args.output_dir}")
    print(f"Model: {args.model_name}")
    print(f"Epochs: {args.epochs}")
    print(f"Batch size: {args.batch_size}")
    print(f"Learning rate: {args.learning_rate}")
    print(f"Max length: {args.max_length}")
    print(f"Weight decay: {args.weight_decay}")
    print(f"Dropout: {args.dropout}")
    print("="*80 + "\n")

    # Set device
    if args.device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(RANDOM_SEED)

    print(f"Device: {device}")
    print(f"PyTorch version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}\n")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Load data - either from single file or pre-split files
    if has_single_input:
        # Load single file and create splits
        questions, labels, languages = load_data(args.input)

        # Create splits (stratify by labels, keep languages aligned)
        from sklearn.model_selection import train_test_split
        # First split: separate test set
        train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split(
            questions, labels, languages,
            test_size=args.test_size,
            random_state=RANDOM_SEED,
            stratify=labels
        )

        # Second split: separate validation from training
        train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split(
            train_val_questions, train_val_labels, train_val_langs,
            test_size=args.val_size / (1 - args.test_size),
            random_state=RANDOM_SEED,
            stratify=train_val_labels
        )

        print(f"\nData splits:")
        print(f"  Training:   {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
        print(f"  Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
        print(f"  Test:       {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
        print(f"  Total:      {len(questions):4d} examples")
    else:
        # Load pre-split files
        print("Loading pre-split datasets...\n")
        train_questions, train_labels, train_langs = load_data(args.train)
        val_questions, val_labels, val_langs = load_data(args.val)
        test_questions, test_labels, test_langs = load_data(args.test)

        # Print split summary
        total_examples = len(train_questions) + len(val_questions) + len(test_questions)
        print(f"\nData splits:")
        print(f"  Training:   {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)")
        print(f"  Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)")
        print(f"  Test:       {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)")
        print(f"  Total:      {total_examples:4d} examples")

        # Show class distribution per split
        print("\nClass distribution per split:")
        for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
            counts = Counter(split_labels)
            print(f"\n{split_name}:")
            for label_id in sorted(counts.keys()):
                cat_name = ID2LABEL[label_id]
                print(f"  {cat_name:20s}: {counts[label_id]:3d}")

    # Load tokenizer and model
    print(f"\nLoading tokenizer: {args.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    print("βœ“ Tokenizer loaded")

    print(f"\nLoading model: {args.model_name}")
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name,
        num_labels=len(NFQA_CATEGORIES),
        id2label=ID2LABEL,
        label2id=LABEL2ID,
        hidden_dropout_prob=args.dropout,
        attention_probs_dropout_prob=args.dropout,
        classifier_dropout=args.dropout
    )
    model.to(device)

    print(f"βœ“ Model loaded")
    print(f"  Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Create datasets
    print("\nCreating datasets...")
    train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length)
    val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length)
    test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    print(f"βœ“ Datasets created")
    print(f"  Train: {len(train_dataset)} examples ({len(train_loader)} batches)")
    print(f"  Val:   {len(val_dataset)} examples ({len(val_loader)} batches)")
    print(f"  Test:  {len(test_dataset)} examples ({len(test_loader)} batches)")

    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay
    )

    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=total_steps
    )

    print(f"\nβœ“ Optimizer and scheduler configured")
    print(f"  Total training steps: {total_steps}")
    print(f"  Warmup steps: {args.warmup_steps}")

    # Training loop
    history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': [],
        'val_f1': []
    }

    best_val_f1 = 0
    best_epoch = 0

    print("\n" + "="*80)
    print("STARTING TRAINING")
    print("="*80 + "\n")

    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch + 1}/{args.epochs}")
        print("-" * 80)

        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)

        # Validate with detailed analysis
        val_loss, val_acc, val_f1, val_preds, val_true = evaluate(
            model, val_loader, device,
            languages=val_langs,
            desc="Validating",
            show_analysis=True
        )

        # Update history
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)
        history['val_f1'].append(val_f1)

        # Print metrics
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Train Loss:     {train_loss:.4f}")
        print(f"  Train Accuracy: {train_acc:.4f}")
        print(f"  Val Loss:       {val_loss:.4f}")
        print(f"  Val Accuracy:   {val_acc:.4f}")
        print(f"  Val F1 (Macro): {val_f1:.4f}")

        # Save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_epoch = epoch + 1

            # Save model
            model_path = os.path.join(args.output_dir, 'best_model')
            model.save_pretrained(model_path)
            tokenizer.save_pretrained(model_path)

            print(f"  βœ“ New best model saved! (F1: {val_f1:.4f})")

    print("\n" + "="*80)
    print("TRAINING COMPLETE")
    print("="*80)
    print(f"Best epoch: {best_epoch}")
    print(f"Best validation F1: {best_val_f1:.4f}")
    print("="*80)

    # Save training history
    history_file = os.path.join(args.output_dir, 'training_history.json')
    with open(history_file, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"\nβœ“ Training history saved to: {history_file}")

    # Save final model
    final_model_path = os.path.join(args.output_dir, 'final_model')
    model.save_pretrained(final_model_path)
    tokenizer.save_pretrained(final_model_path)
    print(f"βœ“ Final model saved to: {final_model_path}")

    # Plot training curves
    plot_training_curves(history, best_val_f1, args.output_dir)

    # Load best model and evaluate on test set
    print("\nLoading best model for final evaluation...")
    best_model_path = os.path.join(args.output_dir, 'best_model')
    model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
    model.to(device)

    test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing")

    print("\n" + "="*80)
    print("FINAL TEST SET RESULTS")
    print("="*80)
    print(f"Test Loss:       {test_loss:.4f}")
    print(f"Test Accuracy:   {test_acc:.4f}")
    print(f"Test F1 (Macro): {test_f1:.4f}")
    print("="*80)

    # Classification report
    print("\n" + "="*80)
    print("PER-CATEGORY PERFORMANCE")
    print("="*80 + "\n")

    report = classification_report(
        test_true,
        test_preds,
        labels=list(range(len(NFQA_CATEGORIES))),
        target_names=NFQA_CATEGORIES,
        zero_division=0
    )
    print(report)

    # Save report
    report_file = os.path.join(args.output_dir, 'classification_report.txt')
    with open(report_file, 'w') as f:
        f.write(report)
    print(f"βœ“ Classification report saved to: {report_file}")

    # Plot confusion matrix
    plot_confusion_matrix(test_true, test_preds, args.output_dir)

    # Detailed performance analysis
    print("\n" + "="*80)
    print("DETAILED PERFORMANCE ANALYSIS")
    print("="*80)

    # Analyze by category
    analyze_performance_by_category(test_preds, test_true)

    # Analyze by language
    analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10)

    # Analyze language-category combinations
    analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15)

    print("\n" + "="*80)

    # Save test results
    test_results = {
        'test_loss': float(test_loss),
        'test_accuracy': float(test_acc),
        'test_f1_macro': float(test_f1),
        'best_epoch': int(best_epoch),
        'best_val_f1': float(best_val_f1),
        'num_train_examples': len(train_questions),
        'num_val_examples': len(val_questions),
        'num_test_examples': len(test_questions),
        'config': {
            'model_name': args.model_name,
            'max_length': args.max_length,
            'batch_size': args.batch_size,
            'learning_rate': args.learning_rate,
            'num_epochs': args.epochs,
            'warmup_steps': args.warmup_steps,
            'weight_decay': args.weight_decay,
            'dropout': args.dropout,
            'data_source': 'pre-split' if has_split_inputs else 'single_file',
            'train_file': args.train if has_split_inputs else args.input,
            'val_file': args.val if has_split_inputs else None,
            'test_file': args.test if has_split_inputs else None,
            'auto_split': not has_split_inputs,
            'test_size': args.test_size if not has_split_inputs else None,
            'val_size': args.val_size if not has_split_inputs else None
        },
        'timestamp': datetime.now().isoformat()
    }

    results_file = os.path.join(args.output_dir, 'test_results.json')
    with open(results_file, 'w') as f:
        json.dump(test_results, f, indent=2)
    print(f"βœ“ Test results saved to: {results_file}")

    # Summary
    print("\n" + "="*80)
    print("TRAINING SUMMARY")
    print("="*80)
    print(f"\nModel: {args.model_name}")
    print(f"Training examples: {len(train_questions)}")
    print(f"Validation examples: {len(val_questions)}")
    print(f"Test examples: {len(test_questions)}")
    print(f"\nBest epoch: {best_epoch}/{args.epochs}")
    print(f"Best validation F1: {best_val_f1:.4f}")
    print(f"\nFinal test results:")
    print(f"  Accuracy: {test_acc:.4f}")
    print(f"  F1 Score (Macro): {test_f1:.4f}")
    print(f"\nModel saved to: {args.output_dir}")
    print(f"\nGenerated files:")
    print(f"  - best_model/ (best checkpoint)")
    print(f"  - final_model/ (last epoch)")
    print(f"  - training_history.json")
    print(f"  - training_curves.png")
    print(f"  - test_results.json")
    print(f"  - classification_report.txt")
    print(f"  - confusion_matrix.png")
    print("\n" + "="*80)
    print("βœ… Training complete! Model ready for deployment.")
    print("="*80)


if __name__ == '__main__':
    main()