Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.data import Data | |
| from torch_geometric.nn import MessagePassing | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.model_selection import train_test_split | |
| import matplotlib.pyplot as plt | |
| import itertools | |
| import random | |
| import time | |
| import os | |
| from loguru import logger | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, precision_recall_curve, roc_curve | |
| # Global lists for metrics | |
| train_losses = [] | |
| all_train_metrics = [] | |
| all_val_metrics = [] | |
| # SafeStandardScaler | |
| class SafeStandardScaler(StandardScaler): | |
| def transform(self, X): | |
| X_std = super().transform(X) | |
| X_std = np.nan_to_num(X_std) | |
| return X_std | |
| # Focal Loss | |
| class FocalLoss(nn.Module): | |
| def __init__(self, gamma=2.0, alpha=0.25): | |
| super(FocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| def forward(self, logits, targets): | |
| bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') | |
| pt = torch.exp(-bce) | |
| loss = self.alpha * (1 - pt) ** self.gamma * bce | |
| return loss.mean() | |
| # Preprocess data | |
| def preprocess_data(df_final): | |
| edge_columns = ['post_length', 'sentiment_score', 'create_hour', | |
| 'time_since_prev_post', 'lexical_similarity'] | |
| for col in edge_columns: | |
| df_final[col] = df_final[col].fillna(df_final[col].mean()) | |
| user_features = df_final.groupby('sec_id').agg({ | |
| 'create_days_since_creation': 'max', | |
| 'topic': lambda x: len(set(x)), | |
| 'post_length': 'mean', | |
| 'sentiment_score': 'mean', | |
| 'lexical_diversity': 'mean', | |
| 'is_fake': 'first' | |
| }).reset_index() | |
| user_features['create_days_since_creation'] = user_features['create_days_since_creation'].clip(lower=1) | |
| user_features['posting_frequency'] = df_final.groupby('sec_id').size() / user_features['create_days_since_creation'] | |
| user_node_features = user_features[[ | |
| 'posting_frequency', 'topic', 'post_length', | |
| 'sentiment_score', 'lexical_diversity' | |
| ]].fillna(0).values | |
| scaler_user = SafeStandardScaler() | |
| user_node_features = scaler_user.fit_transform(user_node_features) | |
| user_node_features = np.hstack([user_node_features, np.zeros((user_node_features.shape[0], 1))]) | |
| user_id_map = {sid: idx for idx, sid in enumerate(user_features['sec_id'])} | |
| num_users = len(user_id_map) | |
| topic_features = df_final.groupby('topic').agg({ | |
| 'topic': 'count', | |
| 'sentiment_score': ['mean', 'var'], | |
| 'digg_count': 'mean', | |
| 'comment_count': 'mean', | |
| 'share_count': 'mean' | |
| }).reset_index() | |
| topic_features.columns = [ | |
| 'topic', 'popularity', 'sentiment_mean', 'sentiment_var', | |
| 'digg_count_mean', 'comment_count_mean', 'share_count_mean' | |
| ] | |
| topic_features['sentiment_var'] = topic_features['sentiment_var'].fillna(0) | |
| topic_node_features = topic_features[[ | |
| 'popularity', 'sentiment_mean', 'sentiment_var', | |
| 'digg_count_mean', 'comment_count_mean', 'share_count_mean' | |
| ]].fillna(0).values | |
| scaler_topic = SafeStandardScaler() | |
| topic_node_features = scaler_topic.fit_transform(topic_node_features) | |
| topic_id_map = {tid: idx + num_users for idx, tid in enumerate(topic_features['topic'])} | |
| edge_index, edge_features = [], [] | |
| for _, row in df_final.iterrows(): | |
| user_idx = user_id_map[row['sec_id']] | |
| topic_idx = topic_id_map[row['topic']] | |
| edge_index.extend([[user_idx, topic_idx], [topic_idx, user_idx]]) | |
| edge_attr = [row[col] for col in edge_columns] | |
| edge_features.extend([edge_attr, edge_attr]) | |
| edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() | |
| edge_features_np = np.array(edge_features, dtype=np.float32) | |
| scaler_edge = SafeStandardScaler() | |
| edge_features = torch.tensor(scaler_edge.fit_transform(edge_features_np), dtype=torch.float32) | |
| node_features = np.vstack([user_node_features, topic_node_features]) | |
| node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32) | |
| user_labels = torch.tensor(user_features['is_fake'].values, dtype=torch.float32) | |
| position_vectors = torch.randn(node_features.shape[0], 3) | |
| data = Data( | |
| x=node_features, | |
| edge_index=edge_index, | |
| edge_attr=edge_features, | |
| y=user_labels, | |
| pos=position_vectors | |
| ) | |
| data.num_users = num_users | |
| assert not torch.isnan(data.x).any(), "NaNs in node features" | |
| assert not torch.isnan(data.edge_attr).any(), "NaNs in edge features" | |
| return data, user_id_map, topic_id_map | |
| # EnergyMPNN Model | |
| class EnergyMPNNLayer(MessagePassing): | |
| def __init__(self, input_node_dim, edge_dim, hidden_dim, pos_dim, dropout=0.4): | |
| super(EnergyMPNNLayer, self).__init__(aggr='mean') | |
| self.input_node_dim = input_node_dim | |
| self.edge_dim = edge_dim | |
| self.hidden_dim = hidden_dim | |
| self.pos_dim = pos_dim | |
| self.dropout = dropout | |
| self.user_mlp = nn.Sequential( | |
| nn.Linear(input_node_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| self.user_residual = nn.Linear(input_node_dim, hidden_dim) | |
| self.topic_mlp = nn.Sequential( | |
| nn.Linear(input_node_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| self.topic_residual = nn.Linear(input_node_dim, hidden_dim) | |
| message_input_dim = 2 * hidden_dim + edge_dim + 1 | |
| self.message_mlp = nn.Sequential( | |
| nn.Linear(message_input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| self.update_mlp = nn.Sequential( | |
| nn.Linear(hidden_dim + hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| self.update_residual = nn.Linear(hidden_dim, hidden_dim) | |
| def forward(self, x, edge_index, edge_attr, pos, num_users): | |
| user_x = x[:num_users] | |
| topic_x = x[num_users:] | |
| user_residual = self.user_residual(user_x) | |
| h_user = self.user_mlp(user_x) | |
| h_user = h_user + user_residual | |
| topic_residual = self.topic_residual(topic_x) | |
| h_topic = self.topic_mlp(topic_x) | |
| h_topic = h_topic + topic_residual | |
| h = torch.cat([h_user, h_topic], dim=0) | |
| h = self.propagate(edge_index, x=h, edge_attr=edge_attr, pos=pos) | |
| return h | |
| def message(self, x_i, x_j, edge_attr, pos_i, pos_j): | |
| dist = torch.norm(pos_i - pos_j, p=2, dim=-1, keepdim=True) | |
| message_input = torch.cat([x_i, x_j, edge_attr, dist], dim=-1) | |
| message = self.message_mlp(message_input) | |
| return message | |
| def update(self, aggr_out, x): | |
| update_input = torch.cat([x, aggr_out], dim=-1) | |
| update_residual = self.update_residual(x) | |
| h = self.update_mlp(update_input) + update_residual | |
| return h | |
| class EnergyMPNN(nn.Module): | |
| def __init__(self, input_node_dim=6, edge_dim=5, hidden_dim=64, pos_dim=3, num_layers=2, dropout=0.4): | |
| super(EnergyMPNN, self).__init__() | |
| self.input_node_dim = input_node_dim | |
| self.edge_dim = edge_dim | |
| self.hidden_dim = hidden_dim | |
| self.pos_dim = pos_dim | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| self.layers = nn.ModuleList() | |
| for i in range(num_layers): | |
| layer_input_dim = input_node_dim if i == 0 else hidden_dim | |
| self.layers.append(EnergyMPNNLayer( | |
| input_node_dim=layer_input_dim, | |
| edge_dim=edge_dim, | |
| hidden_dim=hidden_dim, | |
| pos_dim=pos_dim, | |
| dropout=dropout | |
| )) | |
| score_input_dim = 2 * hidden_dim + edge_dim + 1 | |
| self.score_mlp = nn.Sequential( | |
| nn.Linear(score_input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim // 2), | |
| nn.BatchNorm1d(hidden_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim // 2, 1) | |
| ) | |
| def forward(self, x, edge_index, edge_attr, pos, num_users): | |
| h = x | |
| for i, layer in enumerate(self.layers): | |
| h = layer(h, edge_index, edge_attr, pos, num_users) | |
| u, t = edge_index | |
| dist = torch.norm(pos[u] - pos[t], p=2, dim=-1) | |
| edge_input = torch.cat([ | |
| h[u], h[t], edge_attr, dist.unsqueeze(-1) | |
| ], dim=-1) | |
| edge_scores = self.score_mlp(edge_input).squeeze() | |
| user_scores = torch.zeros(num_users, device=x.device) | |
| edge_counts = torch.zeros(num_users, device=x.device) | |
| user_mask = (edge_index[0] < num_users) | |
| user_indices = edge_index[0][user_mask] | |
| user_scores.scatter_add_(0, user_indices, edge_scores[user_mask]) | |
| edge_counts.scatter_add_(0, user_indices, torch.ones_like(user_indices, dtype=torch.float)) | |
| edge_counts = edge_counts.clamp(min=1) | |
| user_scores = user_scores / edge_counts | |
| return user_scores, edge_scores | |
| # Global lists for metrics | |
| all_train_metrics = [] | |
| all_val_metrics = [] | |
| # Focal Loss (placeholder) | |
| class FocalLoss(nn.Module): | |
| def __init__(self, gamma=2.0, alpha=0.25): | |
| super(FocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.alpha = alpha | |
| def forward(self, logits, targets): | |
| bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') | |
| pt = torch.exp(-bce) | |
| loss = self.alpha * (1 - pt) ** self.gamma * bce | |
| return loss.mean() | |
| # Calculate metrics (unchanged) | |
| def calculate_metrics(probs, labels): | |
| precision, recall, thresholds = precision_recall_curve(labels.cpu(), probs.cpu()) | |
| f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10) | |
| optimal_idx = np.argmax(f1_scores) | |
| optimal_threshold = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5 | |
| preds = (probs > optimal_threshold).float() | |
| return { | |
| 'acc': accuracy_score(labels.cpu(), preds.cpu()), | |
| 'f1': f1_score(labels.cpu(), preds.cpu()), | |
| 'auc': roc_auc_score(labels.cpu(), probs.cpu()), | |
| 'precision': precision_score(labels.cpu(), preds.cpu()), | |
| 'recall': recall_score(labels.cpu(), preds.cpu()), | |
| 'threshold': optimal_threshold | |
| } | |
| # Train model (modified with best configuration, saving plots and metrics) | |
| def train_model(data, save_path, test_size,num_epochs=150): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| data = data.to(device) | |
| # Best configuration | |
| hidden_dim = 64 | |
| num_layers = 2 | |
| dropout = 0.2 | |
| lr = 0.001 | |
| gamma = 1.0 | |
| weight_decay = 0.0005 | |
| model = EnergyMPNN( | |
| input_node_dim=6, | |
| edge_dim=5, | |
| hidden_dim=hidden_dim, | |
| pos_dim=3, | |
| num_layers=num_layers, | |
| dropout=dropout | |
| ).to(device) | |
| # Create save directories | |
| os.makedirs(save_path, exist_ok=True) | |
| plot_dir = os.path.join(save_path, 'plots') | |
| os.makedirs(plot_dir, exist_ok=True) | |
| user_indices = torch.arange(data.num_users, device='cpu') | |
| y_np = data.y.cpu().numpy() | |
| try: | |
| train_idx, test_idx = train_test_split( | |
| user_indices.numpy(), | |
| test_size=test_size, | |
| stratify=y_np, | |
| random_state=42 | |
| ) | |
| val_idx, test_idx = train_test_split( | |
| test_idx, | |
| test_size=0.5, | |
| stratify=y_np[test_idx], | |
| random_state=42 | |
| ) | |
| except Exception as e: | |
| print(f"Error in train/val split: {e}") | |
| return None, [], [], {}, 0 | |
| data.train_mask = torch.zeros(data.num_users, dtype=bool, device=device) | |
| data.val_mask = torch.zeros(data.num_users, dtype=bool, device=device) | |
| data.test_mask = torch.zeros(data.num_users, dtype=bool, device=device) | |
| data.train_mask[torch.tensor(train_idx, device=device)] = True | |
| data.val_mask[torch.tensor(val_idx, device=device)] = True | |
| data.test_mask[torch.tensor(test_idx, device=device)] = True | |
| criterion = FocalLoss(gamma=gamma, alpha=0.25) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15) | |
| best_val_auc = 0 | |
| epochs_no_improve = 0 | |
| patience = 20 | |
| train_losses, train_metrics_list, val_metrics_list = [], [], [] | |
| start_time = time.time() | |
| for epoch in range(num_epochs): | |
| model.train() | |
| optimizer.zero_grad() | |
| user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) | |
| loss = criterion(user_scores[data.train_mask], data.y[data.train_mask]) | |
| loss.backward() | |
| optimizer.step() | |
| model.eval() | |
| with torch.no_grad(): | |
| user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) | |
| probs = torch.sigmoid(user_scores) | |
| train_metrics = calculate_metrics(probs[data.train_mask], data.y[data.train_mask]) | |
| val_metrics = calculate_metrics(probs[data.val_mask], data.y[data.val_mask]) | |
| train_losses.append(loss.item()) | |
| train_metrics_list.append(train_metrics) | |
| val_metrics_list.append(val_metrics) | |
| all_train_metrics.append(train_metrics) | |
| all_val_metrics.append(val_metrics) | |
| scheduler.step(val_metrics['auc']) | |
| os.makedirs(os.path.join(save_path,"model_checkpoint"), exist_ok=True) | |
| if val_metrics['auc'] > best_val_auc: | |
| best_val_auc = val_metrics['auc'] | |
| torch.save(model.state_dict(), os.path.join(save_path,'model_checkpoint','best_model.pth')) | |
| epochs_no_improve = 0 | |
| else: | |
| epochs_no_improve += 1 | |
| if epochs_no_improve >= patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss.item():.4f} | Val AUC: {val_metrics['auc']:.4f} | Val F1: {val_metrics['f1']:.4f}") | |
| training_time = time.time() - start_time | |
| model.load_state_dict(torch.load(os.path.join(save_path, 'model_checkpoint','best_model.pth'))) | |
| model.eval() | |
| with torch.no_grad(): | |
| user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users) | |
| probs = torch.sigmoid(user_scores[data.test_mask]) | |
| test_metrics = calculate_metrics(probs, data.y[data.test_mask]) | |
| print("\nFinal Test Metrics:") | |
| print(f"AUC: {test_metrics['auc']:.4f} | F1: {test_metrics['f1']:.4f} | Accuracy: {test_metrics['acc']:.4f}") | |
| print(f"Precision: {test_metrics['precision']:.4f} | Recall: {test_metrics['recall']:.4f} | Threshold: {test_metrics['threshold']:.4f}") | |
| print(f"Training Time: {training_time:.2f} seconds") | |
| # Save metrics to CSVs | |
| train_metrics_df = pd.DataFrame([ | |
| { | |
| 'epoch': epoch + 1, | |
| 'loss': train_losses[epoch], | |
| 'auc': tm['auc'], | |
| 'f1': tm['f1'], | |
| 'accuracy': tm['acc'], | |
| 'precision': tm['precision'], | |
| 'recall': tm['recall'], | |
| 'threshold': tm['threshold'] | |
| } | |
| for epoch, tm in enumerate(train_metrics_list) | |
| ]) | |
| train_metrics_df.to_csv(os.path.join(save_path, 'train_metrics.csv'), index=False) | |
| val_metrics_df = pd.DataFrame([ | |
| { | |
| 'epoch': epoch + 1, | |
| 'auc': vm['auc'], | |
| 'f1': vm['f1'], | |
| 'accuracy': vm['acc'], | |
| 'precision': vm['precision'], | |
| 'recall': vm['recall'], | |
| 'threshold': vm['threshold'] | |
| } | |
| for epoch, vm in enumerate(val_metrics_list) | |
| ]) | |
| val_metrics_df.to_csv(os.path.join(save_path, 'val_metrics.csv'), index=False) | |
| test_metrics_df = pd.DataFrame([test_metrics]) | |
| test_metrics_df.to_csv(os.path.join(save_path, 'test_metrics.csv'), index=False) | |
| # Save training summary as CSV | |
| training_summary = pd.DataFrame([{ | |
| 'hidden_dim': hidden_dim, | |
| 'num_layers': num_layers, | |
| 'dropout': dropout, | |
| 'lr': lr, | |
| 'gamma': gamma, | |
| 'weight_decay': weight_decay, | |
| 'num_epochs': num_epochs, | |
| 'training_time_seconds': training_time, | |
| 'test_auc': test_metrics['auc'], | |
| 'test_f1': test_metrics['f1'], | |
| 'test_accuracy': test_metrics['acc'], | |
| 'test_precision': test_metrics['precision'], | |
| 'test_recall': test_metrics['recall'], | |
| 'test_threshold': test_metrics['threshold'] | |
| }]) | |
| training_summary.to_csv(os.path.join(save_path, 'training_summary.csv'), index=False) | |
| # Create and save plots | |
| # 1. Metrics over epochs | |
| plt.figure(figsize=(15, 10)) | |
| plt.subplot(2, 4, 1) | |
| plt.plot(train_metrics_df['epoch'], train_metrics_df['loss'], label='Train Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Training Loss') | |
| plt.legend() | |
| plt.subplot(2, 4, 2) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['auc'], label='Val AUC') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('AUC') | |
| plt.title('Validation AUC') | |
| plt.legend() | |
| plt.subplot(2, 4, 3) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['f1'], label='Val F1') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('F1') | |
| plt.title('Validation F1') | |
| plt.legend() | |
| plt.subplot(2, 4, 4) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['accuracy'], label='Val Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.title('Validation Accuracy') | |
| plt.legend() | |
| plt.subplot(2, 4, 5) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['precision'], label='Val Precision') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Precision') | |
| plt.title('Validation Precision') | |
| plt.legend() | |
| plt.subplot(2, 4, 6) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['recall'], label='Val Recall') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Recall') | |
| plt.title('Validation Recall') | |
| plt.legend() | |
| plt.subplot(2, 4, 7) | |
| plt.plot(val_metrics_df['epoch'], val_metrics_df['threshold'], label='Val Threshold') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Threshold') | |
| plt.title('Validation Threshold') | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(plot_dir, 'metrics_over_epochs.png')) | |
| plt.close() | |
| # 2. Precision-Recall Curve (Test Set) | |
| precision, recall, _ = precision_recall_curve(data.y[data.test_mask].cpu(), probs.cpu()) | |
| plt.figure(figsize=(8, 6)) | |
| plt.plot(recall, precision, label='Precision-Recall Curve') | |
| plt.xlabel('Recall') | |
| plt.ylabel('Precision') | |
| plt.title('Precision-Recall Curve (Test Set)') | |
| plt.legend() | |
| plt.savefig(os.path.join(plot_dir, 'precision_recall_curve.png')) | |
| plt.close() | |
| # 3. ROC Curve (Test Set) | |
| fpr, tpr, _ = roc_curve(data.y[data.test_mask].cpu(), probs.cpu()) | |
| plt.figure(figsize=(8, 6)) | |
| plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {test_metrics["auc"]:.4f})') | |
| plt.plot([0, 1], [0, 1], 'k--', label='Random') | |
| plt.xlabel('False Positive Rate') | |
| plt.ylabel('True Positive Rate') | |
| plt.title('ROC Curve (Test Set)') | |
| plt.legend() | |
| plt.savefig(os.path.join(plot_dir, 'roc_curve.png')) | |
| plt.close() | |
| return model, train_losses, val_metrics_list[-1], test_metrics, training_time | |
| # Placeholder for EnergyMPNN | |
| # Main execution | |
| def trainer(save_path,test_size,num_epochs,df_final): | |
| try: | |
| # Placeholder for df_final | |
| data, _, _ = preprocess_data(df_final) | |
| logger.info(f"Data shapes: x={data.x.shape}, edge_index={data.edge_index.shape}, edge_attr={data.edge_attr.shape}, y={data.y.shape}, pos={data.pos.shape}") | |
| model, train_losses, val_metrics, test_metrics, training_time = train_model( | |
| data, save_path,test_size, num_epochs=num_epochs | |
| ) | |
| logger.info("Training completed.") | |
| except Exception as e: | |
| logger.info(f"Error: {str(e)}") |