""" Training script for multimodal fraudulent paper detection - GPU optimized. Uses mixed precision (fp16), class weighting, and saves full metrics. """ import os import sys SCRIPT_DIR = os.getcwd() if SCRIPT_DIR not in sys.path: sys.path.insert(0, SCRIPT_DIR) import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from transformers import get_linear_schedule_with_warmup from torch.cuda.amp import autocast, GradScaler import numpy as np from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, roc_auc_score, classification_report, confusion_matrix ) from tqdm import tqdm import json import time from model import MultimodalFraudDetector from data_loader import FraudPaperDataset, collate_fn def compute_metrics(predictions, labels, probs): preds = np.argmax(predictions, axis=1) accuracy = accuracy_score(labels, preds) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0) try: auc = roc_auc_score(labels, probs[:, 1]) except: auc = 0.5 return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc } def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, scaler, class_weights): model.train() total_loss = 0 all_preds, all_labels, all_probs = [], [], [] start_time = time.time() pbar = tqdm(dataloader, desc=f"Epoch {epoch}") for batch in pbar: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) tabular = batch['tabular_features'].to(device) metadata = batch['metadata_features'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() with autocast(): outputs = model( text_input_ids=input_ids, text_attention_mask=attention_mask, tabular_features=tabular, metadata_features=metadata ) logits = outputs['logits'] ce_loss = nn.CrossEntropyLoss(weight=class_weights)(logits, labels) # Modality consistency modality_scores = outputs['modality_scores'] consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1 # Anomaly loss anomaly_score = outputs['anomaly_score'] fraud_mask = labels == 1 if fraud_mask.any(): anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2) anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2) else: anomaly_loss = torch.tensor(0.0, device=device) loss = ce_loss + consistency_loss + 0.1 * anomaly_loss scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() total_loss += loss.item() probs = torch.softmax(logits, dim=1).detach().cpu().numpy() all_preds.append(logits.detach().cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_probs.append(probs) pbar.set_postfix({'loss': loss.item()}) elapsed = time.time() - start_time all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) all_probs = np.concatenate(all_probs) metrics = compute_metrics(all_preds, all_labels, all_probs) metrics['loss'] = total_loss / len(dataloader) metrics['time'] = elapsed return metrics def evaluate(model, dataloader, device): model.eval() total_loss = 0 all_preds, all_labels, all_probs = [], [], [] all_embeddings, all_anomaly = [], [] with torch.no_grad(): for batch in tqdm(dataloader, desc="Evaluating"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) tabular = batch['tabular_features'].to(device) metadata = batch['metadata_features'].to(device) labels = batch['labels'].to(device) with autocast(): outputs = model( text_input_ids=input_ids, text_attention_mask=attention_mask, tabular_features=tabular, metadata_features=metadata ) logits = outputs['logits'] loss = nn.CrossEntropyLoss()(logits, labels) total_loss += loss.item() probs = torch.softmax(logits, dim=1).cpu().numpy() all_preds.append(logits.cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_probs.append(probs) all_embeddings.append(outputs['fused_embedding'].cpu().numpy()) all_anomaly.append(outputs['anomaly_score'].cpu().numpy()) all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) all_probs = np.concatenate(all_probs) all_embeddings = np.concatenate(all_embeddings) all_anomaly = np.concatenate(all_anomaly) metrics = compute_metrics(all_preds, all_labels, all_probs) metrics['loss'] = total_loss / len(dataloader) return metrics, all_embeddings, all_anomaly, all_labels, np.argmax(all_preds, axis=1) def main(): print("=" * 70) print("MULTIMODAL FRAUD DETECTION - GPU OPTIMIZED TRAINING") print("=" * 70) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") if device.type == 'cuda': print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") output_dir = './outputs' os.makedirs(output_dir, exist_ok=True) # Load data print("\nLoading dataset...") dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256) # Split train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_ds, val_ds = random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, collate_fn=collate_fn, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn, pin_memory=True) print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") # Calculate class weights for imbalance train_labels = [dataset.labels[idx] for idx in train_ds.indices] fraud_count = sum(train_labels) authentic_count = len(train_labels) - fraud_count print(f"Class distribution - Authentic: {authentic_count}, Fraudulent: {fraud_count}") weight_fraud = authentic_count / max(fraud_count, 1) class_weights = torch.tensor([1.0, weight_fraud]).to(device) print(f"Class weights: [1.0, {weight_fraud:.2f}]") # Get dims sample = next(iter(train_loader)) tabular_dim = sample['tabular_features'].shape[1] metadata_dim = sample['metadata_features'].shape[1] print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}") # Model print("\nBuilding model...") model = MultimodalFraudDetector( text_model="allenai/scibert_scivocab_uncased", tabular_features=tabular_dim, metadata_features=metadata_dim, fused_dim=256, freeze_text_layers=8 ).to(device) total_params = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total params: {total_params:,}") print(f"Trainable: {trainable:,}") print(f"Frozen: {total_params - trainable:,}") # Optimizer optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) total_steps = len(train_loader) * 5 warmup = int(total_steps * 0.1) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps) # Mixed precision scaler = GradScaler() # Training print("\n" + "=" * 70) print("TRAINING") print("=" * 70) best_f1 = 0 history = [] for epoch in range(1, 6): print(f"\n>>> Epoch {epoch}/5 <<<") train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch, scaler, class_weights) print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, " f"Prec: {train_metrics['precision']:.4f}, Rec: {train_metrics['recall']:.4f}, " f"F1: {train_metrics['f1']:.4f}, AUC: {train_metrics['auc']:.4f} " f"({train_metrics['time']:.1f}s)") val_metrics, val_emb, val_anom, val_labels, val_preds = evaluate(model, val_loader, device) print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " f"Prec: {val_metrics['precision']:.4f}, Rec: {val_metrics['recall']:.4f}, " f"F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}") history.append({ 'epoch': epoch, 'train': train_metrics, 'val': val_metrics }) if val_metrics['f1'] > best_f1: best_f1 = val_metrics['f1'] torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'f1': best_f1, 'metrics': val_metrics, 'tabular_dim': tabular_dim, 'metadata_dim': metadata_dim, }, os.path.join(output_dir, 'best_model.pt')) print(f"*** Saved best model (F1: {best_f1:.4f}) ***") # Final evaluation print("\n" + "=" * 70) print("FINAL EVALUATION") print("=" * 70) # Load best model checkpoint = torch.load(os.path.join(output_dir, 'best_model.pt')) model.load_state_dict(checkpoint['model_state_dict']) final_metrics, final_emb, final_anom, final_labels, final_preds = evaluate(model, val_loader, device) print("\nClassification Report:") print(classification_report(final_labels, final_preds, target_names=['authentic', 'fraudulent'], zero_division=0)) print("Confusion Matrix:") cm = confusion_matrix(final_labels, final_preds) print(f" Predicted") print(f" Auth Fraud") print(f"Actual Auth [{cm[0,0]:4d} {cm[0,1]:4d}]") print(f" Fraud [{cm[1,0]:4d} {cm[1,1]:4d}]") print(f"\nFinal Metrics:") print(f" Accuracy: {final_metrics['accuracy']:.4f}") print(f" Precision: {final_metrics['precision']:.4f}") print(f" Recall: {final_metrics['recall']:.4f}") print(f" F1-Score: {final_metrics['f1']:.4f}") print(f" AUC: {final_metrics['auc']:.4f}") # Save everything np.save(os.path.join(output_dir, 'val_embeddings.npy'), final_emb) np.save(os.path.join(output_dir, 'val_anomaly.npy'), final_anom) np.save(os.path.join(output_dir, 'val_labels.npy'), final_labels) np.save(os.path.join(output_dir, 'val_preds.npy'), final_preds) with open(os.path.join(output_dir, 'metrics.json'), 'w') as f: json.dump({ 'best_f1': best_f1, 'final_metrics': final_metrics, 'history': history, 'model_info': { 'total_params': total_params, 'trainable_params': trainable, 'tabular_dim': tabular_dim, 'metadata_dim': metadata_dim, } }, f, indent=2) print(f"\n{'=' * 70}") print(f"TRAINING COMPLETE!") print(f"Best F1: {best_f1:.4f}") print(f"All outputs saved to: {output_dir}") print(f"{'=' * 70}") if __name__ == '__main__': main()