datasciencesage's picture
model file
5395d59 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import itertools
import random
import time
import os
from loguru import logger
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, precision_recall_curve, roc_curve
# Global lists for metrics
train_losses = []
all_train_metrics = []
all_val_metrics = []
# SafeStandardScaler
class SafeStandardScaler(StandardScaler):
def transform(self, X):
X_std = super().transform(X)
X_std = np.nan_to_num(X_std)
return X_std
# Focal Loss
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
pt = torch.exp(-bce)
loss = self.alpha * (1 - pt) ** self.gamma * bce
return loss.mean()
# Preprocess data
def preprocess_data(df_final):
edge_columns = ['post_length', 'sentiment_score', 'create_hour',
'time_since_prev_post', 'lexical_similarity']
for col in edge_columns:
df_final[col] = df_final[col].fillna(df_final[col].mean())
user_features = df_final.groupby('sec_id').agg({
'create_days_since_creation': 'max',
'topic': lambda x: len(set(x)),
'post_length': 'mean',
'sentiment_score': 'mean',
'lexical_diversity': 'mean',
'is_fake': 'first'
}).reset_index()
user_features['create_days_since_creation'] = user_features['create_days_since_creation'].clip(lower=1)
user_features['posting_frequency'] = df_final.groupby('sec_id').size() / user_features['create_days_since_creation']
user_node_features = user_features[[
'posting_frequency', 'topic', 'post_length',
'sentiment_score', 'lexical_diversity'
]].fillna(0).values
scaler_user = SafeStandardScaler()
user_node_features = scaler_user.fit_transform(user_node_features)
user_node_features = np.hstack([user_node_features, np.zeros((user_node_features.shape[0], 1))])
user_id_map = {sid: idx for idx, sid in enumerate(user_features['sec_id'])}
num_users = len(user_id_map)
topic_features = df_final.groupby('topic').agg({
'topic': 'count',
'sentiment_score': ['mean', 'var'],
'digg_count': 'mean',
'comment_count': 'mean',
'share_count': 'mean'
}).reset_index()
topic_features.columns = [
'topic', 'popularity', 'sentiment_mean', 'sentiment_var',
'digg_count_mean', 'comment_count_mean', 'share_count_mean'
]
topic_features['sentiment_var'] = topic_features['sentiment_var'].fillna(0)
topic_node_features = topic_features[[
'popularity', 'sentiment_mean', 'sentiment_var',
'digg_count_mean', 'comment_count_mean', 'share_count_mean'
]].fillna(0).values
scaler_topic = SafeStandardScaler()
topic_node_features = scaler_topic.fit_transform(topic_node_features)
topic_id_map = {tid: idx + num_users for idx, tid in enumerate(topic_features['topic'])}
edge_index, edge_features = [], []
for _, row in df_final.iterrows():
user_idx = user_id_map[row['sec_id']]
topic_idx = topic_id_map[row['topic']]
edge_index.extend([[user_idx, topic_idx], [topic_idx, user_idx]])
edge_attr = [row[col] for col in edge_columns]
edge_features.extend([edge_attr, edge_attr])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_features_np = np.array(edge_features, dtype=np.float32)
scaler_edge = SafeStandardScaler()
edge_features = torch.tensor(scaler_edge.fit_transform(edge_features_np), dtype=torch.float32)
node_features = np.vstack([user_node_features, topic_node_features])
node_features = torch.tensor(np.nan_to_num(node_features), dtype=torch.float32)
user_labels = torch.tensor(user_features['is_fake'].values, dtype=torch.float32)
position_vectors = torch.randn(node_features.shape[0], 3)
data = Data(
x=node_features,
edge_index=edge_index,
edge_attr=edge_features,
y=user_labels,
pos=position_vectors
)
data.num_users = num_users
assert not torch.isnan(data.x).any(), "NaNs in node features"
assert not torch.isnan(data.edge_attr).any(), "NaNs in edge features"
return data, user_id_map, topic_id_map
# EnergyMPNN Model
class EnergyMPNNLayer(MessagePassing):
def __init__(self, input_node_dim, edge_dim, hidden_dim, pos_dim, dropout=0.4):
super(EnergyMPNNLayer, self).__init__(aggr='mean')
self.input_node_dim = input_node_dim
self.edge_dim = edge_dim
self.hidden_dim = hidden_dim
self.pos_dim = pos_dim
self.dropout = dropout
self.user_mlp = nn.Sequential(
nn.Linear(input_node_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
)
self.user_residual = nn.Linear(input_node_dim, hidden_dim)
self.topic_mlp = nn.Sequential(
nn.Linear(input_node_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
)
self.topic_residual = nn.Linear(input_node_dim, hidden_dim)
message_input_dim = 2 * hidden_dim + edge_dim + 1
self.message_mlp = nn.Sequential(
nn.Linear(message_input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
)
self.update_mlp = nn.Sequential(
nn.Linear(hidden_dim + hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
)
self.update_residual = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, edge_index, edge_attr, pos, num_users):
user_x = x[:num_users]
topic_x = x[num_users:]
user_residual = self.user_residual(user_x)
h_user = self.user_mlp(user_x)
h_user = h_user + user_residual
topic_residual = self.topic_residual(topic_x)
h_topic = self.topic_mlp(topic_x)
h_topic = h_topic + topic_residual
h = torch.cat([h_user, h_topic], dim=0)
h = self.propagate(edge_index, x=h, edge_attr=edge_attr, pos=pos)
return h
def message(self, x_i, x_j, edge_attr, pos_i, pos_j):
dist = torch.norm(pos_i - pos_j, p=2, dim=-1, keepdim=True)
message_input = torch.cat([x_i, x_j, edge_attr, dist], dim=-1)
message = self.message_mlp(message_input)
return message
def update(self, aggr_out, x):
update_input = torch.cat([x, aggr_out], dim=-1)
update_residual = self.update_residual(x)
h = self.update_mlp(update_input) + update_residual
return h
class EnergyMPNN(nn.Module):
def __init__(self, input_node_dim=6, edge_dim=5, hidden_dim=64, pos_dim=3, num_layers=2, dropout=0.4):
super(EnergyMPNN, self).__init__()
self.input_node_dim = input_node_dim
self.edge_dim = edge_dim
self.hidden_dim = hidden_dim
self.pos_dim = pos_dim
self.num_layers = num_layers
self.dropout = dropout
self.layers = nn.ModuleList()
for i in range(num_layers):
layer_input_dim = input_node_dim if i == 0 else hidden_dim
self.layers.append(EnergyMPNNLayer(
input_node_dim=layer_input_dim,
edge_dim=edge_dim,
hidden_dim=hidden_dim,
pos_dim=pos_dim,
dropout=dropout
))
score_input_dim = 2 * hidden_dim + edge_dim + 1
self.score_mlp = nn.Sequential(
nn.Linear(score_input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.BatchNorm1d(hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, edge_index, edge_attr, pos, num_users):
h = x
for i, layer in enumerate(self.layers):
h = layer(h, edge_index, edge_attr, pos, num_users)
u, t = edge_index
dist = torch.norm(pos[u] - pos[t], p=2, dim=-1)
edge_input = torch.cat([
h[u], h[t], edge_attr, dist.unsqueeze(-1)
], dim=-1)
edge_scores = self.score_mlp(edge_input).squeeze()
user_scores = torch.zeros(num_users, device=x.device)
edge_counts = torch.zeros(num_users, device=x.device)
user_mask = (edge_index[0] < num_users)
user_indices = edge_index[0][user_mask]
user_scores.scatter_add_(0, user_indices, edge_scores[user_mask])
edge_counts.scatter_add_(0, user_indices, torch.ones_like(user_indices, dtype=torch.float))
edge_counts = edge_counts.clamp(min=1)
user_scores = user_scores / edge_counts
return user_scores, edge_scores
# Global lists for metrics
all_train_metrics = []
all_val_metrics = []
# Focal Loss (placeholder)
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, logits, targets):
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
pt = torch.exp(-bce)
loss = self.alpha * (1 - pt) ** self.gamma * bce
return loss.mean()
# Calculate metrics (unchanged)
def calculate_metrics(probs, labels):
precision, recall, thresholds = precision_recall_curve(labels.cpu(), probs.cpu())
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5
preds = (probs > optimal_threshold).float()
return {
'acc': accuracy_score(labels.cpu(), preds.cpu()),
'f1': f1_score(labels.cpu(), preds.cpu()),
'auc': roc_auc_score(labels.cpu(), probs.cpu()),
'precision': precision_score(labels.cpu(), preds.cpu()),
'recall': recall_score(labels.cpu(), preds.cpu()),
'threshold': optimal_threshold
}
# Train model (modified with best configuration, saving plots and metrics)
def train_model(data, save_path, test_size,num_epochs=150):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
# Best configuration
hidden_dim = 64
num_layers = 2
dropout = 0.2
lr = 0.001
gamma = 1.0
weight_decay = 0.0005
model = EnergyMPNN(
input_node_dim=6,
edge_dim=5,
hidden_dim=hidden_dim,
pos_dim=3,
num_layers=num_layers,
dropout=dropout
).to(device)
# Create save directories
os.makedirs(save_path, exist_ok=True)
plot_dir = os.path.join(save_path, 'plots')
os.makedirs(plot_dir, exist_ok=True)
user_indices = torch.arange(data.num_users, device='cpu')
y_np = data.y.cpu().numpy()
try:
train_idx, test_idx = train_test_split(
user_indices.numpy(),
test_size=test_size,
stratify=y_np,
random_state=42
)
val_idx, test_idx = train_test_split(
test_idx,
test_size=0.5,
stratify=y_np[test_idx],
random_state=42
)
except Exception as e:
print(f"Error in train/val split: {e}")
return None, [], [], {}, 0
data.train_mask = torch.zeros(data.num_users, dtype=bool, device=device)
data.val_mask = torch.zeros(data.num_users, dtype=bool, device=device)
data.test_mask = torch.zeros(data.num_users, dtype=bool, device=device)
data.train_mask[torch.tensor(train_idx, device=device)] = True
data.val_mask[torch.tensor(val_idx, device=device)] = True
data.test_mask[torch.tensor(test_idx, device=device)] = True
criterion = FocalLoss(gamma=gamma, alpha=0.25)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15)
best_val_auc = 0
epochs_no_improve = 0
patience = 20
train_losses, train_metrics_list, val_metrics_list = [], [], []
start_time = time.time()
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users)
loss = criterion(user_scores[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users)
probs = torch.sigmoid(user_scores)
train_metrics = calculate_metrics(probs[data.train_mask], data.y[data.train_mask])
val_metrics = calculate_metrics(probs[data.val_mask], data.y[data.val_mask])
train_losses.append(loss.item())
train_metrics_list.append(train_metrics)
val_metrics_list.append(val_metrics)
all_train_metrics.append(train_metrics)
all_val_metrics.append(val_metrics)
scheduler.step(val_metrics['auc'])
os.makedirs(os.path.join(save_path,"model_checkpoint"), exist_ok=True)
if val_metrics['auc'] > best_val_auc:
best_val_auc = val_metrics['auc']
torch.save(model.state_dict(), os.path.join(save_path,'model_checkpoint','best_model.pth'))
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
print(f"Epoch {epoch+1}/{num_epochs} | Loss: {loss.item():.4f} | Val AUC: {val_metrics['auc']:.4f} | Val F1: {val_metrics['f1']:.4f}")
training_time = time.time() - start_time
model.load_state_dict(torch.load(os.path.join(save_path, 'model_checkpoint','best_model.pth')))
model.eval()
with torch.no_grad():
user_scores, _ = model(data.x, data.edge_index, data.edge_attr, data.pos, data.num_users)
probs = torch.sigmoid(user_scores[data.test_mask])
test_metrics = calculate_metrics(probs, data.y[data.test_mask])
print("\nFinal Test Metrics:")
print(f"AUC: {test_metrics['auc']:.4f} | F1: {test_metrics['f1']:.4f} | Accuracy: {test_metrics['acc']:.4f}")
print(f"Precision: {test_metrics['precision']:.4f} | Recall: {test_metrics['recall']:.4f} | Threshold: {test_metrics['threshold']:.4f}")
print(f"Training Time: {training_time:.2f} seconds")
# Save metrics to CSVs
train_metrics_df = pd.DataFrame([
{
'epoch': epoch + 1,
'loss': train_losses[epoch],
'auc': tm['auc'],
'f1': tm['f1'],
'accuracy': tm['acc'],
'precision': tm['precision'],
'recall': tm['recall'],
'threshold': tm['threshold']
}
for epoch, tm in enumerate(train_metrics_list)
])
train_metrics_df.to_csv(os.path.join(save_path, 'train_metrics.csv'), index=False)
val_metrics_df = pd.DataFrame([
{
'epoch': epoch + 1,
'auc': vm['auc'],
'f1': vm['f1'],
'accuracy': vm['acc'],
'precision': vm['precision'],
'recall': vm['recall'],
'threshold': vm['threshold']
}
for epoch, vm in enumerate(val_metrics_list)
])
val_metrics_df.to_csv(os.path.join(save_path, 'val_metrics.csv'), index=False)
test_metrics_df = pd.DataFrame([test_metrics])
test_metrics_df.to_csv(os.path.join(save_path, 'test_metrics.csv'), index=False)
# Save training summary as CSV
training_summary = pd.DataFrame([{
'hidden_dim': hidden_dim,
'num_layers': num_layers,
'dropout': dropout,
'lr': lr,
'gamma': gamma,
'weight_decay': weight_decay,
'num_epochs': num_epochs,
'training_time_seconds': training_time,
'test_auc': test_metrics['auc'],
'test_f1': test_metrics['f1'],
'test_accuracy': test_metrics['acc'],
'test_precision': test_metrics['precision'],
'test_recall': test_metrics['recall'],
'test_threshold': test_metrics['threshold']
}])
training_summary.to_csv(os.path.join(save_path, 'training_summary.csv'), index=False)
# Create and save plots
# 1. Metrics over epochs
plt.figure(figsize=(15, 10))
plt.subplot(2, 4, 1)
plt.plot(train_metrics_df['epoch'], train_metrics_df['loss'], label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.subplot(2, 4, 2)
plt.plot(val_metrics_df['epoch'], val_metrics_df['auc'], label='Val AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('Validation AUC')
plt.legend()
plt.subplot(2, 4, 3)
plt.plot(val_metrics_df['epoch'], val_metrics_df['f1'], label='Val F1')
plt.xlabel('Epoch')
plt.ylabel('F1')
plt.title('Validation F1')
plt.legend()
plt.subplot(2, 4, 4)
plt.plot(val_metrics_df['epoch'], val_metrics_df['accuracy'], label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.subplot(2, 4, 5)
plt.plot(val_metrics_df['epoch'], val_metrics_df['precision'], label='Val Precision')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.title('Validation Precision')
plt.legend()
plt.subplot(2, 4, 6)
plt.plot(val_metrics_df['epoch'], val_metrics_df['recall'], label='Val Recall')
plt.xlabel('Epoch')
plt.ylabel('Recall')
plt.title('Validation Recall')
plt.legend()
plt.subplot(2, 4, 7)
plt.plot(val_metrics_df['epoch'], val_metrics_df['threshold'], label='Val Threshold')
plt.xlabel('Epoch')
plt.ylabel('Threshold')
plt.title('Validation Threshold')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, 'metrics_over_epochs.png'))
plt.close()
# 2. Precision-Recall Curve (Test Set)
precision, recall, _ = precision_recall_curve(data.y[data.test_mask].cpu(), probs.cpu())
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label='Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve (Test Set)')
plt.legend()
plt.savefig(os.path.join(plot_dir, 'precision_recall_curve.png'))
plt.close()
# 3. ROC Curve (Test Set)
fpr, tpr, _ = roc_curve(data.y[data.test_mask].cpu(), probs.cpu())
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {test_metrics["auc"]:.4f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (Test Set)')
plt.legend()
plt.savefig(os.path.join(plot_dir, 'roc_curve.png'))
plt.close()
return model, train_losses, val_metrics_list[-1], test_metrics, training_time
# Placeholder for EnergyMPNN
# Main execution
def trainer(save_path,test_size,num_epochs,df_final):
try:
# Placeholder for df_final
data, _, _ = preprocess_data(df_final)
logger.info(f"Data shapes: x={data.x.shape}, edge_index={data.edge_index.shape}, edge_attr={data.edge_attr.shape}, y={data.y.shape}, pos={data.pos.shape}")
model, train_losses, val_metrics, test_metrics, training_time = train_model(
data, save_path,test_size, num_epochs=num_epochs
)
logger.info("Training completed.")
except Exception as e:
logger.info(f"Error: {str(e)}")