#!/usr/bin/env python3 """ Train NFQA Classification Model from Scratch Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data. Usage (single file with automatic splitting): python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10 Usage (pre-split files): python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10 Author: Ali Date: December 2024 """ import pandas as pd import numpy as np import torch import json import argparse import os from collections import Counter from datetime import datetime from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup ) from sklearn.model_selection import train_test_split from sklearn.metrics import ( classification_report, confusion_matrix, accuracy_score, f1_score ) import matplotlib matplotlib.use('Agg') # Non-interactive backend for server import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm # Set random seed RANDOM_SEED = 42 np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) NFQA_CATEGORIES = [ 'NOT-A-QUESTION', 'FACTOID', 'DEBATE', 'EVIDENCE-BASED', 'INSTRUCTION', 'REASON', 'EXPERIENCE', 'COMPARISON' ] # Label mappings LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)} ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} class NFQADataset(Dataset): """Custom dataset for NFQA classification""" def __init__(self, questions, labels, tokenizer, max_length=128): self.questions = questions self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.questions) def __getitem__(self, idx): question = str(self.questions[idx]) label = int(self.labels[idx]) # Tokenize encoding = self.tokenizer( question, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } def train_epoch(model, train_loader, optimizer, scheduler, device): """Train for one epoch""" model.train() total_loss = 0 predictions = [] true_labels = [] progress_bar = tqdm(train_loader, desc="Training") for batch in progress_bar: # Move batch to device input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # Forward pass outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) loss = outputs.loss total_loss += loss.item() # Backward pass optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() # Track predictions preds = torch.argmax(outputs.logits, dim=1) predictions.extend(preds.cpu().numpy()) true_labels.extend(labels.cpu().numpy()) # Update progress bar progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) avg_loss = total_loss / len(train_loader) accuracy = accuracy_score(true_labels, predictions) return avg_loss, accuracy def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False): """Evaluate model on validation/test set with optional detailed analysis""" model.eval() total_loss = 0 predictions = [] true_labels = [] with torch.no_grad(): for batch in tqdm(data_loader, desc=desc): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels ) total_loss += outputs.loss.item() preds = torch.argmax(outputs.logits, dim=1) predictions.extend(preds.cpu().numpy()) true_labels.extend(labels.cpu().numpy()) avg_loss = total_loss / len(data_loader) accuracy = accuracy_score(true_labels, predictions) f1 = f1_score(true_labels, predictions, average='macro') # Run detailed analysis if requested if show_analysis and languages is not None: print("\n" + "-"*70) print("VALIDATION ANALYSIS") print("-"*70) # Analyze by category analyze_performance_by_category(predictions, true_labels) # Analyze by language (top 5) analyze_performance_by_language(predictions, true_labels, languages, top_n=5) # Analyze combinations (top 10) analyze_language_category_combinations(predictions, true_labels, languages, top_n=10) print("-"*70) return avg_loss, accuracy, f1, predictions, true_labels def load_data(file_path): """Load annotated data from JSONL file""" print(f"Loading data from: {file_path}\n") try: df = pd.read_json(file_path, lines=True) print(f"✓ Loaded {len(df)} annotated examples") # Check required columns if 'question' not in df.columns: raise ValueError("Missing 'question' column") # Determine label column if 'label_id' in df.columns: label_col = 'label_id' elif 'ensemble_prediction' in df.columns: # Convert category names to IDs df['label_id'] = df['ensemble_prediction'].map(LABEL2ID) label_col = 'label_id' elif 'label' in df.columns: label_col = 'label' else: raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')") # Remove any rows with missing labels df = df.dropna(subset=['question', label_col]) print(f"✓ Data cleaned: {len(df)} examples with valid labels") # Show statistics print("\nLabel distribution:") label_counts = df[label_col].value_counts().sort_index() for label_id, count in label_counts.items(): cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}") print(f" {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)") # Prepare final dataset with language info questions = df['question'].tolist() labels = df[label_col].astype(int).tolist() languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df) print(f"\n✓ Prepared {len(questions)} question-label pairs") return questions, labels, languages except FileNotFoundError: print(f"❌ Error: File not found: {file_path}") raise except Exception as e: print(f"❌ Error loading data: {e}") raise def create_data_splits(questions, labels, test_size=0.2, val_size=0.1): """Create train/val/test splits""" print("\nCreating data splits...") # First split: separate test set train_val_questions, test_questions, train_val_labels, test_labels = train_test_split( questions, labels, test_size=test_size, random_state=RANDOM_SEED, stratify=labels ) # Second split: separate validation from training train_questions, val_questions, train_labels, val_labels = train_test_split( train_val_questions, train_val_labels, test_size=val_size / (1 - test_size), random_state=RANDOM_SEED, stratify=train_val_labels ) print(f"\nData splits:") print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)") print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)") print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)") print(f" Total: {len(questions):4d} examples") # Verify class distribution print("\nClass distribution per split:") for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]: counts = Counter(split_labels) print(f"\n{split_name}:") for label_id in sorted(counts.keys()): cat_name = ID2LABEL[label_id] print(f" {cat_name:20s}: {counts[label_id]:3d}") return train_questions, val_questions, test_questions, train_labels, val_labels, test_labels def plot_training_curves(history, best_val_f1, output_dir): """Plot and save training curves""" fig, axes = plt.subplots(1, 3, figsize=(18, 5)) epochs = range(1, len(history['train_loss']) + 1) # Plot 1: Loss axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Loss') axes[0].set_title('Training and Validation Loss') axes[0].legend() axes[0].grid(True, alpha=0.3) # Plot 2: Accuracy axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2) axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2) axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Accuracy') axes[1].set_title('Training and Validation Accuracy') axes[1].legend() axes[1].grid(True, alpha=0.3) # Plot 3: F1 Score axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2) axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}') axes[2].set_xlabel('Epoch') axes[2].set_ylabel('F1 Score') axes[2].set_title('Validation F1 Score') axes[2].legend() axes[2].grid(True, alpha=0.3) plt.tight_layout() plot_file = os.path.join(output_dir, 'training_curves.png') plt.savefig(plot_file, dpi=300, bbox_inches='tight') plt.close() print(f"✓ Training curves saved to: {plot_file}") def analyze_performance_by_language(predictions, true_labels, languages, top_n=10): """Analyze and print performance by language""" from collections import defaultdict lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) for pred, true, lang in zip(predictions, true_labels, languages): lang_stats[lang]['total'] += 1 if pred == true: lang_stats[lang]['correct'] += 1 # Calculate accuracy per language lang_accuracies = [] for lang, stats in lang_stats.items(): if stats['total'] >= 5: # Only show languages with at least 5 examples acc = stats['correct'] / stats['total'] lang_accuracies.append({ 'language': lang, 'accuracy': acc, 'correct': stats['correct'], 'total': stats['total'], 'errors': stats['total'] - stats['correct'] }) lang_accuracies.sort(key=lambda x: x['accuracy']) print(f"\n{'='*70}") print(f"WORST {top_n} LANGUAGES (with >= 5 examples)") print(f"{'='*70}") print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") print(f"{'-'*70}") for item in lang_accuracies[:top_n]: print(f"{item['language']:<12} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") return lang_stats, lang_accuracies def analyze_performance_by_category(predictions, true_labels): """Analyze and print performance by category""" from collections import defaultdict cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) for pred, true in zip(predictions, true_labels): cat_stats[true]['total'] += 1 if pred == true: cat_stats[true]['correct'] += 1 cat_accuracies = [] for cat_id, stats in cat_stats.items(): acc = stats['correct'] / stats['total'] cat_accuracies.append({ 'category': ID2LABEL[cat_id], 'accuracy': acc, 'correct': stats['correct'], 'total': stats['total'], 'errors': stats['total'] - stats['correct'] }) cat_accuracies.sort(key=lambda x: x['accuracy']) print(f"\n{'='*70}") print(f"PERFORMANCE BY CATEGORY") print(f"{'='*70}") print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") print(f"{'-'*70}") for item in cat_accuracies: print(f"{item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") return cat_stats, cat_accuracies def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15): """Analyze performance by (language, category) combinations""" from collections import defaultdict combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) for pred, true, lang in zip(predictions, true_labels, languages): key = (lang, ID2LABEL[true]) combo_stats[key]['total'] += 1 if pred == true: combo_stats[key]['correct'] += 1 combo_accuracies = [] for (lang, cat), stats in combo_stats.items(): if stats['total'] >= 3: # Only show combinations with at least 3 examples acc = stats['correct'] / stats['total'] combo_accuracies.append({ 'language': lang, 'category': cat, 'accuracy': acc, 'correct': stats['correct'], 'total': stats['total'], 'errors': stats['total'] - stats['correct'] }) combo_accuracies.sort(key=lambda x: x['accuracy']) print(f"\n{'='*80}") print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)") print(f"{'='*80}") print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}") print(f"{'-'*80}") for item in combo_accuracies[:top_n]: print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>6} {item['total']:>6}") return combo_stats, combo_accuracies def plot_confusion_matrix(test_true, test_preds, output_dir): """Plot and save confusion matrix""" cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES)))) plt.figure(figsize=(12, 10)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=NFQA_CATEGORIES, yticklabels=NFQA_CATEGORIES, cbar_kws={'label': 'Count'} ) plt.xlabel('Predicted Category') plt.ylabel('True Category') plt.title('Confusion Matrix - Test Set') plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() cm_file = os.path.join(output_dir, 'confusion_matrix.png') plt.savefig(cm_file, dpi=300, bbox_inches='tight') plt.close() print(f"✓ Confusion matrix saved to: {cm_file}") def main(): parser = argparse.ArgumentParser(description='Train NFQA Classification Model') # Data arguments - either single input file OR separate train/val/test files parser.add_argument('--input', type=str, help='Input JSONL file with annotated data (will be split automatically)') parser.add_argument('--train', type=str, help='Training set JSONL file (use with --val and --test)') parser.add_argument('--val', type=str, help='Validation set JSONL file (use with --train and --test)') parser.add_argument('--test', type=str, help='Test set JSONL file (use with --train and --val)') parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained', help='Output directory for model and results') # Model arguments parser.add_argument('--model-name', type=str, default='xlm-roberta-base', help='Pretrained model name (default: xlm-roberta-base)') parser.add_argument('--max-length', type=int, default=128, help='Maximum sequence length (default: 128)') # Training arguments parser.add_argument('--batch-size', type=int, default=16, help='Batch size (default: 16)') parser.add_argument('--epochs', type=int, default=10, help='Number of epochs (default: 10)') parser.add_argument('--learning-rate', type=float, default=2e-5, help='Learning rate (default: 2e-5)') parser.add_argument('--warmup-steps', type=int, default=500, help='Warmup steps (default: 500)') parser.add_argument('--weight-decay', type=float, default=0.01, help='Weight decay (default: 0.01)') parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability (default: 0.1)') # Split arguments parser.add_argument('--test-size', type=float, default=0.2, help='Test set size (default: 0.2)') parser.add_argument('--val-size', type=float, default=0.1, help='Validation set size (default: 0.1)') # Device argument parser.add_argument('--device', type=str, default='auto', help='Device to use: cuda, cpu, or auto (default: auto)') args = parser.parse_args() # Validate arguments has_single_input = args.input is not None has_split_inputs = all([args.train, args.val, args.test]) if not has_single_input and not has_split_inputs: parser.error("Either --input OR (--train, --val, --test) must be provided") if has_single_input and has_split_inputs: parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.") # Print configuration print("="*80) print("NFQA MODEL TRAINING") print("="*80) if has_single_input: print(f"Input file: {args.input}") print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})") else: print(f"Train file: {args.train}") print(f"Val file: {args.val}") print(f"Test file: {args.test}") print(f"Data splitting: manual (pre-split)") print(f"Output directory: {args.output_dir}") print(f"Model: {args.model_name}") print(f"Epochs: {args.epochs}") print(f"Batch size: {args.batch_size}") print(f"Learning rate: {args.learning_rate}") print(f"Max length: {args.max_length}") print(f"Weight decay: {args.weight_decay}") print(f"Dropout: {args.dropout}") print("="*80 + "\n") # Set device if args.device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(args.device) if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED) print(f"Device: {device}") print(f"PyTorch version: {torch.__version__}") if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(0)}\n") # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Load data - either from single file or pre-split files if has_single_input: # Load single file and create splits questions, labels, languages = load_data(args.input) # Create splits (stratify by labels, keep languages aligned) from sklearn.model_selection import train_test_split # First split: separate test set train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split( questions, labels, languages, test_size=args.test_size, random_state=RANDOM_SEED, stratify=labels ) # Second split: separate validation from training train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split( train_val_questions, train_val_labels, train_val_langs, test_size=args.val_size / (1 - args.test_size), random_state=RANDOM_SEED, stratify=train_val_labels ) print(f"\nData splits:") print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)") print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)") print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)") print(f" Total: {len(questions):4d} examples") else: # Load pre-split files print("Loading pre-split datasets...\n") train_questions, train_labels, train_langs = load_data(args.train) val_questions, val_labels, val_langs = load_data(args.val) test_questions, test_labels, test_langs = load_data(args.test) # Print split summary total_examples = len(train_questions) + len(val_questions) + len(test_questions) print(f"\nData splits:") print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)") print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)") print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)") print(f" Total: {total_examples:4d} examples") # Show class distribution per split print("\nClass distribution per split:") for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]: counts = Counter(split_labels) print(f"\n{split_name}:") for label_id in sorted(counts.keys()): cat_name = ID2LABEL[label_id] print(f" {cat_name:20s}: {counts[label_id]:3d}") # Load tokenizer and model print(f"\nLoading tokenizer: {args.model_name}") tokenizer = AutoTokenizer.from_pretrained(args.model_name) print("✓ Tokenizer loaded") print(f"\nLoading model: {args.model_name}") model = AutoModelForSequenceClassification.from_pretrained( args.model_name, num_labels=len(NFQA_CATEGORIES), id2label=ID2LABEL, label2id=LABEL2ID, hidden_dropout_prob=args.dropout, attention_probs_dropout_prob=args.dropout, classifier_dropout=args.dropout ) model.to(device) print(f"✓ Model loaded") print(f" Number of parameters: {sum(p.numel() for p in model.parameters()):,}") print(f" Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") # Create datasets print("\nCreating datasets...") train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length) val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length) test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size) test_loader = DataLoader(test_dataset, batch_size=args.batch_size) print(f"✓ Datasets created") print(f" Train: {len(train_dataset)} examples ({len(train_loader)} batches)") print(f" Val: {len(val_dataset)} examples ({len(val_loader)} batches)") print(f" Test: {len(test_dataset)} examples ({len(test_loader)} batches)") # Setup optimizer and scheduler optimizer = AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) total_steps = len(train_loader) * args.epochs scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps ) print(f"\n✓ Optimizer and scheduler configured") print(f" Total training steps: {total_steps}") print(f" Warmup steps: {args.warmup_steps}") # Training loop history = { 'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': [] } best_val_f1 = 0 best_epoch = 0 print("\n" + "="*80) print("STARTING TRAINING") print("="*80 + "\n") for epoch in range(args.epochs): print(f"\nEpoch {epoch + 1}/{args.epochs}") print("-" * 80) # Train train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device) # Validate with detailed analysis val_loss, val_acc, val_f1, val_preds, val_true = evaluate( model, val_loader, device, languages=val_langs, desc="Validating", show_analysis=True ) # Update history history['train_loss'].append(train_loss) history['train_accuracy'].append(train_acc) history['val_loss'].append(val_loss) history['val_accuracy'].append(val_acc) history['val_f1'].append(val_f1) # Print metrics print(f"\nEpoch {epoch + 1} Summary:") print(f" Train Loss: {train_loss:.4f}") print(f" Train Accuracy: {train_acc:.4f}") print(f" Val Loss: {val_loss:.4f}") print(f" Val Accuracy: {val_acc:.4f}") print(f" Val F1 (Macro): {val_f1:.4f}") # Save best model if val_f1 > best_val_f1: best_val_f1 = val_f1 best_epoch = epoch + 1 # Save model model_path = os.path.join(args.output_dir, 'best_model') model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) print(f" ✓ New best model saved! (F1: {val_f1:.4f})") print("\n" + "="*80) print("TRAINING COMPLETE") print("="*80) print(f"Best epoch: {best_epoch}") print(f"Best validation F1: {best_val_f1:.4f}") print("="*80) # Save training history history_file = os.path.join(args.output_dir, 'training_history.json') with open(history_file, 'w') as f: json.dump(history, f, indent=2) print(f"\n✓ Training history saved to: {history_file}") # Save final model final_model_path = os.path.join(args.output_dir, 'final_model') model.save_pretrained(final_model_path) tokenizer.save_pretrained(final_model_path) print(f"✓ Final model saved to: {final_model_path}") # Plot training curves plot_training_curves(history, best_val_f1, args.output_dir) # Load best model and evaluate on test set print("\nLoading best model for final evaluation...") best_model_path = os.path.join(args.output_dir, 'best_model') model = AutoModelForSequenceClassification.from_pretrained(best_model_path) model.to(device) test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing") print("\n" + "="*80) print("FINAL TEST SET RESULTS") print("="*80) print(f"Test Loss: {test_loss:.4f}") print(f"Test Accuracy: {test_acc:.4f}") print(f"Test F1 (Macro): {test_f1:.4f}") print("="*80) # Classification report print("\n" + "="*80) print("PER-CATEGORY PERFORMANCE") print("="*80 + "\n") report = classification_report( test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES))), target_names=NFQA_CATEGORIES, zero_division=0 ) print(report) # Save report report_file = os.path.join(args.output_dir, 'classification_report.txt') with open(report_file, 'w') as f: f.write(report) print(f"✓ Classification report saved to: {report_file}") # Plot confusion matrix plot_confusion_matrix(test_true, test_preds, args.output_dir) # Detailed performance analysis print("\n" + "="*80) print("DETAILED PERFORMANCE ANALYSIS") print("="*80) # Analyze by category analyze_performance_by_category(test_preds, test_true) # Analyze by language analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10) # Analyze language-category combinations analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15) print("\n" + "="*80) # Save test results test_results = { 'test_loss': float(test_loss), 'test_accuracy': float(test_acc), 'test_f1_macro': float(test_f1), 'best_epoch': int(best_epoch), 'best_val_f1': float(best_val_f1), 'num_train_examples': len(train_questions), 'num_val_examples': len(val_questions), 'num_test_examples': len(test_questions), 'config': { 'model_name': args.model_name, 'max_length': args.max_length, 'batch_size': args.batch_size, 'learning_rate': args.learning_rate, 'num_epochs': args.epochs, 'warmup_steps': args.warmup_steps, 'weight_decay': args.weight_decay, 'dropout': args.dropout, 'data_source': 'pre-split' if has_split_inputs else 'single_file', 'train_file': args.train if has_split_inputs else args.input, 'val_file': args.val if has_split_inputs else None, 'test_file': args.test if has_split_inputs else None, 'auto_split': not has_split_inputs, 'test_size': args.test_size if not has_split_inputs else None, 'val_size': args.val_size if not has_split_inputs else None }, 'timestamp': datetime.now().isoformat() } results_file = os.path.join(args.output_dir, 'test_results.json') with open(results_file, 'w') as f: json.dump(test_results, f, indent=2) print(f"✓ Test results saved to: {results_file}") # Summary print("\n" + "="*80) print("TRAINING SUMMARY") print("="*80) print(f"\nModel: {args.model_name}") print(f"Training examples: {len(train_questions)}") print(f"Validation examples: {len(val_questions)}") print(f"Test examples: {len(test_questions)}") print(f"\nBest epoch: {best_epoch}/{args.epochs}") print(f"Best validation F1: {best_val_f1:.4f}") print(f"\nFinal test results:") print(f" Accuracy: {test_acc:.4f}") print(f" F1 Score (Macro): {test_f1:.4f}") print(f"\nModel saved to: {args.output_dir}") print(f"\nGenerated files:") print(f" - best_model/ (best checkpoint)") print(f" - final_model/ (last epoch)") print(f" - training_history.json") print(f" - training_curves.png") print(f" - test_results.json") print(f" - classification_report.txt") print(f" - confusion_matrix.png") print("\n" + "="*80) print("✅ Training complete! Model ready for deployment.") print("="*80) if __name__ == '__main__': main()