import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.nn import MessagePassing import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt import itertools import random import time import os from loguru import logger import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, precision_recall_curve, roc_curve # Global lists for metrics train_losses = [] all_train_metrics = [] all_val_metrics = [] # SafeStandardScaler class SafeStandardScaler(StandardScaler): def transform(self, X): X_std = super().transform(X) X_std = np.nan_to_num(X_std) return X_std # Focal Loss class FocalLoss(nn.Module): def __init__(self, gamma=2.0, alpha=0.25): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha def forward(self, logits, targets): bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') pt = torch.exp(-bce) loss = self.alpha * (1 - pt) ** self.gamma * bce return loss.mean() # Preprocess data def preprocess_data(df_final): edge_columns = ['post_length', 'sentiment_score', 'create_hour', 'time_since_prev_post', 'lexical_similarity'] for col in edge_columns: df_final[col] = df_final[col].fillna(df_final[col].mean()) user_features = df_final.groupby('sec_id').agg({ 'create_days_since_creation': 'max', 'topic': lambda x: len(set(x)), 'post_length': 'mean', 'sentiment_score': 'mean', 'lexical_diversity': 'mean', 'is_fake': 'first' }).reset_index() user_features['create_days_since_creation'] = user_features['create_days_since_creation'].clip(lower=1) user_features['posting_frequency'] = df_final.groupby('sec_id').size() / user_features['create_days_since_creation'] user_node_features = user_features[[ 'posting_frequency', 'topic', 'post_length', 'sentiment_score', 'lexical_diversity' ]].fillna(0).values scaler_user = SafeStandardScaler() user_node_features = scaler_user.fit_transform(user_node_features) user_node_features = np.hstack([user_node_features, np.zeros((user_node_features.shape[0], 1))]) user_id_map = {sid: idx for idx, sid in enumerate(user_features['sec_id'])} num_users = len(user_id_map) topic_features = df_final.groupby('topic').agg({ 'topic': 'count', 'sentiment_score': ['mean', 'var'], 'digg_count': 'mean', 'comment_count': 'mean', 'share_count': 'mean' }).reset_index() topic_features.columns = [ 'topic', 'popularity', 'sentiment_mean', 'sentiment_var', 'digg_count_mean', 'comment_count_mean', 'share_count_mean' ] topic_features['sentiment_var'] = topic_features['sentiment_var'].fillna(0) topic_node_features = topic_features[[ 'popularity', 'sentiment_mean', 'sentiment_var', 'digg_count_mean', 'comment_count_mean', 'share_count_mean' ]].fillna(0).values scaler_topic = SafeStandardScaler() topic_node_features = scaler_topic.fit_transform(topic_node_features) topic_id_map = {tid: idx + num_users for idx, tid in enumerate(topic_features['topic'])} edge_index, edge_features = [], [] for _, row in df_final.iterrows(): user_idx = user_id_map[row['sec_id']] topic_idx = topic_id_map[row['topic']] edge_index.extend([[user_idx, topic_idx], [topic_idx, user_idx]]) edge_attr = [row[col] for col in edge_columns] edge_features.extend([edge_attr, edge_attr]) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() edge_features_np = np.array(edge_features, dtype=np.float32) scaler_edge = SafeStandardScaler() edge_features = torch.tensor(scaler_edge.fit_transform(edge_features_np), dtype=torch.float32) node_features = np.vstack([user_node_features, topic_node_features]) node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32) user_labels = torch.tensor(user_features['is_fake'].values, dtype=torch.float32) position_vectors = torch.randn(node_features.shape[0], 3) data = Data( x=node_features, edge_index=edge_index, edge_attr=edge_features, y=user_labels, pos=position_vectors ) data.num_users = num_users assert not torch.isnan(data.x).any(), "NaNs in node features" assert not torch.isnan(data.edge_attr).any(), "NaNs in edge features" return data, user_id_map, topic_id_map # EnergyMPNN Model class EnergyMPNNLayer(MessagePassing): def __init__(self, input_node_dim, edge_dim, hidden_dim, pos_dim, dropout=0.4): super(EnergyMPNNLayer, self).__init__(aggr='mean') self.input_node_dim = input_node_dim self.edge_dim = edge_dim self.hidden_dim = hidden_dim self.pos_dim = pos_dim self.dropout = dropout self.user_mlp = nn.Sequential( nn.Linear(input_node_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), ) self.user_residual = nn.Linear(input_node_dim, hidden_dim) self.topic_mlp = nn.Sequential( nn.Linear(input_node_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), ) self.topic_residual = nn.Linear(input_node_dim, hidden_dim) message_input_dim = 2 * hidden_dim + edge_dim + 1 self.message_mlp = nn.Sequential( nn.Linear(message_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), ) self.update_mlp = nn.Sequential( nn.Linear(hidden_dim + hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), ) self.update_residual = nn.Linear(hidden_dim, hidden_dim) def forward(self, x, edge_index, edge_attr, pos, num_users): user_x = x[:num_users] topic_x = x[num_users:] user_residual = self.user_residual(user_x) h_user = self.user_mlp(user_x) h_user = h_user + user_residual topic_residual = self.topic_residual(topic_x) h_topic = self.topic_mlp(topic_x) h_topic = h_topic + topic_residual h = torch.cat([h_user, h_topic], dim=0) h = self.propagate(edge_index, x=h, edge_attr=edge_attr, pos=pos) return h def message(self, x_i, x_j, edge_attr, pos_i, pos_j): dist = torch.norm(pos_i - pos_j, p=2, dim=-1, keepdim=True) message_input = torch.cat([x_i, x_j, edge_attr, dist], dim=-1) message = self.message_mlp(message_input) return message def update(self, aggr_out, x): update_input = torch.cat([x, aggr_out], dim=-1) update_residual = self.update_residual(x) h = self.update_mlp(update_input) + update_residual return h class EnergyMPNN(nn.Module): def __init__(self, input_node_dim=6, edge_dim=5, hidden_dim=64, pos_dim=3, num_layers=2, dropout=0.4): super(EnergyMPNN, self).__init__() self.input_node_dim = input_node_dim self.edge_dim = edge_dim self.hidden_dim = hidden_dim self.pos_dim = pos_dim self.num_layers = num_layers self.dropout = dropout self.layers = nn.ModuleList() for i in range(num_layers): layer_input_dim = input_node_dim if i == 0 else hidden_dim self.layers.append(EnergyMPNNLayer( input_node_dim=layer_input_dim, edge_dim=edge_dim, hidden_dim=hidden_dim, pos_dim=pos_dim, dropout=dropout )) score_input_dim = 2 * hidden_dim + edge_dim + 1 self.score_mlp = nn.Sequential( nn.Linear(score_input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, edge_index, edge_attr, pos, num_users): h = x for i, layer in enumerate(self.layers): h = layer(h, edge_index, edge_attr, pos, num_users) u, t = edge_index dist = torch.norm(pos[u] - pos[t], p=2, dim=-1) edge_input = torch.cat([ h[u], h[t], edge_attr, dist.unsqueeze(-1) ], dim=-1) edge_scores = self.score_mlp(edge_input).squeeze() user_scores = torch.zeros(num_users, device=x.device) edge_counts = torch.zeros(num_users, device=x.device) user_mask = (edge_index[0] < num_users) user_indices = edge_index[0][user_mask] user_scores.scatter_add_(0, user_indices, edge_scores[user_mask]) edge_counts.scatter_add_(0, user_indices, torch.ones_like(user_indices, dtype=torch.float)) edge_counts = edge_counts.clamp(min=1) user_scores = user_scores / edge_counts return user_scores, edge_scores # Global lists for metrics all_train_metrics = [] all_val_metrics = [] # Focal Loss (placeholder) class FocalLoss(nn.Module): def __init__(self, gamma=2.0, alpha=0.25): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha def forward(self, logits, targets): bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') pt = torch.exp(-bce) loss = self.alpha * (1 - pt) ** self.gamma * bce return loss.mean() # Calculate metrics (unchanged) def calculate_metrics(probs, labels): precision, recall, thresholds = precision_recall_curve(labels.cpu(), probs.cpu()) f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10) optimal_idx = np.argmax(f1_scores) optimal_threshold = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5 preds = (probs > optimal_threshold).float() return { 'acc': accuracy_score(labels.cpu(), preds.cpu()), 'f1': f1_score(labels.cpu(), preds.cpu()), 'auc': roc_auc_score(labels.cpu(), probs.cpu()), 'precision': precision_score(labels.cpu(), preds.cpu()), 'recall': recall_score(labels.cpu(), preds.cpu()), 'threshold': optimal_threshold } # Train model (modified with best configuration, saving plots and metrics) def train_model(data, save_path, test_size,num_epochs=150): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data = data.to(device) # Best configuration hidden_dim = 64 num_layers = 2 dropout = 0.2 lr = 0.001 gamma = 1.0 weight_decay = 0.0005 model = EnergyMPNN( input_node_dim=6, edge_dim=5, hidden_dim=hidden_dim, pos_dim=3, num_layers=num_layers, dropout=dropout ).to(device) # Create save directories os.makedirs(save_path, exist_ok=True) plot_dir = os.path.join(save_path, 'plots') os.makedirs(plot_dir, exist_ok=True) user_indices = torch.arange(data.num_users, device='cpu') y_np = data.y.cpu().numpy() try: train_idx, test_idx = train_test_split( user_indices.numpy(), test_size=test_size, stratify=y_np, random_state=42 ) val_idx, test_idx = train_test_split( test_idx, test_size=0.5, stratify=y_np[test_idx], random_state=42 ) except Exception as e: print(f"Error in train/val split: {e}") return None, [], [], {}, 0 data.train_mask = torch.zeros(data.num_users, dtype=bool, device=device) data.val_mask = torch.zeros(data.num_users, dtype=bool, device=device) data.test_mask = torch.zeros(data.num_users, dtype=bool, device=device) data.train_mask[torch.tensor(train_idx, device=device)] = True data.val_mask[torch.tensor(val_idx, device=device)] = True data.test_mask[torch.tensor(test_idx, device=device)] = True criterion = FocalLoss(gamma=gamma, alpha=0.25) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15) best_val_auc = 0 epochs_no_improve = 0 patience = 20 train_losses, train_metrics_list, val_metrics_list = [], [], [] start_time = time.time() for epoch in range(num_epochs): model.train() optimizer.zero_grad() user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) loss = criterion(user_scores[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() model.eval() with torch.no_grad(): user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) probs = torch.sigmoid(user_scores) train_metrics = calculate_metrics(probs[data.train_mask], data.y[data.train_mask]) val_metrics = calculate_metrics(probs[data.val_mask], data.y[data.val_mask]) train_losses.append(loss.item()) train_metrics_list.append(train_metrics) val_metrics_list.append(val_metrics) all_train_metrics.append(train_metrics) all_val_metrics.append(val_metrics) scheduler.step(val_metrics['auc']) os.makedirs(os.path.join(save_path,"model_checkpoint"), exist_ok=True) if val_metrics['auc'] > best_val_auc: best_val_auc = val_metrics['auc'] torch.save(model.state_dict(), os.path.join(save_path,'model_checkpoint','best_model.pth')) epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= patience: print(f"Early stopping at epoch {epoch+1}") break print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss.item():.4f} | Val AUC: {val_metrics['auc']:.4f} | Val F1: {val_metrics['f1']:.4f}") training_time = time.time() - start_time model.load_state_dict(torch.load(os.path.join(save_path, 'model_checkpoint','best_model.pth'))) model.eval() with torch.no_grad(): user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) probs = torch.sigmoid(user_scores[data.test_mask]) test_metrics = calculate_metrics(probs, data.y[data.test_mask]) print("\nFinal Test Metrics:") print(f"AUC: {test_metrics['auc']:.4f} | F1: {test_metrics['f1']:.4f} | Accuracy: {test_metrics['acc']:.4f}") print(f"Precision: {test_metrics['precision']:.4f} | Recall: {test_metrics['recall']:.4f} | Threshold: {test_metrics['threshold']:.4f}") print(f"Training Time: {training_time:.2f} seconds") # Save metrics to CSVs train_metrics_df = pd.DataFrame([ { 'epoch': epoch + 1, 'loss': train_losses[epoch], 'auc': tm['auc'], 'f1': tm['f1'], 'accuracy': tm['acc'], 'precision': tm['precision'], 'recall': tm['recall'], 'threshold': tm['threshold'] } for epoch, tm in enumerate(train_metrics_list) ]) train_metrics_df.to_csv(os.path.join(save_path, 'train_metrics.csv'), index=False) val_metrics_df = pd.DataFrame([ { 'epoch': epoch + 1, 'auc': vm['auc'], 'f1': vm['f1'], 'accuracy': vm['acc'], 'precision': vm['precision'], 'recall': vm['recall'], 'threshold': vm['threshold'] } for epoch, vm in enumerate(val_metrics_list) ]) val_metrics_df.to_csv(os.path.join(save_path, 'val_metrics.csv'), index=False) test_metrics_df = pd.DataFrame([test_metrics]) test_metrics_df.to_csv(os.path.join(save_path, 'test_metrics.csv'), index=False) # Save training summary as CSV training_summary = pd.DataFrame([{ 'hidden_dim': hidden_dim, 'num_layers': num_layers, 'dropout': dropout, 'lr': lr, 'gamma': gamma, 'weight_decay': weight_decay, 'num_epochs': num_epochs, 'training_time_seconds': training_time, 'test_auc': test_metrics['auc'], 'test_f1': test_metrics['f1'], 'test_accuracy': test_metrics['acc'], 'test_precision': test_metrics['precision'], 'test_recall': test_metrics['recall'], 'test_threshold': test_metrics['threshold'] }]) training_summary.to_csv(os.path.join(save_path, 'training_summary.csv'), index=False) # Create and save plots # 1. Metrics over epochs plt.figure(figsize=(15, 10)) plt.subplot(2, 4, 1) plt.plot(train_metrics_df['epoch'], train_metrics_df['loss'], label='Train Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training Loss') plt.legend() plt.subplot(2, 4, 2) plt.plot(val_metrics_df['epoch'], val_metrics_df['auc'], label='Val AUC') plt.xlabel('Epoch') plt.ylabel('AUC') plt.title('Validation AUC') plt.legend() plt.subplot(2, 4, 3) plt.plot(val_metrics_df['epoch'], val_metrics_df['f1'], label='Val F1') plt.xlabel('Epoch') plt.ylabel('F1') plt.title('Validation F1') plt.legend() plt.subplot(2, 4, 4) plt.plot(val_metrics_df['epoch'], val_metrics_df['accuracy'], label='Val Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.title('Validation Accuracy') plt.legend() plt.subplot(2, 4, 5) plt.plot(val_metrics_df['epoch'], val_metrics_df['precision'], label='Val Precision') plt.xlabel('Epoch') plt.ylabel('Precision') plt.title('Validation Precision') plt.legend() plt.subplot(2, 4, 6) plt.plot(val_metrics_df['epoch'], val_metrics_df['recall'], label='Val Recall') plt.xlabel('Epoch') plt.ylabel('Recall') plt.title('Validation Recall') plt.legend() plt.subplot(2, 4, 7) plt.plot(val_metrics_df['epoch'], val_metrics_df['threshold'], label='Val Threshold') plt.xlabel('Epoch') plt.ylabel('Threshold') plt.title('Validation Threshold') plt.legend() plt.tight_layout() plt.savefig(os.path.join(plot_dir, 'metrics_over_epochs.png')) plt.close() # 2. Precision-Recall Curve (Test Set) precision, recall, _ = precision_recall_curve(data.y[data.test_mask].cpu(), probs.cpu()) plt.figure(figsize=(8, 6)) plt.plot(recall, precision, label='Precision-Recall Curve') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve (Test Set)') plt.legend() plt.savefig(os.path.join(plot_dir, 'precision_recall_curve.png')) plt.close() # 3. ROC Curve (Test Set) fpr, tpr, _ = roc_curve(data.y[data.test_mask].cpu(), probs.cpu()) plt.figure(figsize=(8, 6)) plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {test_metrics["auc"]:.4f})') plt.plot([0, 1], [0, 1], 'k--', label='Random') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve (Test Set)') plt.legend() plt.savefig(os.path.join(plot_dir, 'roc_curve.png')) plt.close() return model, train_losses, val_metrics_list[-1], test_metrics, training_time # Placeholder for EnergyMPNN # Main execution def trainer(save_path,test_size,num_epochs,df_final): try: # Placeholder for df_final data, _, _ = preprocess_data(df_final) logger.info(f"Data shapes: x={data.x.shape}, edge_index={data.edge_index.shape}, edge_attr={data.edge_attr.shape}, y={data.y.shape}, pos={data.pos.shape}") model, train_losses, val_metrics, test_metrics, training_time = train_model( data, save_path,test_size, num_epochs=num_epochs ) logger.info("Training completed.") except Exception as e: logger.info(f"Error: {str(e)}")