pangweijlu's picture
Fix __file__ for inline execution
100c45f verified
"""
Training script for multimodal fraudulent paper detection - GPU optimized.
Uses mixed precision (fp16), class weighting, and saves full metrics.
"""
import os
import sys
SCRIPT_DIR = os.getcwd()
if SCRIPT_DIR not in sys.path:
sys.path.insert(0, SCRIPT_DIR)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from transformers import get_linear_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from sklearn.metrics import (
accuracy_score, precision_recall_fscore_support,
roc_auc_score, classification_report, confusion_matrix
)
from tqdm import tqdm
import json
import time
from model import MultimodalFraudDetector
from data_loader import FraudPaperDataset, collate_fn
def compute_metrics(predictions, labels, probs):
preds = np.argmax(predictions, axis=1)
accuracy = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
try:
auc = roc_auc_score(labels, probs[:, 1])
except:
auc = 0.5
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': auc
}
def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, scaler, class_weights):
model.train()
total_loss = 0
all_preds, all_labels, all_probs = [], [], []
start_time = time.time()
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
for batch in pbar:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
tabular = batch['tabular_features'].to(device)
metadata = batch['metadata_features'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
with autocast():
outputs = model(
text_input_ids=input_ids,
text_attention_mask=attention_mask,
tabular_features=tabular,
metadata_features=metadata
)
logits = outputs['logits']
ce_loss = nn.CrossEntropyLoss(weight=class_weights)(logits, labels)
# Modality consistency
modality_scores = outputs['modality_scores']
consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1
# Anomaly loss
anomaly_score = outputs['anomaly_score']
fraud_mask = labels == 1
if fraud_mask.any():
anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2)
anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2)
else:
anomaly_loss = torch.tensor(0.0, device=device)
loss = ce_loss + consistency_loss + 0.1 * anomaly_loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
probs = torch.softmax(logits, dim=1).detach().cpu().numpy()
all_preds.append(logits.detach().cpu().numpy())
all_labels.append(labels.cpu().numpy())
all_probs.append(probs)
pbar.set_postfix({'loss': loss.item()})
elapsed = time.time() - start_time
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)
metrics = compute_metrics(all_preds, all_labels, all_probs)
metrics['loss'] = total_loss / len(dataloader)
metrics['time'] = elapsed
return metrics
def evaluate(model, dataloader, device):
model.eval()
total_loss = 0
all_preds, all_labels, all_probs = [], [], []
all_embeddings, all_anomaly = [], []
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
tabular = batch['tabular_features'].to(device)
metadata = batch['metadata_features'].to(device)
labels = batch['labels'].to(device)
with autocast():
outputs = model(
text_input_ids=input_ids,
text_attention_mask=attention_mask,
tabular_features=tabular,
metadata_features=metadata
)
logits = outputs['logits']
loss = nn.CrossEntropyLoss()(logits, labels)
total_loss += loss.item()
probs = torch.softmax(logits, dim=1).cpu().numpy()
all_preds.append(logits.cpu().numpy())
all_labels.append(labels.cpu().numpy())
all_probs.append(probs)
all_embeddings.append(outputs['fused_embedding'].cpu().numpy())
all_anomaly.append(outputs['anomaly_score'].cpu().numpy())
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)
all_embeddings = np.concatenate(all_embeddings)
all_anomaly = np.concatenate(all_anomaly)
metrics = compute_metrics(all_preds, all_labels, all_probs)
metrics['loss'] = total_loss / len(dataloader)
return metrics, all_embeddings, all_anomaly, all_labels, np.argmax(all_preds, axis=1)
def main():
print("=" * 70)
print("MULTIMODAL FRAUD DETECTION - GPU OPTIMIZED TRAINING")
print("=" * 70)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
output_dir = './outputs'
os.makedirs(output_dir, exist_ok=True)
# Load data
print("\nLoading dataset...")
dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256)
# Split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn, pin_memory=True)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
# Calculate class weights for imbalance
train_labels = [dataset.labels[idx] for idx in train_ds.indices]
fraud_count = sum(train_labels)
authentic_count = len(train_labels) - fraud_count
print(f"Class distribution - Authentic: {authentic_count}, Fraudulent: {fraud_count}")
weight_fraud = authentic_count / max(fraud_count, 1)
class_weights = torch.tensor([1.0, weight_fraud]).to(device)
print(f"Class weights: [1.0, {weight_fraud:.2f}]")
# Get dims
sample = next(iter(train_loader))
tabular_dim = sample['tabular_features'].shape[1]
metadata_dim = sample['metadata_features'].shape[1]
print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}")
# Model
print("\nBuilding model...")
model = MultimodalFraudDetector(
text_model="allenai/scibert_scivocab_uncased",
tabular_features=tabular_dim,
metadata_features=metadata_dim,
fused_dim=256,
freeze_text_layers=8
).to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable: {trainable:,}")
print(f"Frozen: {total_params - trainable:,}")
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * 5
warmup = int(total_steps * 0.1)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps)
# Mixed precision
scaler = GradScaler()
# Training
print("\n" + "=" * 70)
print("TRAINING")
print("=" * 70)
best_f1 = 0
history = []
for epoch in range(1, 6):
print(f"\n>>> Epoch {epoch}/5 <<<")
train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch, scaler, class_weights)
print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, "
f"Prec: {train_metrics['precision']:.4f}, Rec: {train_metrics['recall']:.4f}, "
f"F1: {train_metrics['f1']:.4f}, AUC: {train_metrics['auc']:.4f} "
f"({train_metrics['time']:.1f}s)")
val_metrics, val_emb, val_anom, val_labels, val_preds = evaluate(model, val_loader, device)
print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, "
f"Prec: {val_metrics['precision']:.4f}, Rec: {val_metrics['recall']:.4f}, "
f"F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}")
history.append({
'epoch': epoch,
'train': train_metrics,
'val': val_metrics
})
if val_metrics['f1'] > best_f1:
best_f1 = val_metrics['f1']
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'f1': best_f1,
'metrics': val_metrics,
'tabular_dim': tabular_dim,
'metadata_dim': metadata_dim,
}, os.path.join(output_dir, 'best_model.pt'))
print(f"*** Saved best model (F1: {best_f1:.4f}) ***")
# Final evaluation
print("\n" + "=" * 70)
print("FINAL EVALUATION")
print("=" * 70)
# Load best model
checkpoint = torch.load(os.path.join(output_dir, 'best_model.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
final_metrics, final_emb, final_anom, final_labels, final_preds = evaluate(model, val_loader, device)
print("\nClassification Report:")
print(classification_report(final_labels, final_preds, target_names=['authentic', 'fraudulent'], zero_division=0))
print("Confusion Matrix:")
cm = confusion_matrix(final_labels, final_preds)
print(f" Predicted")
print(f" Auth Fraud")
print(f"Actual Auth [{cm[0,0]:4d} {cm[0,1]:4d}]")
print(f" Fraud [{cm[1,0]:4d} {cm[1,1]:4d}]")
print(f"\nFinal Metrics:")
print(f" Accuracy: {final_metrics['accuracy']:.4f}")
print(f" Precision: {final_metrics['precision']:.4f}")
print(f" Recall: {final_metrics['recall']:.4f}")
print(f" F1-Score: {final_metrics['f1']:.4f}")
print(f" AUC: {final_metrics['auc']:.4f}")
# Save everything
np.save(os.path.join(output_dir, 'val_embeddings.npy'), final_emb)
np.save(os.path.join(output_dir, 'val_anomaly.npy'), final_anom)
np.save(os.path.join(output_dir, 'val_labels.npy'), final_labels)
np.save(os.path.join(output_dir, 'val_preds.npy'), final_preds)
with open(os.path.join(output_dir, 'metrics.json'), 'w') as f:
json.dump({
'best_f1': best_f1,
'final_metrics': final_metrics,
'history': history,
'model_info': {
'total_params': total_params,
'trainable_params': trainable,
'tabular_dim': tabular_dim,
'metadata_dim': metadata_dim,
}
}, f, indent=2)
print(f"\n{'=' * 70}")
print(f"TRAINING COMPLETE!")
print(f"Best F1: {best_f1:.4f}")
print(f"All outputs saved to: {output_dir}")
print(f"{'=' * 70}")
if __name__ == '__main__':
main()