| import torch |
| import torch.nn as nn |
|
|
|
|
| class TGN(nn.Module): |
| def __init__(self, memory_dim, node_dim, edge_dim, time_dim, hidden_dim=128): |
| super().__init__() |
|
|
| self.memory_dim = memory_dim |
| self.node_dim = node_dim |
| self.time_dim = time_dim |
|
|
| |
| |
| |
| self.message_mlp = nn.Sequential( |
| nn.Linear(2 * memory_dim + edge_dim + 2 * time_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, memory_dim), |
| ) |
|
|
| |
| |
| |
| self.update_mlp = nn.GRUCell(memory_dim, memory_dim) |
|
|
| |
| |
| |
| self.decoder = nn.Sequential( |
| nn.Linear( |
| 2 * (memory_dim + node_dim) + edge_dim + 2 * time_dim, |
| hidden_dim |
| ), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(hidden_dim // 2, 1), |
| ) |
|
|
| |
| |
| |
| self.node_classifier = nn.Sequential( |
| nn.Linear(memory_dim + node_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ) |
|
|
| |
| |
| |
| def compute_message(self, h_u, h_v, edge_attr, time_enc): |
| return self.message_mlp( |
| torch.cat([h_u, h_v, edge_attr, time_enc], dim=1) |
| ) |
|
|
| |
| |
| |
| def update_memory(self, memory, node_ids, messages): |
| updated = self.update_mlp(messages, memory[node_ids]) |
| memory[node_ids] = updated.detach() |
| return memory |
|
|
| |
| |
| |
| def predict(self, h_u, h_v, edge_attr, x_u, x_v, time_enc): |
| return self.decoder( |
| torch.cat([h_u, x_u, h_v, x_v, edge_attr, time_enc], dim=1) |
| ).squeeze(-1) |
|
|
| |
| |
| |
| def predict_node(self, memory, x): |
| combined = torch.cat([memory, x], dim=1) |
| return self.node_classifier(combined).squeeze(-1) |