temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
import torch
import torch.nn as nn
class TGN(nn.Module):
def __init__(self, memory_dim, node_dim, edge_dim, time_dim, hidden_dim=128):
super().__init__()
self.memory_dim = memory_dim
self.node_dim = node_dim
self.time_dim = time_dim
# -------------------------
# MESSAGE FUNCTION
# -------------------------
self.message_mlp = nn.Sequential(
nn.Linear(2 * memory_dim + edge_dim + 2 * time_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, memory_dim),
)
# -------------------------
# MEMORY UPDATE
# -------------------------
self.update_mlp = nn.GRUCell(memory_dim, memory_dim)
# -------------------------
# EDGE PREDICTOR (TIME-AWARE)
# -------------------------
self.decoder = nn.Sequential(
nn.Linear(
2 * (memory_dim + node_dim) + edge_dim + 2 * time_dim,
hidden_dim
),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
# -------------------------
# NODE RISK CLASSIFIER (NEW)
# -------------------------
self.node_classifier = nn.Sequential(
nn.Linear(memory_dim + node_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
# -------------------------
# MESSAGE COMPUTATION
# -------------------------
def compute_message(self, h_u, h_v, edge_attr, time_enc):
return self.message_mlp(
torch.cat([h_u, h_v, edge_attr, time_enc], dim=1)
)
# -------------------------
# MEMORY UPDATE
# -------------------------
def update_memory(self, memory, node_ids, messages):
updated = self.update_mlp(messages, memory[node_ids])
memory[node_ids] = updated.detach()
return memory
# -------------------------
# PREDICTION (UPDATED)
# -------------------------
def predict(self, h_u, h_v, edge_attr, x_u, x_v, time_enc):
return self.decoder(
torch.cat([h_u, x_u, h_v, x_v, edge_attr, time_enc], dim=1)
).squeeze(-1)
# -------------------------
# NODE PREDICTION (NEW)
# -------------------------
def predict_node(self, memory, x):
combined = torch.cat([memory, x], dim=1)
return self.node_classifier(combined).squeeze(-1)