GraPHFormer / graphformer /models /tree_encoder.py
uzshah's picture
Initial commit: GraPHFormer codebase
cf84204
"""
Tree Encoder Components for Neuron Morphology Analysis
This module contains TreeLSTM and related components for encoding tree structures.
"""
import torch
import torch.nn as nn
import dgl
class MultiNodeAggregation(nn.Module):
"""Attention-based aggregation over all tree nodes instead of just root"""
def __init__(self, h_size, aggregation_type="attention"):
super(MultiNodeAggregation, self).__init__()
self.h_size = h_size
self.aggregation_type = aggregation_type
if aggregation_type == "attention":
# Learnable attention over nodes
self.attention = nn.Sequential(
nn.Linear(h_size, h_size),
nn.Tanh(),
nn.Linear(h_size, 1)
)
elif aggregation_type == "weighted":
# Learnable weighted sum
self.weight_net = nn.Sequential(
nn.Linear(h_size, h_size // 2),
nn.ReLU(),
nn.Linear(h_size // 2, 1),
nn.Sigmoid()
)
def forward(self, g, node_features, offsets):
"""
Args:
g: DGL graph
node_features: (N, h_size) features for all nodes
offsets: indices of root nodes for each tree in batch
Returns:
aggregated: (batch_size, h_size) aggregated features
"""
batch_size = len(offsets)
aggregated = []
# For each tree in batch
for i in range(batch_size):
# Get start and end indices for this tree's nodes
if i < batch_size - 1:
start_idx = offsets[i] if i == 0 else offsets[i-1]
end_idx = offsets[i+1]
else:
start_idx = offsets[i-1] if i > 0 else 0
end_idx = len(node_features)
# Extract this tree's node features
tree_feats = node_features[start_idx:end_idx] # (num_nodes_i, h_size)
if self.aggregation_type == "attention":
# Attention weights
attn_scores = self.attention(tree_feats) # (num_nodes_i, 1)
attn_weights = torch.softmax(attn_scores, dim=0)
# Weighted sum
agg = (tree_feats * attn_weights).sum(dim=0) # (h_size,)
elif self.aggregation_type == "weighted":
# Simple learnable weighting
weights = self.weight_net(tree_feats) # (num_nodes_i, 1)
agg = (tree_feats * weights).sum(dim=0) # (h_size,)
elif self.aggregation_type == "mean":
agg = tree_feats.mean(dim=0) # (h_size,)
elif self.aggregation_type == "max":
agg = tree_feats.max(dim=0)[0] # (h_size,)
aggregated.append(agg)
return torch.stack(aggregated, dim=0) # (batch_size, h_size)
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size, mode="sum"):
super(TreeLSTMCell, self).__init__()
self.h_size, self.mode = h_size, mode
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
if self.mode == "sum":
h_cat = nodes.mailbox["h"].sum(dim=1)
elif self.mode == "max":
h_cat = nodes.mailbox["h"].max(dim=1)[0]
elif self.mode == "mean":
h_cat = nodes.mailbox["h"].mean(dim=1)
else:
raise NotImplementedError
f = torch.sigmoid(self.U_f(nodes.mailbox["h"]))
c = torch.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_cat), "c": c}
def apply_node_func(self, nodes):
iou = nodes.data["iou"] + self.b_iou
i, o, u = torch.chunk(iou, 3, 1)
i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
c = i * u + nodes.data["c"]
h = o * torch.tanh(c)
return {"h": h, "c": c}
class BidirectionalTreeLSTMCell(nn.Module):
"""Bidirectional TreeLSTM Cell that processes tree in both bottom-up and top-down directions"""
def __init__(self, x_size, h_size, mode="sum"):
super(BidirectionalTreeLSTMCell, self).__init__()
self.h_size = h_size
self.mode = mode
# Bottom-up (child to parent) parameters
self.W_iou_bu = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou_bu = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou_bu = nn.Parameter(torch.zeros(1, 3 * h_size))
self.U_f_bu = nn.Linear(h_size, h_size)
# Top-down (parent to child) parameters
self.W_iou_td = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou_td = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou_td = nn.Parameter(torch.zeros(1, 3 * h_size))
self.U_f_td = nn.Linear(h_size, h_size)
def message_func_bu(self, edges):
"""Bottom-up message function (from children to parent)"""
return {"h_bu": edges.src["h_bu"], "c_bu": edges.src["c_bu"]}
def reduce_func_bu(self, nodes):
"""Bottom-up reduce function"""
if self.mode == "sum":
h_cat = nodes.mailbox["h_bu"].sum(dim=1)
elif self.mode == "max":
h_cat = nodes.mailbox["h_bu"].max(dim=1)[0]
elif self.mode == "mean":
h_cat = nodes.mailbox["h_bu"].mean(dim=1)
else:
raise NotImplementedError
f = torch.sigmoid(self.U_f_bu(nodes.mailbox["h_bu"]))
c = torch.sum(f * nodes.mailbox["c_bu"], 1)
return {"iou_bu": self.U_iou_bu(h_cat), "c_bu": c}
def apply_node_func_bu(self, nodes):
"""Bottom-up apply node function"""
iou = nodes.data["iou_bu"] + self.b_iou_bu
i, o, u = torch.chunk(iou, 3, 1)
i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
c = i * u + nodes.data["c_bu"]
h = o * torch.tanh(c)
return {"h_bu": h, "c_bu": c}
def message_func_td(self, edges):
"""Top-down message function (from parent to children)"""
return {"h_td": edges.dst["h_td"], "c_td": edges.dst["c_td"]}
def reduce_func_td(self, nodes):
"""Top-down reduce function"""
if nodes.mailbox["h_td"].shape[1] == 0:
# Root node has no parent
return {"iou_td": torch.zeros(nodes.batch_size(), 3 * self.h_size, device=nodes.mailbox["h_td"].device),
"c_td": torch.zeros(nodes.batch_size(), self.h_size, device=nodes.mailbox["h_td"].device)}
# For non-root nodes, aggregate parent information
h_parent = nodes.mailbox["h_td"][:, 0, :] # Take first (and only) parent
c_parent = nodes.mailbox["c_td"][:, 0, :]
return {"iou_td": self.U_iou_td(h_parent), "c_td": c_parent}
def apply_node_func_td(self, nodes):
"""Top-down apply node function"""
iou = nodes.data["iou_td"] + self.b_iou_td
i, o, u = torch.chunk(iou, 3, 1)
i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
c = i * u + nodes.data["c_td"]
h = o * torch.tanh(c)
return {"h_td": h, "c_td": c}
class BidirectionalTreeLSTM(nn.Module):
"""Bidirectional TreeLSTM that combines bottom-up and top-down processing"""
def __init__(self, x_size, h_size, num_classes, mode="sum", fc=True, bn=False,
node_aggregation=None):
super(BidirectionalTreeLSTM, self).__init__()
self.x_size = x_size
self.h_size = h_size
self.node_aggregation = node_aggregation
# Input projection
if bn:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
)
else:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.ReLU(),
)
# Bidirectional TreeLSTM cell
self.cell = BidirectionalTreeLSTMCell(h_size, h_size, mode=mode)
# Node aggregation
if node_aggregation:
self.node_agg = MultiNodeAggregation(h_size * 2, aggregation_type=node_aggregation)
# Classification head
self.fc = fc
if fc:
self.linear = nn.Linear(h_size * 2, num_classes)
def forward(self, batch):
"""Forward pass combining bottom-up and top-down TreeLSTM"""
g = batch.graph.to(torch.device("cuda"))
g = dgl.graph(g.edges())
n = g.number_of_nodes()
# Input projection
feats = self.mlp1(batch.feats.cuda())
# Initialize node states for both directions
g.ndata["iou_bu"] = self.cell.W_iou_bu(feats)
g.ndata["iou_td"] = self.cell.W_iou_td(feats)
g.ndata["h_bu"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c_bu"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["h_td"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c_td"] = torch.zeros((n, self.h_size)).cuda()
# Bottom-up pass
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func_bu,
reduce_func=self.cell.reduce_func_bu,
apply_node_func=self.cell.apply_node_func_bu,
)
# Top-down pass (reverse topological order)
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func_td,
reduce_func=self.cell.reduce_func_td,
apply_node_func=self.cell.apply_node_func_td,
reverse=True,
)
# Combine bottom-up and top-down representations
h_combined = torch.cat([g.ndata.pop("c_bu"), g.ndata.pop("c_td")], dim=1)
# Aggregate node features
if self.node_aggregation:
h = self.node_agg(g, h_combined, batch.offset.long())
else:
h = h_combined[batch.offset.long()]
# Classification
if self.fc:
return self.linear(h)
return h
class TreeLSTM(nn.Module):
def __init__(self, x_size, h_size, num_classes, mode="sum", fc=True, bn=False,
node_aggregation=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.h_size = h_size
self.node_aggregation = node_aggregation
if bn:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
)
else:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.ReLU(),
)
self.cell = TreeLSTMCell(h_size, h_size, mode=mode)
if node_aggregation:
self.node_agg = MultiNodeAggregation(h_size, aggregation_type=node_aggregation)
self.fc = fc
if fc:
self.linear = nn.Linear(h_size, num_classes)
def forward(self, batch):
g = batch.graph.to(torch.device("cuda"))
g = dgl.graph(g.edges())
n = g.number_of_nodes()
feats = self.mlp1(batch.feats.cuda())
g.ndata["iou"] = self.cell.W_iou(feats)
g.ndata["h"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c"] = torch.zeros((n, self.h_size)).cuda()
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
h = g.ndata.pop("c")
if self.node_aggregation:
h = self.node_agg(g, h, batch.offset.long())
else:
h = h[batch.offset.long()]
if self.fc:
return self.linear(h)
return h
class TreeLSTM_wo_MLP(nn.Module):
"""TreeLSTM without initial MLP projection"""
def __init__(self, x_size, h_size, num_classes, mode="sum", fc=True):
super(TreeLSTM_wo_MLP, self).__init__()
self.x_size = x_size
self.h_size = h_size
self.cell = TreeLSTMCell(x_size, h_size, mode=mode)
self.fc = fc
if fc:
self.linear = nn.Linear(h_size, num_classes)
def forward(self, batch):
g = batch.graph.to(torch.device("cuda"))
g = dgl.graph(g.edges())
n = g.number_of_nodes()
feats = batch.feats.cuda()
g.ndata["iou"] = self.cell.W_iou(feats)
g.ndata["h"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c"] = torch.zeros((n, self.h_size)).cuda()
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
h = g.ndata.pop("c")[batch.offset.long()]
if self.fc:
return self.linear(h)
return h
class TreeLSTMCellv2(nn.Module):
"""TreeLSTM Cell with alternative architecture"""
def __init__(self, x_size, h_size, mode="sum"):
super(TreeLSTMCellv2, self).__init__()
self.h_size = h_size
self.mode = mode
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, h_size)
def message_func(self, edges):
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
if self.mode == "sum":
h_cat = nodes.mailbox["h"].sum(dim=1)
elif self.mode == "max":
h_cat = nodes.mailbox["h"].max(dim=1)[0]
elif self.mode == "mean":
h_cat = nodes.mailbox["h"].mean(dim=1)
h_max = nodes.mailbox["h"].max(dim=1)[0]
h_combined = torch.cat([h_cat, h_max], dim=1)
f = torch.sigmoid(self.U_f(h_combined.unsqueeze(1).expand(-1, nodes.mailbox["h"].shape[1], -1)))
c = torch.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_combined), "c": c}
def apply_node_func(self, nodes):
iou = nodes.data["iou"] + self.b_iou
i, o, u = torch.chunk(iou, 3, 1)
i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
c = i * u + nodes.data["c"]
h = o * torch.tanh(c)
return {"h": h, "c": c}
class TreeLSTMDoubleCell(nn.Module):
"""TreeLSTM Cell with double hidden state (original implementation)"""
def __init__(self, x_size, h_size, mode="sum"):
super(TreeLSTMDoubleCell, self).__init__()
self.W1_iouf = nn.Linear(x_size, 4 * h_size)
self.U1_iouf = nn.Linear(h_size, 4 * h_size)
self.W2_iouf = nn.Linear(h_size, 4 * h_size)
self.U2_iouf = nn.Linear(h_size, 4 * h_size)
self.mode = mode
self.init_state = True
self.h_size = h_size
def message_func(self, edges):
return {
"h1": edges.src["h1"],
"c1": edges.src["c1"],
"h2": edges.src["h2"],
"c2": edges.src["c2"],
}
def reduce_func(self, nodes):
h1, c1 = nodes.mailbox["h1"], nodes.mailbox["c1"]
h2, c2 = nodes.mailbox["h2"], nodes.mailbox["c2"]
if self.mode == "sum":
h1, c1, h2, c2 = h1.sum(-2), c1.sum(-2), h2.sum(-2), c2.sum(-2)
elif self.mode == "mean":
h1, c1, h2, c2 = h1.mean(-2), c1.mean(-2), h2.mean(-2), c2.mean(-2)
else:
raise ValueError("must in [sum, mean]")
x_iouf = nodes.data["iouf"]
xi, xo, xu, xf = torch.chunk(x_iouf, 4, 1)
h_iouf1 = self.U1_iouf(h1)
hi1, ho1, hu1, hf1 = torch.chunk(h_iouf1, 4, 1)
i = torch.sigmoid(xi + hi1)
f = torch.sigmoid(xf + hf1)
o = torch.sigmoid(xo + ho1)
u = torch.tanh(xu + hu1)
c1 = i * u + f * c1
h1 = o * torch.tanh(c1)
x_iouf2 = self.W2_iouf(c1)
xi, xo, xu, xf = torch.chunk(x_iouf2, 4, 1)
h_iouf2 = self.U2_iouf(h2)
hi2, ho2, hu2, hf2 = torch.chunk(h_iouf2, 4, 1)
i = torch.sigmoid(xi + hi2)
f = torch.sigmoid(xf + hf2)
o = torch.sigmoid(xo + ho2)
u = torch.tanh(xu + hu2)
c2 = i * u + f * c2
h2 = o * torch.tanh(c2)
return {"h1": h1, "c1": c1, "h2": h2, "c2": c2}
def apply_node_func(self, nodes):
if self.init_state:
iouf = nodes.data["iouf"]
i, o, u, f = torch.chunk(iouf, 4, 1)
i, o, u, f = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u), torch.sigmoid(f)
c1 = i * u + f * nodes.data["c1"]
h1 = o * torch.tanh(c1)
iouf2 = self.W2_iouf(c1)
i, o, u, f = torch.chunk(iouf2, 4, 1)
i, o, u, f = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u), torch.sigmoid(f)
c2 = i * u + f * nodes.data["c2"]
h2 = o * torch.tanh(c2)
self.init_state = False
return {"h1": h1, "c1": c1, "h2": h2, "c2": c2}
else:
return {
"h1": nodes.data["h1"],
"c1": nodes.data["c1"],
"h2": nodes.data["h2"],
"c2": nodes.data["c2"],
}
class TreeLSTMDouble(nn.Module):
"""TreeLSTM with double hidden state aggregation"""
def __init__(self, x_size, h_size, num_classes, mode="sum", fc=True, bn=False, node_aggregation=None):
super(TreeLSTMDouble, self).__init__()
self.x_size, self.h_size = x_size, h_size
self.node_aggregation = node_aggregation
if bn:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
nn.Linear(h_size, 2 * h_size),
nn.BatchNorm1d(2 * h_size),
nn.ReLU(),
nn.Linear(2 * h_size, h_size),
)
else:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.ReLU(),
nn.Linear(h_size, 2 * h_size),
nn.ReLU(),
nn.Linear(2 * h_size, h_size),
)
self.cell = TreeLSTMDoubleCell(h_size, h_size, mode=mode)
if node_aggregation:
self.node_agg = MultiNodeAggregation(h_size, aggregation_type=node_aggregation)
self.fc = fc
if fc:
self.linear = nn.Linear(h_size, num_classes)
def forward_backbone(self, batch):
g = batch.graph.to(torch.device("cuda"))
# to heterogenous graph
g = dgl.graph(g.edges())
n = g.number_of_nodes()
# feed embedding
feats = self.mlp1(batch.feats.cuda())
g.ndata["iouf"] = self.cell.W1_iouf(feats)
g.ndata["h1"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c1"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["h2"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c2"] = torch.zeros((n, self.h_size)).cuda()
# propagate
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
logits = g.ndata.pop("c2")[batch.offset.long()]
return logits
def forward(self, batch):
logits = self.forward_backbone(batch)
if self.fc:
logits = self.linear(logits)
return logits
else:
return logits
class TreeLSTMv2(nn.Module):
"""TreeLSTM variant with improved architecture"""
def __init__(self, x_size, h_size, num_classes, mode="sum", fc=True, bn=False):
super(TreeLSTMv2, self).__init__()
self.x_size, self.h_size = x_size, h_size
if bn:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.BatchNorm1d(h_size),
nn.ReLU(),
)
else:
self.mlp1 = nn.Sequential(
nn.Linear(x_size, h_size),
nn.ReLU(),
)
self.cell = TreeLSTMCellv2(h_size, h_size, mode=mode)
self.fc = fc
if fc:
self.linear = nn.Linear(h_size, num_classes)
def forward(self, batch):
g = batch.graph.to(torch.device("cuda"))
g = dgl.graph(g.edges())
n = g.number_of_nodes()
feats = self.mlp1(batch.feats.cuda())
g.ndata["iou"] = self.cell.W_iou(feats)
g.ndata["h"] = torch.zeros((n, self.h_size)).cuda()
g.ndata["c"] = torch.zeros((n, self.h_size)).cuda()
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
h = g.ndata.pop("c")[batch.offset.long()]
if self.fc:
return self.linear(h)
return h