| """ |
| Training script for multimodal fraudulent paper detection - GPU optimized. |
| Uses mixed precision (fp16), class weighting, and saves full metrics. |
| """ |
|
|
| import os |
| import sys |
|
|
| SCRIPT_DIR = os.getcwd() |
| if SCRIPT_DIR not in sys.path: |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, random_split |
| from transformers import get_linear_schedule_with_warmup |
| from torch.cuda.amp import autocast, GradScaler |
| import numpy as np |
| from sklearn.metrics import ( |
| accuracy_score, precision_recall_fscore_support, |
| roc_auc_score, classification_report, confusion_matrix |
| ) |
| from tqdm import tqdm |
| import json |
| import time |
|
|
| from model import MultimodalFraudDetector |
| from data_loader import FraudPaperDataset, collate_fn |
|
|
|
|
| def compute_metrics(predictions, labels, probs): |
| preds = np.argmax(predictions, axis=1) |
| accuracy = accuracy_score(labels, preds) |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0) |
| try: |
| auc = roc_auc_score(labels, probs[:, 1]) |
| except: |
| auc = 0.5 |
| return { |
| 'accuracy': accuracy, |
| 'precision': precision, |
| 'recall': recall, |
| 'f1': f1, |
| 'auc': auc |
| } |
|
|
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, scaler, class_weights): |
| model.train() |
| total_loss = 0 |
| all_preds, all_labels, all_probs = [], [], [] |
| start_time = time.time() |
| |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch}") |
| for batch in pbar: |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| tabular = batch['tabular_features'].to(device) |
| metadata = batch['metadata_features'].to(device) |
| labels = batch['labels'].to(device) |
| |
| optimizer.zero_grad() |
| |
| with autocast(): |
| outputs = model( |
| text_input_ids=input_ids, |
| text_attention_mask=attention_mask, |
| tabular_features=tabular, |
| metadata_features=metadata |
| ) |
| logits = outputs['logits'] |
| ce_loss = nn.CrossEntropyLoss(weight=class_weights)(logits, labels) |
| |
| |
| modality_scores = outputs['modality_scores'] |
| consistency_loss = torch.mean((modality_scores - 0.5) ** 2) * 0.1 |
| |
| |
| anomaly_score = outputs['anomaly_score'] |
| fraud_mask = labels == 1 |
| if fraud_mask.any(): |
| anomaly_loss = torch.mean((anomaly_score[fraud_mask] - 1.0) ** 2) |
| anomaly_loss += torch.mean((anomaly_score[~fraud_mask] - 0.0) ** 2) |
| else: |
| anomaly_loss = torch.tensor(0.0, device=device) |
| |
| loss = ce_loss + consistency_loss + 0.1 * anomaly_loss |
| |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| |
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1).detach().cpu().numpy() |
| all_preds.append(logits.detach().cpu().numpy()) |
| all_labels.append(labels.cpu().numpy()) |
| all_probs.append(probs) |
| pbar.set_postfix({'loss': loss.item()}) |
| |
| elapsed = time.time() - start_time |
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
| all_probs = np.concatenate(all_probs) |
| metrics = compute_metrics(all_preds, all_labels, all_probs) |
| metrics['loss'] = total_loss / len(dataloader) |
| metrics['time'] = elapsed |
| return metrics |
|
|
|
|
| def evaluate(model, dataloader, device): |
| model.eval() |
| total_loss = 0 |
| all_preds, all_labels, all_probs = [], [], [] |
| all_embeddings, all_anomaly = [], [] |
| |
| with torch.no_grad(): |
| for batch in tqdm(dataloader, desc="Evaluating"): |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| tabular = batch['tabular_features'].to(device) |
| metadata = batch['metadata_features'].to(device) |
| labels = batch['labels'].to(device) |
| |
| with autocast(): |
| outputs = model( |
| text_input_ids=input_ids, |
| text_attention_mask=attention_mask, |
| tabular_features=tabular, |
| metadata_features=metadata |
| ) |
| logits = outputs['logits'] |
| loss = nn.CrossEntropyLoss()(logits, labels) |
| |
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1).cpu().numpy() |
| all_preds.append(logits.cpu().numpy()) |
| all_labels.append(labels.cpu().numpy()) |
| all_probs.append(probs) |
| all_embeddings.append(outputs['fused_embedding'].cpu().numpy()) |
| all_anomaly.append(outputs['anomaly_score'].cpu().numpy()) |
| |
| all_preds = np.concatenate(all_preds) |
| all_labels = np.concatenate(all_labels) |
| all_probs = np.concatenate(all_probs) |
| all_embeddings = np.concatenate(all_embeddings) |
| all_anomaly = np.concatenate(all_anomaly) |
| metrics = compute_metrics(all_preds, all_labels, all_probs) |
| metrics['loss'] = total_loss / len(dataloader) |
| return metrics, all_embeddings, all_anomaly, all_labels, np.argmax(all_preds, axis=1) |
|
|
|
|
| def main(): |
| print("=" * 70) |
| print("MULTIMODAL FRAUD DETECTION - GPU OPTIMIZED TRAINING") |
| print("=" * 70) |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| if device.type == 'cuda': |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| |
| output_dir = './outputs' |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print("\nLoading dataset...") |
| dataset = FraudPaperDataset("Lihuchen/pubmed_retraction", split="train", max_length=256) |
| |
| |
| train_size = int(0.8 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_ds, val_ds = random_split(dataset, [train_size, val_size]) |
| |
| train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, collate_fn=collate_fn, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, collate_fn=collate_fn, pin_memory=True) |
| |
| print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") |
| |
| |
| train_labels = [dataset.labels[idx] for idx in train_ds.indices] |
| fraud_count = sum(train_labels) |
| authentic_count = len(train_labels) - fraud_count |
| print(f"Class distribution - Authentic: {authentic_count}, Fraudulent: {fraud_count}") |
| weight_fraud = authentic_count / max(fraud_count, 1) |
| class_weights = torch.tensor([1.0, weight_fraud]).to(device) |
| print(f"Class weights: [1.0, {weight_fraud:.2f}]") |
| |
| |
| sample = next(iter(train_loader)) |
| tabular_dim = sample['tabular_features'].shape[1] |
| metadata_dim = sample['metadata_features'].shape[1] |
| print(f"Tabular: {tabular_dim}, Metadata: {metadata_dim}") |
| |
| |
| print("\nBuilding model...") |
| model = MultimodalFraudDetector( |
| text_model="allenai/scibert_scivocab_uncased", |
| tabular_features=tabular_dim, |
| metadata_features=metadata_dim, |
| fused_dim=256, |
| freeze_text_layers=8 |
| ).to(device) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total params: {total_params:,}") |
| print(f"Trainable: {trainable:,}") |
| print(f"Frozen: {total_params - trainable:,}") |
| |
| |
| optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) |
| total_steps = len(train_loader) * 5 |
| warmup = int(total_steps * 0.1) |
| scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup, num_training_steps=total_steps) |
| |
| |
| scaler = GradScaler() |
| |
| |
| print("\n" + "=" * 70) |
| print("TRAINING") |
| print("=" * 70) |
| |
| best_f1 = 0 |
| history = [] |
| |
| for epoch in range(1, 6): |
| print(f"\n>>> Epoch {epoch}/5 <<<") |
| train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, epoch, scaler, class_weights) |
| print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, " |
| f"Prec: {train_metrics['precision']:.4f}, Rec: {train_metrics['recall']:.4f}, " |
| f"F1: {train_metrics['f1']:.4f}, AUC: {train_metrics['auc']:.4f} " |
| f"({train_metrics['time']:.1f}s)") |
| |
| val_metrics, val_emb, val_anom, val_labels, val_preds = evaluate(model, val_loader, device) |
| print(f"Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " |
| f"Prec: {val_metrics['precision']:.4f}, Rec: {val_metrics['recall']:.4f}, " |
| f"F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}") |
| |
| history.append({ |
| 'epoch': epoch, |
| 'train': train_metrics, |
| 'val': val_metrics |
| }) |
| |
| if val_metrics['f1'] > best_f1: |
| best_f1 = val_metrics['f1'] |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'f1': best_f1, |
| 'metrics': val_metrics, |
| 'tabular_dim': tabular_dim, |
| 'metadata_dim': metadata_dim, |
| }, os.path.join(output_dir, 'best_model.pt')) |
| print(f"*** Saved best model (F1: {best_f1:.4f}) ***") |
| |
| |
| print("\n" + "=" * 70) |
| print("FINAL EVALUATION") |
| print("=" * 70) |
| |
| |
| checkpoint = torch.load(os.path.join(output_dir, 'best_model.pt')) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| final_metrics, final_emb, final_anom, final_labels, final_preds = evaluate(model, val_loader, device) |
| |
| print("\nClassification Report:") |
| print(classification_report(final_labels, final_preds, target_names=['authentic', 'fraudulent'], zero_division=0)) |
| |
| print("Confusion Matrix:") |
| cm = confusion_matrix(final_labels, final_preds) |
| print(f" Predicted") |
| print(f" Auth Fraud") |
| print(f"Actual Auth [{cm[0,0]:4d} {cm[0,1]:4d}]") |
| print(f" Fraud [{cm[1,0]:4d} {cm[1,1]:4d}]") |
| |
| print(f"\nFinal Metrics:") |
| print(f" Accuracy: {final_metrics['accuracy']:.4f}") |
| print(f" Precision: {final_metrics['precision']:.4f}") |
| print(f" Recall: {final_metrics['recall']:.4f}") |
| print(f" F1-Score: {final_metrics['f1']:.4f}") |
| print(f" AUC: {final_metrics['auc']:.4f}") |
| |
| |
| np.save(os.path.join(output_dir, 'val_embeddings.npy'), final_emb) |
| np.save(os.path.join(output_dir, 'val_anomaly.npy'), final_anom) |
| np.save(os.path.join(output_dir, 'val_labels.npy'), final_labels) |
| np.save(os.path.join(output_dir, 'val_preds.npy'), final_preds) |
| |
| with open(os.path.join(output_dir, 'metrics.json'), 'w') as f: |
| json.dump({ |
| 'best_f1': best_f1, |
| 'final_metrics': final_metrics, |
| 'history': history, |
| 'model_info': { |
| 'total_params': total_params, |
| 'trainable_params': trainable, |
| 'tabular_dim': tabular_dim, |
| 'metadata_dim': metadata_dim, |
| } |
| }, f, indent=2) |
| |
| print(f"\n{'=' * 70}") |
| print(f"TRAINING COMPLETE!") |
| print(f"Best F1: {best_f1:.4f}") |
| print(f"All outputs saved to: {output_dir}") |
| print(f"{'=' * 70}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|