| import torch |
|
|
| from sklearn.metrics import roc_auc_score, average_precision_score |
| from src.tgn.time_encoding import TimeEncoding |
| from src.tgn.memory import Memory |
|
|
|
|
| def evaluate(model, memory, graph_data, norm_stats): |
| device = torch.device("cpu") |
|
|
| edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long) |
| edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32) |
| labels = torch.tensor(graph_data["y"], dtype=torch.float32) |
|
|
| x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) |
| x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) |
|
|
| |
| edge_attr = (edge_attr - norm_stats["ea_mean"]) / norm_stats["ea_std"] |
|
|
| timestamps = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)[:, 1] |
| timestamps = (timestamps - norm_stats["t_min"]) / (norm_stats["t_max"] - norm_stats["t_min"] + 1e-6) |
|
|
| test_idx = graph_data["test_idx"] |
| train_idx = graph_data["train_idx"] |
|
|
| |
| memory = Memory(x.shape[0], memory_dim=64, device=device) |
| time_encoder = TimeEncoding(16).to(device) |
|
|
| batch_size = 1024 |
|
|
| with torch.no_grad(): |
| for i in range(0, len(train_idx), batch_size): |
| batch_ids = train_idx[i:i + batch_size] |
|
|
| u_i = edge_index[0, batch_ids] |
| v_i = edge_index[1, batch_ids] |
|
|
| edge_feat_i = edge_attr[batch_ids] |
| t_i = timestamps[batch_ids] |
|
|
| time_enc_i = time_encoder(t_i) |
|
|
| h_u_i = memory.get(u_i) |
| h_v_i = memory.get(v_i) |
|
|
| msg = model.compute_message( |
| h_u_i.detach(), h_v_i.detach(), |
| edge_feat_i, time_enc_i |
| ) |
|
|
| node_ids = torch.cat([u_i, v_i]) |
| messages = torch.cat([msg, msg]) |
|
|
| unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) |
|
|
| agg_msg = torch.zeros_like(memory.memory[unique_nodes]) |
| agg_msg.index_add_(0, inverse_idx, messages) |
|
|
| counts = torch.bincount(inverse_idx).unsqueeze(1) |
| agg_msg = agg_msg / counts |
|
|
| memory.update(unique_nodes, agg_msg) |
|
|
| |
| u = edge_index[0, test_idx].to(device) |
| v = edge_index[1, test_idx].to(device) |
|
|
| h_u = memory.get(u) |
| h_v = memory.get(v) |
|
|
| x_u = x[u] |
| x_v = x[v] |
|
|
| edge_feat = edge_attr[test_idx].to(device) |
|
|
| with torch.no_grad(): |
| t = timestamps[test_idx].to(device) |
| time_enc = time_encoder(t) |
|
|
| logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc) |
| probs = torch.sigmoid(logits).cpu().numpy() |
|
|
| y_true = labels[test_idx].cpu().numpy() |
|
|
| roc = roc_auc_score(y_true, probs) |
| pr = average_precision_score(y_true, probs) |
|
|
| return roc, pr, probs, y_true |