|
|
|
|
|
""" |
|
|
Train NFQA Classification Model from Scratch |
|
|
|
|
|
Trains a multilingual NFQA classifier using XLM-RoBERTa on LLM-annotated WebFAQ data. |
|
|
|
|
|
Usage (single file with automatic splitting): |
|
|
python train_nfqa_model.py --input data.jsonl --output-dir ./model --epochs 10 |
|
|
|
|
|
Usage (pre-split files): |
|
|
python train_nfqa_model.py --train train.jsonl --val val.jsonl --test test.jsonl --output-dir ./model --epochs 10 |
|
|
|
|
|
Author: Ali |
|
|
Date: December 2024 |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import json |
|
|
import argparse |
|
|
import os |
|
|
from collections import Counter |
|
|
from datetime import datetime |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim import AdamW |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
get_linear_schedule_with_warmup |
|
|
) |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.metrics import ( |
|
|
classification_report, |
|
|
confusion_matrix, |
|
|
accuracy_score, |
|
|
f1_score |
|
|
) |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
RANDOM_SEED = 42 |
|
|
np.random.seed(RANDOM_SEED) |
|
|
torch.manual_seed(RANDOM_SEED) |
|
|
|
|
|
NFQA_CATEGORIES = [ |
|
|
'NOT-A-QUESTION', |
|
|
'FACTOID', |
|
|
'DEBATE', |
|
|
'EVIDENCE-BASED', |
|
|
'INSTRUCTION', |
|
|
'REASON', |
|
|
'EXPERIENCE', |
|
|
'COMPARISON' |
|
|
] |
|
|
|
|
|
|
|
|
LABEL2ID = {label: idx for idx, label in enumerate(NFQA_CATEGORIES)} |
|
|
ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} |
|
|
|
|
|
|
|
|
class NFQADataset(Dataset): |
|
|
"""Custom dataset for NFQA classification""" |
|
|
|
|
|
def __init__(self, questions, labels, tokenizer, max_length=128): |
|
|
self.questions = questions |
|
|
self.labels = labels |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.questions) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
question = str(self.questions[idx]) |
|
|
label = int(self.labels[idx]) |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
question, |
|
|
add_special_tokens=True, |
|
|
max_length=self.max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_attention_mask=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].flatten(), |
|
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
|
'labels': torch.tensor(label, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
def train_epoch(model, train_loader, optimizer, scheduler, device): |
|
|
"""Train for one epoch""" |
|
|
model.train() |
|
|
total_loss = 0 |
|
|
predictions = [] |
|
|
true_labels = [] |
|
|
|
|
|
progress_bar = tqdm(train_loader, desc="Training") |
|
|
|
|
|
for batch in progress_bar: |
|
|
|
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
|
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
loss = outputs.loss |
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
preds = torch.argmax(outputs.logits, dim=1) |
|
|
predictions.extend(preds.cpu().numpy()) |
|
|
true_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
|
|
|
progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
avg_loss = total_loss / len(train_loader) |
|
|
accuracy = accuracy_score(true_labels, predictions) |
|
|
|
|
|
return avg_loss, accuracy |
|
|
|
|
|
|
|
|
def evaluate(model, data_loader, device, languages=None, desc="Evaluating", show_analysis=False): |
|
|
"""Evaluate model on validation/test set with optional detailed analysis""" |
|
|
model.eval() |
|
|
total_loss = 0 |
|
|
predictions = [] |
|
|
true_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(data_loader, desc=desc): |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
total_loss += outputs.loss.item() |
|
|
|
|
|
preds = torch.argmax(outputs.logits, dim=1) |
|
|
predictions.extend(preds.cpu().numpy()) |
|
|
true_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
avg_loss = total_loss / len(data_loader) |
|
|
accuracy = accuracy_score(true_labels, predictions) |
|
|
f1 = f1_score(true_labels, predictions, average='macro') |
|
|
|
|
|
|
|
|
if show_analysis and languages is not None: |
|
|
print("\n" + "-"*70) |
|
|
print("VALIDATION ANALYSIS") |
|
|
print("-"*70) |
|
|
|
|
|
|
|
|
analyze_performance_by_category(predictions, true_labels) |
|
|
|
|
|
|
|
|
analyze_performance_by_language(predictions, true_labels, languages, top_n=5) |
|
|
|
|
|
|
|
|
analyze_language_category_combinations(predictions, true_labels, languages, top_n=10) |
|
|
|
|
|
print("-"*70) |
|
|
|
|
|
return avg_loss, accuracy, f1, predictions, true_labels |
|
|
|
|
|
|
|
|
def load_data(file_path): |
|
|
"""Load annotated data from JSONL file""" |
|
|
print(f"Loading data from: {file_path}\n") |
|
|
|
|
|
try: |
|
|
df = pd.read_json(file_path, lines=True) |
|
|
print(f"✓ Loaded {len(df)} annotated examples") |
|
|
|
|
|
|
|
|
if 'question' not in df.columns: |
|
|
raise ValueError("Missing 'question' column") |
|
|
|
|
|
|
|
|
if 'label_id' in df.columns: |
|
|
label_col = 'label_id' |
|
|
elif 'ensemble_prediction' in df.columns: |
|
|
|
|
|
df['label_id'] = df['ensemble_prediction'].map(LABEL2ID) |
|
|
label_col = 'label_id' |
|
|
elif 'label' in df.columns: |
|
|
label_col = 'label' |
|
|
else: |
|
|
raise ValueError("No label column found (expected: 'label', 'label_id', or 'ensemble_prediction')") |
|
|
|
|
|
|
|
|
df = df.dropna(subset=['question', label_col]) |
|
|
|
|
|
print(f"✓ Data cleaned: {len(df)} examples with valid labels") |
|
|
|
|
|
|
|
|
print("\nLabel distribution:") |
|
|
label_counts = df[label_col].value_counts().sort_index() |
|
|
for label_id, count in label_counts.items(): |
|
|
cat_name = ID2LABEL.get(int(label_id), f"UNKNOWN_{label_id}") |
|
|
print(f" {cat_name:20s}: {count:4d} ({count/len(df)*100:5.1f}%)") |
|
|
|
|
|
|
|
|
questions = df['question'].tolist() |
|
|
labels = df[label_col].astype(int).tolist() |
|
|
languages = df['language'].tolist() if 'language' in df.columns else ['unknown'] * len(df) |
|
|
|
|
|
print(f"\n✓ Prepared {len(questions)} question-label pairs") |
|
|
|
|
|
return questions, labels, languages |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"❌ Error: File not found: {file_path}") |
|
|
raise |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading data: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def create_data_splits(questions, labels, test_size=0.2, val_size=0.1): |
|
|
"""Create train/val/test splits""" |
|
|
print("\nCreating data splits...") |
|
|
|
|
|
|
|
|
train_val_questions, test_questions, train_val_labels, test_labels = train_test_split( |
|
|
questions, |
|
|
labels, |
|
|
test_size=test_size, |
|
|
random_state=RANDOM_SEED, |
|
|
stratify=labels |
|
|
) |
|
|
|
|
|
|
|
|
train_questions, val_questions, train_labels, val_labels = train_test_split( |
|
|
train_val_questions, |
|
|
train_val_labels, |
|
|
test_size=val_size / (1 - test_size), |
|
|
random_state=RANDOM_SEED, |
|
|
stratify=train_val_labels |
|
|
) |
|
|
|
|
|
print(f"\nData splits:") |
|
|
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Total: {len(questions):4d} examples") |
|
|
|
|
|
|
|
|
print("\nClass distribution per split:") |
|
|
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]: |
|
|
counts = Counter(split_labels) |
|
|
print(f"\n{split_name}:") |
|
|
for label_id in sorted(counts.keys()): |
|
|
cat_name = ID2LABEL[label_id] |
|
|
print(f" {cat_name:20s}: {counts[label_id]:3d}") |
|
|
|
|
|
return train_questions, val_questions, test_questions, train_labels, val_labels, test_labels |
|
|
|
|
|
|
|
|
def plot_training_curves(history, best_val_f1, output_dir): |
|
|
"""Plot and save training curves""" |
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
|
|
|
|
|
epochs = range(1, len(history['train_loss']) + 1) |
|
|
|
|
|
|
|
|
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2) |
|
|
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2) |
|
|
axes[0].set_xlabel('Epoch') |
|
|
axes[0].set_ylabel('Loss') |
|
|
axes[0].set_title('Training and Validation Loss') |
|
|
axes[0].legend() |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2) |
|
|
axes[1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2) |
|
|
axes[1].set_xlabel('Epoch') |
|
|
axes[1].set_ylabel('Accuracy') |
|
|
axes[1].set_title('Training and Validation Accuracy') |
|
|
axes[1].legend() |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[2].plot(epochs, history['val_f1'], 'g-', label='Val F1 (Macro)', linewidth=2) |
|
|
axes[2].axhline(y=best_val_f1, color='r', linestyle='--', label=f'Best F1: {best_val_f1:.4f}') |
|
|
axes[2].set_xlabel('Epoch') |
|
|
axes[2].set_ylabel('F1 Score') |
|
|
axes[2].set_title('Validation F1 Score') |
|
|
axes[2].legend() |
|
|
axes[2].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plot_file = os.path.join(output_dir, 'training_curves.png') |
|
|
plt.savefig(plot_file, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"✓ Training curves saved to: {plot_file}") |
|
|
|
|
|
|
|
|
def analyze_performance_by_language(predictions, true_labels, languages, top_n=10): |
|
|
"""Analyze and print performance by language""" |
|
|
from collections import defaultdict |
|
|
|
|
|
lang_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
|
|
|
for pred, true, lang in zip(predictions, true_labels, languages): |
|
|
lang_stats[lang]['total'] += 1 |
|
|
if pred == true: |
|
|
lang_stats[lang]['correct'] += 1 |
|
|
|
|
|
|
|
|
lang_accuracies = [] |
|
|
for lang, stats in lang_stats.items(): |
|
|
if stats['total'] >= 5: |
|
|
acc = stats['correct'] / stats['total'] |
|
|
lang_accuracies.append({ |
|
|
'language': lang, |
|
|
'accuracy': acc, |
|
|
'correct': stats['correct'], |
|
|
'total': stats['total'], |
|
|
'errors': stats['total'] - stats['correct'] |
|
|
}) |
|
|
|
|
|
lang_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"WORST {top_n} LANGUAGES (with >= 5 examples)") |
|
|
print(f"{'='*70}") |
|
|
print(f"{'Language':<12} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") |
|
|
print(f"{'-'*70}") |
|
|
|
|
|
for item in lang_accuracies[:top_n]: |
|
|
print(f"{item['language']:<12} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") |
|
|
|
|
|
return lang_stats, lang_accuracies |
|
|
|
|
|
|
|
|
def analyze_performance_by_category(predictions, true_labels): |
|
|
"""Analyze and print performance by category""" |
|
|
from collections import defaultdict |
|
|
|
|
|
cat_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
|
|
|
for pred, true in zip(predictions, true_labels): |
|
|
cat_stats[true]['total'] += 1 |
|
|
if pred == true: |
|
|
cat_stats[true]['correct'] += 1 |
|
|
|
|
|
cat_accuracies = [] |
|
|
for cat_id, stats in cat_stats.items(): |
|
|
acc = stats['correct'] / stats['total'] |
|
|
cat_accuracies.append({ |
|
|
'category': ID2LABEL[cat_id], |
|
|
'accuracy': acc, |
|
|
'correct': stats['correct'], |
|
|
'total': stats['total'], |
|
|
'errors': stats['total'] - stats['correct'] |
|
|
}) |
|
|
|
|
|
cat_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"PERFORMANCE BY CATEGORY") |
|
|
print(f"{'='*70}") |
|
|
print(f"{'Category':<20} {'Accuracy':<12} {'Errors':<10} {'Total':<10}") |
|
|
print(f"{'-'*70}") |
|
|
|
|
|
for item in cat_accuracies: |
|
|
print(f"{item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>8} {item['total']:>8}") |
|
|
|
|
|
return cat_stats, cat_accuracies |
|
|
|
|
|
|
|
|
def analyze_language_category_combinations(predictions, true_labels, languages, top_n=15): |
|
|
"""Analyze performance by (language, category) combinations""" |
|
|
from collections import defaultdict |
|
|
|
|
|
combo_stats = defaultdict(lambda: {'correct': 0, 'total': 0}) |
|
|
|
|
|
for pred, true, lang in zip(predictions, true_labels, languages): |
|
|
key = (lang, ID2LABEL[true]) |
|
|
combo_stats[key]['total'] += 1 |
|
|
if pred == true: |
|
|
combo_stats[key]['correct'] += 1 |
|
|
|
|
|
combo_accuracies = [] |
|
|
for (lang, cat), stats in combo_stats.items(): |
|
|
if stats['total'] >= 3: |
|
|
acc = stats['correct'] / stats['total'] |
|
|
combo_accuracies.append({ |
|
|
'language': lang, |
|
|
'category': cat, |
|
|
'accuracy': acc, |
|
|
'correct': stats['correct'], |
|
|
'total': stats['total'], |
|
|
'errors': stats['total'] - stats['correct'] |
|
|
}) |
|
|
|
|
|
combo_accuracies.sort(key=lambda x: x['accuracy']) |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"WORST {top_n} LANGUAGE-CATEGORY COMBINATIONS (with >= 3 examples)") |
|
|
print(f"{'='*80}") |
|
|
print(f"{'Language':<12} {'Category':<20} {'Accuracy':<12} {'Errors':<8} {'Total':<8}") |
|
|
print(f"{'-'*80}") |
|
|
|
|
|
for item in combo_accuracies[:top_n]: |
|
|
print(f"{item['language']:<12} {item['category']:<20} {item['accuracy']:>10.2%} {item['errors']:>6} {item['total']:>6}") |
|
|
|
|
|
return combo_stats, combo_accuracies |
|
|
|
|
|
|
|
|
def plot_confusion_matrix(test_true, test_preds, output_dir): |
|
|
"""Plot and save confusion matrix""" |
|
|
cm = confusion_matrix(test_true, test_preds, labels=list(range(len(NFQA_CATEGORIES)))) |
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt='d', |
|
|
cmap='Blues', |
|
|
xticklabels=NFQA_CATEGORIES, |
|
|
yticklabels=NFQA_CATEGORIES, |
|
|
cbar_kws={'label': 'Count'} |
|
|
) |
|
|
plt.xlabel('Predicted Category') |
|
|
plt.ylabel('True Category') |
|
|
plt.title('Confusion Matrix - Test Set') |
|
|
plt.xticks(rotation=45, ha='right') |
|
|
plt.yticks(rotation=0) |
|
|
plt.tight_layout() |
|
|
|
|
|
cm_file = os.path.join(output_dir, 'confusion_matrix.png') |
|
|
plt.savefig(cm_file, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"✓ Confusion matrix saved to: {cm_file}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Train NFQA Classification Model') |
|
|
|
|
|
|
|
|
parser.add_argument('--input', type=str, |
|
|
help='Input JSONL file with annotated data (will be split automatically)') |
|
|
parser.add_argument('--train', type=str, |
|
|
help='Training set JSONL file (use with --val and --test)') |
|
|
parser.add_argument('--val', type=str, |
|
|
help='Validation set JSONL file (use with --train and --test)') |
|
|
parser.add_argument('--test', type=str, |
|
|
help='Test set JSONL file (use with --train and --val)') |
|
|
parser.add_argument('--output-dir', type=str, default='./nfqa_model_trained', |
|
|
help='Output directory for model and results') |
|
|
|
|
|
|
|
|
parser.add_argument('--model-name', type=str, default='xlm-roberta-base', |
|
|
help='Pretrained model name (default: xlm-roberta-base)') |
|
|
parser.add_argument('--max-length', type=int, default=128, |
|
|
help='Maximum sequence length (default: 128)') |
|
|
|
|
|
|
|
|
parser.add_argument('--batch-size', type=int, default=16, |
|
|
help='Batch size (default: 16)') |
|
|
parser.add_argument('--epochs', type=int, default=10, |
|
|
help='Number of epochs (default: 10)') |
|
|
parser.add_argument('--learning-rate', type=float, default=2e-5, |
|
|
help='Learning rate (default: 2e-5)') |
|
|
parser.add_argument('--warmup-steps', type=int, default=500, |
|
|
help='Warmup steps (default: 500)') |
|
|
parser.add_argument('--weight-decay', type=float, default=0.01, |
|
|
help='Weight decay (default: 0.01)') |
|
|
parser.add_argument('--dropout', type=float, default=0.1, |
|
|
help='Dropout probability (default: 0.1)') |
|
|
|
|
|
|
|
|
parser.add_argument('--test-size', type=float, default=0.2, |
|
|
help='Test set size (default: 0.2)') |
|
|
parser.add_argument('--val-size', type=float, default=0.1, |
|
|
help='Validation set size (default: 0.1)') |
|
|
|
|
|
|
|
|
parser.add_argument('--device', type=str, default='auto', |
|
|
help='Device to use: cuda, cpu, or auto (default: auto)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
has_single_input = args.input is not None |
|
|
has_split_inputs = all([args.train, args.val, args.test]) |
|
|
|
|
|
if not has_single_input and not has_split_inputs: |
|
|
parser.error("Either --input OR (--train, --val, --test) must be provided") |
|
|
|
|
|
if has_single_input and has_split_inputs: |
|
|
parser.error("Cannot use --input together with --train/--val/--test. Choose one approach.") |
|
|
|
|
|
|
|
|
print("="*80) |
|
|
print("NFQA MODEL TRAINING") |
|
|
print("="*80) |
|
|
if has_single_input: |
|
|
print(f"Input file: {args.input}") |
|
|
print(f"Data splitting: automatic (test={args.test_size}, val={args.val_size})") |
|
|
else: |
|
|
print(f"Train file: {args.train}") |
|
|
print(f"Val file: {args.val}") |
|
|
print(f"Test file: {args.test}") |
|
|
print(f"Data splitting: manual (pre-split)") |
|
|
print(f"Output directory: {args.output_dir}") |
|
|
print(f"Model: {args.model_name}") |
|
|
print(f"Epochs: {args.epochs}") |
|
|
print(f"Batch size: {args.batch_size}") |
|
|
print(f"Learning rate: {args.learning_rate}") |
|
|
print(f"Max length: {args.max_length}") |
|
|
print(f"Weight decay: {args.weight_decay}") |
|
|
print(f"Dropout: {args.dropout}") |
|
|
print("="*80 + "\n") |
|
|
|
|
|
|
|
|
if args.device == 'auto': |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
device = torch.device(args.device) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(RANDOM_SEED) |
|
|
|
|
|
print(f"Device: {device}") |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"CUDA device: {torch.cuda.get_device_name(0)}\n") |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if has_single_input: |
|
|
|
|
|
questions, labels, languages = load_data(args.input) |
|
|
|
|
|
|
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
train_val_questions, test_questions, train_val_labels, test_labels, train_val_langs, test_langs = train_test_split( |
|
|
questions, labels, languages, |
|
|
test_size=args.test_size, |
|
|
random_state=RANDOM_SEED, |
|
|
stratify=labels |
|
|
) |
|
|
|
|
|
|
|
|
train_questions, val_questions, train_labels, val_labels, train_langs, val_langs = train_test_split( |
|
|
train_val_questions, train_val_labels, train_val_langs, |
|
|
test_size=args.val_size / (1 - args.test_size), |
|
|
random_state=RANDOM_SEED, |
|
|
stratify=train_val_labels |
|
|
) |
|
|
|
|
|
print(f"\nData splits:") |
|
|
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/len(questions)*100:5.1f}%)") |
|
|
print(f" Total: {len(questions):4d} examples") |
|
|
else: |
|
|
|
|
|
print("Loading pre-split datasets...\n") |
|
|
train_questions, train_labels, train_langs = load_data(args.train) |
|
|
val_questions, val_labels, val_langs = load_data(args.val) |
|
|
test_questions, test_labels, test_langs = load_data(args.test) |
|
|
|
|
|
|
|
|
total_examples = len(train_questions) + len(val_questions) + len(test_questions) |
|
|
print(f"\nData splits:") |
|
|
print(f" Training: {len(train_questions):4d} examples ({len(train_questions)/total_examples*100:5.1f}%)") |
|
|
print(f" Validation: {len(val_questions):4d} examples ({len(val_questions)/total_examples*100:5.1f}%)") |
|
|
print(f" Test: {len(test_questions):4d} examples ({len(test_questions)/total_examples*100:5.1f}%)") |
|
|
print(f" Total: {total_examples:4d} examples") |
|
|
|
|
|
|
|
|
print("\nClass distribution per split:") |
|
|
for split_name, split_labels in [('Train', train_labels), ('Val', val_labels), ('Test', test_labels)]: |
|
|
counts = Counter(split_labels) |
|
|
print(f"\n{split_name}:") |
|
|
for label_id in sorted(counts.keys()): |
|
|
cat_name = ID2LABEL[label_id] |
|
|
print(f" {cat_name:20s}: {counts[label_id]:3d}") |
|
|
|
|
|
|
|
|
print(f"\nLoading tokenizer: {args.model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
|
print("✓ Tokenizer loaded") |
|
|
|
|
|
print(f"\nLoading model: {args.model_name}") |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
args.model_name, |
|
|
num_labels=len(NFQA_CATEGORIES), |
|
|
id2label=ID2LABEL, |
|
|
label2id=LABEL2ID, |
|
|
hidden_dropout_prob=args.dropout, |
|
|
attention_probs_dropout_prob=args.dropout, |
|
|
classifier_dropout=args.dropout |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
print(f"✓ Model loaded") |
|
|
print(f" Number of parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
print(f" Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") |
|
|
|
|
|
|
|
|
print("\nCreating datasets...") |
|
|
train_dataset = NFQADataset(train_questions, train_labels, tokenizer, args.max_length) |
|
|
val_dataset = NFQADataset(val_questions, val_labels, tokenizer, args.max_length) |
|
|
test_dataset = NFQADataset(test_questions, test_labels, tokenizer, args.max_length) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size) |
|
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size) |
|
|
|
|
|
print(f"✓ Datasets created") |
|
|
print(f" Train: {len(train_dataset)} examples ({len(train_loader)} batches)") |
|
|
print(f" Val: {len(val_dataset)} examples ({len(val_loader)} batches)") |
|
|
print(f" Test: {len(test_dataset)} examples ({len(test_loader)} batches)") |
|
|
|
|
|
|
|
|
optimizer = AdamW( |
|
|
model.parameters(), |
|
|
lr=args.learning_rate, |
|
|
weight_decay=args.weight_decay |
|
|
) |
|
|
|
|
|
total_steps = len(train_loader) * args.epochs |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=args.warmup_steps, |
|
|
num_training_steps=total_steps |
|
|
) |
|
|
|
|
|
print(f"\n✓ Optimizer and scheduler configured") |
|
|
print(f" Total training steps: {total_steps}") |
|
|
print(f" Warmup steps: {args.warmup_steps}") |
|
|
|
|
|
|
|
|
history = { |
|
|
'train_loss': [], |
|
|
'train_accuracy': [], |
|
|
'val_loss': [], |
|
|
'val_accuracy': [], |
|
|
'val_f1': [] |
|
|
} |
|
|
|
|
|
best_val_f1 = 0 |
|
|
best_epoch = 0 |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("STARTING TRAINING") |
|
|
print("="*80 + "\n") |
|
|
|
|
|
for epoch in range(args.epochs): |
|
|
print(f"\nEpoch {epoch + 1}/{args.epochs}") |
|
|
print("-" * 80) |
|
|
|
|
|
|
|
|
train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device) |
|
|
|
|
|
|
|
|
val_loss, val_acc, val_f1, val_preds, val_true = evaluate( |
|
|
model, val_loader, device, |
|
|
languages=val_langs, |
|
|
desc="Validating", |
|
|
show_analysis=True |
|
|
) |
|
|
|
|
|
|
|
|
history['train_loss'].append(train_loss) |
|
|
history['train_accuracy'].append(train_acc) |
|
|
history['val_loss'].append(val_loss) |
|
|
history['val_accuracy'].append(val_acc) |
|
|
history['val_f1'].append(val_f1) |
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch + 1} Summary:") |
|
|
print(f" Train Loss: {train_loss:.4f}") |
|
|
print(f" Train Accuracy: {train_acc:.4f}") |
|
|
print(f" Val Loss: {val_loss:.4f}") |
|
|
print(f" Val Accuracy: {val_acc:.4f}") |
|
|
print(f" Val F1 (Macro): {val_f1:.4f}") |
|
|
|
|
|
|
|
|
if val_f1 > best_val_f1: |
|
|
best_val_f1 = val_f1 |
|
|
best_epoch = epoch + 1 |
|
|
|
|
|
|
|
|
model_path = os.path.join(args.output_dir, 'best_model') |
|
|
model.save_pretrained(model_path) |
|
|
tokenizer.save_pretrained(model_path) |
|
|
|
|
|
print(f" ✓ New best model saved! (F1: {val_f1:.4f})") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("TRAINING COMPLETE") |
|
|
print("="*80) |
|
|
print(f"Best epoch: {best_epoch}") |
|
|
print(f"Best validation F1: {best_val_f1:.4f}") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
history_file = os.path.join(args.output_dir, 'training_history.json') |
|
|
with open(history_file, 'w') as f: |
|
|
json.dump(history, f, indent=2) |
|
|
print(f"\n✓ Training history saved to: {history_file}") |
|
|
|
|
|
|
|
|
final_model_path = os.path.join(args.output_dir, 'final_model') |
|
|
model.save_pretrained(final_model_path) |
|
|
tokenizer.save_pretrained(final_model_path) |
|
|
print(f"✓ Final model saved to: {final_model_path}") |
|
|
|
|
|
|
|
|
plot_training_curves(history, best_val_f1, args.output_dir) |
|
|
|
|
|
|
|
|
print("\nLoading best model for final evaluation...") |
|
|
best_model_path = os.path.join(args.output_dir, 'best_model') |
|
|
model = AutoModelForSequenceClassification.from_pretrained(best_model_path) |
|
|
model.to(device) |
|
|
|
|
|
test_loss, test_acc, test_f1, test_preds, test_true = evaluate(model, test_loader, device, desc="Testing") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("FINAL TEST SET RESULTS") |
|
|
print("="*80) |
|
|
print(f"Test Loss: {test_loss:.4f}") |
|
|
print(f"Test Accuracy: {test_acc:.4f}") |
|
|
print(f"Test F1 (Macro): {test_f1:.4f}") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("PER-CATEGORY PERFORMANCE") |
|
|
print("="*80 + "\n") |
|
|
|
|
|
report = classification_report( |
|
|
test_true, |
|
|
test_preds, |
|
|
labels=list(range(len(NFQA_CATEGORIES))), |
|
|
target_names=NFQA_CATEGORIES, |
|
|
zero_division=0 |
|
|
) |
|
|
print(report) |
|
|
|
|
|
|
|
|
report_file = os.path.join(args.output_dir, 'classification_report.txt') |
|
|
with open(report_file, 'w') as f: |
|
|
f.write(report) |
|
|
print(f"✓ Classification report saved to: {report_file}") |
|
|
|
|
|
|
|
|
plot_confusion_matrix(test_true, test_preds, args.output_dir) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("DETAILED PERFORMANCE ANALYSIS") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
analyze_performance_by_category(test_preds, test_true) |
|
|
|
|
|
|
|
|
analyze_performance_by_language(test_preds, test_true, test_langs, top_n=10) |
|
|
|
|
|
|
|
|
analyze_language_category_combinations(test_preds, test_true, test_langs, top_n=15) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
|
|
|
|
|
|
test_results = { |
|
|
'test_loss': float(test_loss), |
|
|
'test_accuracy': float(test_acc), |
|
|
'test_f1_macro': float(test_f1), |
|
|
'best_epoch': int(best_epoch), |
|
|
'best_val_f1': float(best_val_f1), |
|
|
'num_train_examples': len(train_questions), |
|
|
'num_val_examples': len(val_questions), |
|
|
'num_test_examples': len(test_questions), |
|
|
'config': { |
|
|
'model_name': args.model_name, |
|
|
'max_length': args.max_length, |
|
|
'batch_size': args.batch_size, |
|
|
'learning_rate': args.learning_rate, |
|
|
'num_epochs': args.epochs, |
|
|
'warmup_steps': args.warmup_steps, |
|
|
'weight_decay': args.weight_decay, |
|
|
'dropout': args.dropout, |
|
|
'data_source': 'pre-split' if has_split_inputs else 'single_file', |
|
|
'train_file': args.train if has_split_inputs else args.input, |
|
|
'val_file': args.val if has_split_inputs else None, |
|
|
'test_file': args.test if has_split_inputs else None, |
|
|
'auto_split': not has_split_inputs, |
|
|
'test_size': args.test_size if not has_split_inputs else None, |
|
|
'val_size': args.val_size if not has_split_inputs else None |
|
|
}, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
results_file = os.path.join(args.output_dir, 'test_results.json') |
|
|
with open(results_file, 'w') as f: |
|
|
json.dump(test_results, f, indent=2) |
|
|
print(f"✓ Test results saved to: {results_file}") |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("TRAINING SUMMARY") |
|
|
print("="*80) |
|
|
print(f"\nModel: {args.model_name}") |
|
|
print(f"Training examples: {len(train_questions)}") |
|
|
print(f"Validation examples: {len(val_questions)}") |
|
|
print(f"Test examples: {len(test_questions)}") |
|
|
print(f"\nBest epoch: {best_epoch}/{args.epochs}") |
|
|
print(f"Best validation F1: {best_val_f1:.4f}") |
|
|
print(f"\nFinal test results:") |
|
|
print(f" Accuracy: {test_acc:.4f}") |
|
|
print(f" F1 Score (Macro): {test_f1:.4f}") |
|
|
print(f"\nModel saved to: {args.output_dir}") |
|
|
print(f"\nGenerated files:") |
|
|
print(f" - best_model/ (best checkpoint)") |
|
|
print(f" - final_model/ (last epoch)") |
|
|
print(f" - training_history.json") |
|
|
print(f" - training_curves.png") |
|
|
print(f" - test_results.json") |
|
|
print(f" - classification_report.txt") |
|
|
print(f" - confusion_matrix.png") |
|
|
print("\n" + "="*80) |
|
|
print("✅ Training complete! Model ready for deployment.") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|