AliSalman29's picture
feat: update model
db6aa40
#!/usr/bin/env python3
"""
Train NFQA Classification Model from Scratch
Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data.
Usage (single file with automatic splitting):
python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10
Usage (pre-split files):
python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10
Author: Ali
Date: December 2024
"""
import pandas as pd
import numpy as np
import torch
import json
import argparse
import os
from collections import Counter
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
classification_report,
confusion_matrix,
accuracy_score,
f1_score
)
import matplotlib
matplotlib.use('Agg') # Non-interactive backend for server
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
# Set random seed
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
NFQA_CATEGORIES = [
'NOT-A-QUESTION',
'FACTOID',
'DEBATE',
'EVIDENCE-BASED',
'INSTRUCTION',
'REASON',
'EXPERIENCE',
'COMPARISON'
]
# Label mappings
LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)}
ID2LABEL = {idx: label for label, idx in LABEL2ID.items()}
class NFQADataset(Dataset):
"""Custom dataset for NFQA classification"""
def __init__(self, questions, labels, tokenizer, max_length=128):
self.questions = questions
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
question = str(self.questions[idx])
label = int(self.labels[idx])
# Tokenize
encoding = self.tokenizer(
question,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def train_epoch(model, train_loader, optimizer, scheduler, device):
"""Train for one epoch"""
model.train()
total_loss = 0
predictions = []
true_labels = []
progress_bar = tqdm(train_loader, desc="Training")
for batch in progress_bar:
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
total_loss += loss.item()
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Track predictions
preds = torch.argmax(outputs.logits, dim=1)
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
# Update progress bar
progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(train_loader)
accuracy = accuracy_score(true_labels, predictions)
return avg_loss, accuracy
def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False):
"""Evaluate model on validation/test set with optional detailed analysis"""
model.eval()
total_loss = 0
predictions = []
true_labels = []
with torch.no_grad():
for batch in tqdm(data_loader, desc=desc):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
total_loss += outputs.loss.item()
preds = torch.argmax(outputs.logits, dim=1)
predictions.extend(preds.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
avg_loss = total_loss / len(data_loader)
accuracy = accuracy_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions, average='macro')
# Run detailed analysis if requested
if show_analysis and languages is not None:
print("\n" + "-"*70)
print("VALIDATION ANALYSIS")
print("-"*70)
# Analyze by category
analyze_performance_by_category(predictions, true_labels)
# Analyze by language (top 5)
analyze_performance_by_language(predictions, true_labels, languages, top_n=5)
# Analyze combinations (top 10)
analyze_language_category_combinations(predictions, true_labels, languages, top_n=10)
print("-"*70)
return avg_loss, accuracy, f1, predictions, true_labels
def load_data(file_path):
"""Load annotated data from JSONL file"""
print(f"Loading data from: {file_path}\n")
try:
df = pd.read_json(file_path, lines=True)
print(f"✓ Loaded {len(df)} annotated examples")
# Check required columns
if 'question' not in df.columns:
raise ValueError("Missing 'question' column")
# Determine label column
if 'label_id' in df.columns:
label_col = 'label_id'
elif 'ensemble_prediction' in df.columns:
# Convert category names to IDs
df['label_id'] = df['ensemble_prediction'].map(LABEL2ID)
label_col = 'label_id'
elif 'label' in df.columns:
label_col = 'label'
else:
raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')")
# Remove any rows with missing labels
df = df.dropna(subset=['question', label_col])
print(f"✓ Data cleaned: {len(df)} examples with valid labels")
# Show statistics
print("\nLabel distribution:")
label_counts = df[label_col].value_counts().sort_index()
for label_id, count in label_counts.items():
cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}")
print(f" {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)")
# Prepare final dataset with language info
questions = df['question'].tolist()
labels = df[label_col].astype(int).tolist()
languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df)
print(f"\n✓ Prepared {len(questions)} question-label pairs")
return questions, labels, languages
except FileNotFoundError:
print(f"❌ Error: File not found: {file_path}")
raise
except Exception as e:
print(f"❌ Error loading data: {e}")
raise
def create_data_splits(questions, labels, test_size=0.2, val_size=0.1):
"""Create train/val/test splits"""
print("\nCreating data splits...")
# First split: separate test set
train_val_questions, test_questions, train_val_labels, test_labels = train_test_split(
questions,
labels,
test_size=test_size,
random_state=RANDOM_SEED,
stratify=labels
)
# Second split: separate validation from training
train_questions, val_questions, train_labels, val_labels = train_test_split(
train_val_questions,
train_val_labels,
test_size=val_size / (1 - test_size),
random_state=RANDOM_SEED,
stratify=train_val_labels
)
print(f"\nData splits:")
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
print(f" Total: {len(questions):4d} examples")
# Verify class distribution
print("\nClass distribution per split:")
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
counts = Counter(split_labels)
print(f"\n{split_name}:")
for label_id in sorted(counts.keys()):
cat_name = ID2LABEL[label_id]
print(f" {cat_name:20s}: {counts[label_id]:3d}")
return train_questions, val_questions, test_questions, train_labels, val_labels, test_labels
def plot_training_curves(history, best_val_f1, output_dir):
"""Plot and save training curves"""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
epochs = range(1, len(history['train_loss']) + 1)
# Plot 1: Loss
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Plot 2: Accuracy
axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Plot 3: F1 Score
axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2)
axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('F1 Score')
axes[2].set_title('Validation F1 Score')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plot_file = os.path.join(output_dir, 'training_curves.png')
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.close()
print(f"✓ Training curves saved to: {plot_file}")
def analyze_performance_by_language(predictions, true_labels, languages, top_n=10):
"""Analyze and print performance by language"""
from collections import defaultdict
lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
for pred, true, lang in zip(predictions, true_labels, languages):
lang_stats[lang]['total'] += 1
if pred == true:
lang_stats[lang]['correct'] += 1
# Calculate accuracy per language
lang_accuracies = []
for lang, stats in lang_stats.items():
if stats['total'] >= 5: # Only show languages with at least 5 examples
acc = stats['correct'] / stats['total']
lang_accuracies.append({
'language': lang,
'accuracy': acc,
'correct': stats['correct'],
'total': stats['total'],
'errors': stats['total'] - stats['correct']
})
lang_accuracies.sort(key=lambda x: x['accuracy'])
print(f"\n{'='*70}")
print(f"WORST {top_n} LANGUAGES (with >= 5 examples)")
print(f"{'='*70}")
print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
print(f"{'-'*70}")
for item in lang_accuracies[:top_n]:
print(f"{item['language']:<12} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}")
return lang_stats, lang_accuracies
def analyze_performance_by_category(predictions, true_labels):
"""Analyze and print performance by category"""
from collections import defaultdict
cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
for pred, true in zip(predictions, true_labels):
cat_stats[true]['total'] += 1
if pred == true:
cat_stats[true]['correct'] += 1
cat_accuracies = []
for cat_id, stats in cat_stats.items():
acc = stats['correct'] / stats['total']
cat_accuracies.append({
'category': ID2LABEL[cat_id],
'accuracy': acc,
'correct': stats['correct'],
'total': stats['total'],
'errors': stats['total'] - stats['correct']
})
cat_accuracies.sort(key=lambda x: x['accuracy'])
print(f"\n{'='*70}")
print(f"PERFORMANCE BY CATEGORY")
print(f"{'='*70}")
print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}")
print(f"{'-'*70}")
for item in cat_accuracies:
print(f"{item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}")
return cat_stats, cat_accuracies
def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15):
"""Analyze performance by (language, category) combinations"""
from collections import defaultdict
combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
for pred, true, lang in zip(predictions, true_labels, languages):
key = (lang, ID2LABEL[true])
combo_stats[key]['total'] += 1
if pred == true:
combo_stats[key]['correct'] += 1
combo_accuracies = []
for (lang, cat), stats in combo_stats.items():
if stats['total'] >= 3: # Only show combinations with at least 3 examples
acc = stats['correct'] / stats['total']
combo_accuracies.append({
'language': lang,
'category': cat,
'accuracy': acc,
'correct': stats['correct'],
'total': stats['total'],
'errors': stats['total'] - stats['correct']
})
combo_accuracies.sort(key=lambda x: x['accuracy'])
print(f"\n{'='*80}")
print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)")
print(f"{'='*80}")
print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}")
print(f"{'-'*80}")
for item in combo_accuracies[:top_n]:
print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>6} {item['total']:>6}")
return combo_stats, combo_accuracies
def plot_confusion_matrix(test_true, test_preds, output_dir):
"""Plot and save confusion matrix"""
cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES))))
plt.figure(figsize=(12, 10))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=NFQA_CATEGORIES,
yticklabels=NFQA_CATEGORIES,
cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted Category')
plt.ylabel('True Category')
plt.title('Confusion Matrix - Test Set')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
cm_file = os.path.join(output_dir, 'confusion_matrix.png')
plt.savefig(cm_file, dpi=300, bbox_inches='tight')
plt.close()
print(f"✓ Confusion matrix saved to: {cm_file}")
def main():
parser = argparse.ArgumentParser(description='Train NFQA Classification Model')
# Data arguments - either single input file OR separate train/val/test files
parser.add_argument('--input', type=str,
help='Input JSONL file with annotated data (will be split automatically)')
parser.add_argument('--train', type=str,
help='Training set JSONL file (use with --val and --test)')
parser.add_argument('--val', type=str,
help='Validation set JSONL file (use with --train and --test)')
parser.add_argument('--test', type=str,
help='Test set JSONL file (use with --train and --val)')
parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained',
help='Output directory for model and results')
# Model arguments
parser.add_argument('--model-name', type=str, default='xlm-roberta-base',
help='Pretrained model name (default: xlm-roberta-base)')
parser.add_argument('--max-length', type=int, default=128,
help='Maximum sequence length (default: 128)')
# Training arguments
parser.add_argument('--batch-size', type=int, default=16,
help='Batch size (default: 16)')
parser.add_argument('--epochs', type=int, default=10,
help='Number of epochs (default: 10)')
parser.add_argument('--learning-rate', type=float, default=2e-5,
help='Learning rate (default: 2e-5)')
parser.add_argument('--warmup-steps', type=int, default=500,
help='Warmup steps (default: 500)')
parser.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay (default: 0.01)')
parser.add_argument('--dropout', type=float, default=0.1,
help='Dropout probability (default: 0.1)')
# Split arguments
parser.add_argument('--test-size', type=float, default=0.2,
help='Test set size (default: 0.2)')
parser.add_argument('--val-size', type=float, default=0.1,
help='Validation set size (default: 0.1)')
# Device argument
parser.add_argument('--device', type=str, default='auto',
help='Device to use: cuda, cpu, or auto (default: auto)')
args = parser.parse_args()
# Validate arguments
has_single_input = args.input is not None
has_split_inputs = all([args.train, args.val, args.test])
if not has_single_input and not has_split_inputs:
parser.error("Either --input OR (--train, --val, --test) must be provided")
if has_single_input and has_split_inputs:
parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.")
# Print configuration
print("="*80)
print("NFQA MODEL TRAINING")
print("="*80)
if has_single_input:
print(f"Input file: {args.input}")
print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})")
else:
print(f"Train file: {args.train}")
print(f"Val file: {args.val}")
print(f"Test file: {args.test}")
print(f"Data splitting: manual (pre-split)")
print(f"Output directory: {args.output_dir}")
print(f"Model: {args.model_name}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.learning_rate}")
print(f"Max length: {args.max_length}")
print(f"Weight decay: {args.weight_decay}")
print(f"Dropout: {args.dropout}")
print("="*80 + "\n")
# Set device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
print(f"Device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}\n")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load data - either from single file or pre-split files
if has_single_input:
# Load single file and create splits
questions, labels, languages = load_data(args.input)
# Create splits (stratify by labels, keep languages aligned)
from sklearn.model_selection import train_test_split
# First split: separate test set
train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split(
questions, labels, languages,
test_size=args.test_size,
random_state=RANDOM_SEED,
stratify=labels
)
# Second split: separate validation from training
train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split(
train_val_questions, train_val_labels, train_val_langs,
test_size=args.val_size / (1 - args.test_size),
random_state=RANDOM_SEED,
stratify=train_val_labels
)
print(f"\nData splits:")
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)")
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)")
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)")
print(f" Total: {len(questions):4d} examples")
else:
# Load pre-split files
print("Loading pre-split datasets...\n")
train_questions, train_labels, train_langs = load_data(args.train)
val_questions, val_labels, val_langs = load_data(args.val)
test_questions, test_labels, test_langs = load_data(args.test)
# Print split summary
total_examples = len(train_questions) + len(val_questions) + len(test_questions)
print(f"\nData splits:")
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)")
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)")
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)")
print(f" Total: {total_examples:4d} examples")
# Show class distribution per split
print("\nClass distribution per split:")
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]:
counts = Counter(split_labels)
print(f"\n{split_name}:")
for label_id in sorted(counts.keys()):
cat_name = ID2LABEL[label_id]
print(f" {cat_name:20s}: {counts[label_id]:3d}")
# Load tokenizer and model
print(f"\nLoading tokenizer: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
print("✓ Tokenizer loaded")
print(f"\nLoading model: {args.model_name}")
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=len(NFQA_CATEGORIES),
id2label=ID2LABEL,
label2id=LABEL2ID,
hidden_dropout_prob=args.dropout,
attention_probs_dropout_prob=args.dropout,
classifier_dropout=args.dropout
)
model.to(device)
print(f"✓ Model loaded")
print(f" Number of parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# Create datasets
print("\nCreating datasets...")
train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length)
val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length)
test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
print(f"✓ Datasets created")
print(f" Train: {len(train_dataset)} examples ({len(train_loader)} batches)")
print(f" Val: {len(val_dataset)} examples ({len(val_loader)} batches)")
print(f" Test: {len(test_dataset)} examples ({len(test_loader)} batches)")
# Setup optimizer and scheduler
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
total_steps = len(train_loader) * args.epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_steps
)
print(f"\n✓ Optimizer and scheduler configured")
print(f" Total training steps: {total_steps}")
print(f" Warmup steps: {args.warmup_steps}")
# Training loop
history = {
'train_loss': [],
'train_accuracy': [],
'val_loss': [],
'val_accuracy': [],
'val_f1': []
}
best_val_f1 = 0
best_epoch = 0
print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80 + "\n")
for epoch in range(args.epochs):
print(f"\nEpoch {epoch + 1}/{args.epochs}")
print("-" * 80)
# Train
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
# Validate with detailed analysis
val_loss, val_acc, val_f1, val_preds, val_true = evaluate(
model, val_loader, device,
languages=val_langs,
desc="Validating",
show_analysis=True
)
# Update history
history['train_loss'].append(train_loss)
history['train_accuracy'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_accuracy'].append(val_acc)
history['val_f1'].append(val_f1)
# Print metrics
print(f"\nEpoch {epoch + 1} Summary:")
print(f" Train Loss: {train_loss:.4f}")
print(f" Train Accuracy: {train_acc:.4f}")
print(f" Val Loss: {val_loss:.4f}")
print(f" Val Accuracy: {val_acc:.4f}")
print(f" Val F1 (Macro): {val_f1:.4f}")
# Save best model
if val_f1 > best_val_f1:
best_val_f1 = val_f1
best_epoch = epoch + 1
# Save model
model_path = os.path.join(args.output_dir, 'best_model')
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print(f" ✓ New best model saved! (F1: {val_f1:.4f})")
print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Best epoch: {best_epoch}")
print(f"Best validation F1: {best_val_f1:.4f}")
print("="*80)
# Save training history
history_file = os.path.join(args.output_dir, 'training_history.json')
with open(history_file, 'w') as f:
json.dump(history, f, indent=2)
print(f"\n✓ Training history saved to: {history_file}")
# Save final model
final_model_path = os.path.join(args.output_dir, 'final_model')
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"✓ Final model saved to: {final_model_path}")
# Plot training curves
plot_training_curves(history, best_val_f1, args.output_dir)
# Load best model and evaluate on test set
print("\nLoading best model for final evaluation...")
best_model_path = os.path.join(args.output_dir, 'best_model')
model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
model.to(device)
test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing")
print("\n" + "="*80)
print("FINAL TEST SET RESULTS")
print("="*80)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1 (Macro): {test_f1:.4f}")
print("="*80)
# Classification report
print("\n" + "="*80)
print("PER-CATEGORY PERFORMANCE")
print("="*80 + "\n")
report = classification_report(
test_true,
test_preds,
labels=list(range(len(NFQA_CATEGORIES))),
target_names=NFQA_CATEGORIES,
zero_division=0
)
print(report)
# Save report
report_file = os.path.join(args.output_dir, 'classification_report.txt')
with open(report_file, 'w') as f:
f.write(report)
print(f"✓ Classification report saved to: {report_file}")
# Plot confusion matrix
plot_confusion_matrix(test_true, test_preds, args.output_dir)
# Detailed performance analysis
print("\n" + "="*80)
print("DETAILED PERFORMANCE ANALYSIS")
print("="*80)
# Analyze by category
analyze_performance_by_category(test_preds, test_true)
# Analyze by language
analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10)
# Analyze language-category combinations
analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15)
print("\n" + "="*80)
# Save test results
test_results = {
'test_loss': float(test_loss),
'test_accuracy': float(test_acc),
'test_f1_macro': float(test_f1),
'best_epoch': int(best_epoch),
'best_val_f1': float(best_val_f1),
'num_train_examples': len(train_questions),
'num_val_examples': len(val_questions),
'num_test_examples': len(test_questions),
'config': {
'model_name': args.model_name,
'max_length': args.max_length,
'batch_size': args.batch_size,
'learning_rate': args.learning_rate,
'num_epochs': args.epochs,
'warmup_steps': args.warmup_steps,
'weight_decay': args.weight_decay,
'dropout': args.dropout,
'data_source': 'pre-split' if has_split_inputs else 'single_file',
'train_file': args.train if has_split_inputs else args.input,
'val_file': args.val if has_split_inputs else None,
'test_file': args.test if has_split_inputs else None,
'auto_split': not has_split_inputs,
'test_size': args.test_size if not has_split_inputs else None,
'val_size': args.val_size if not has_split_inputs else None
},
'timestamp': datetime.now().isoformat()
}
results_file = os.path.join(args.output_dir, 'test_results.json')
with open(results_file, 'w') as f:
json.dump(test_results, f, indent=2)
print(f"✓ Test results saved to: {results_file}")
# Summary
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"\nModel: {args.model_name}")
print(f"Training examples: {len(train_questions)}")
print(f"Validation examples: {len(val_questions)}")
print(f"Test examples: {len(test_questions)}")
print(f"\nBest epoch: {best_epoch}/{args.epochs}")
print(f"Best validation F1: {best_val_f1:.4f}")
print(f"\nFinal test results:")
print(f" Accuracy: {test_acc:.4f}")
print(f" F1 Score (Macro): {test_f1:.4f}")
print(f"\nModel saved to: {args.output_dir}")
print(f"\nGenerated files:")
print(f" - best_model/ (best checkpoint)")
print(f" - final_model/ (last epoch)")
print(f" - training_history.json")
print(f" - training_curves.png")
print(f" - test_results.json")
print(f" - classification_report.txt")
print(f" - confusion_matrix.png")
print("\n" + "="*80)
print("✅ Training complete! Model ready for deployment.")
print("="*80)
if __name__ == '__main__':
main()