focal_gnn_one_graph / model_architecture.py
rokati's picture
Upload model_architecture.py with huggingface_hub
50a402c verified
import torch
from torch import nn
from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv
class XGNet(nn.Module):
"""
Heterogeneous GNN for xG with Persistent Nodes and Shot-Indexed Edges.
Graph Structure:
- Nodes: shooter (num_players, persistent), goal (1, persistent), goalkeeper (1, persistent)
- Edges: shooter -> goal (distance, angle), shooter -> goalkeeper (dist_to_gk)
- Global: shot-level contextual features (18 features)
- Masking: All edges and global features indexed by shot_idx for prediction
"""
def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int,
use_norm: bool, num_global_features: int = 18):
super().__init__()
# 1) Node encoders ---------------------------------------------------
self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK
self.goal_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goal
self.gk_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goalkeeper
# Global feature encoder
self.global_encoder = nn.Linear(num_global_features, hid)
self.dropout = nn.Dropout(p=p)
# 2) Edge-conditioned message passing -------------------------------
def mk_gat_with_edge(edge_dim: int):
"""GAT with edge features"""
return GATv2Conv(
in_channels=(hid, hid),
out_channels=hid,
edge_dim=edge_dim,
heads=heads,
concat=False,
dropout=p,
add_self_loops=False,
)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList() if use_norm else None
for _ in range(num_layers):
conv = HeteroConv({
('shooter', 'shoots_at', 'goal'): mk_gat_with_edge(edge_dim=2), # distance + angle
('goal', 'rev_shoots_at', 'shooter'): mk_gat_with_edge(edge_dim=2), # reverse edge
('shooter', 'faces', 'goalkeeper'): mk_gat_with_edge(edge_dim=1), # dist_to_gk
('goalkeeper', 'rev_faces', 'shooter'): mk_gat_with_edge(edge_dim=1), # reverse edge
}, aggr='sum')
self.convs.append(conv)
if use_norm:
# Normalize for each node type
self.norms.append(nn.ModuleDict({
'shooter': nn.LayerNorm(hid),
'goal': nn.LayerNorm(hid),
'goalkeeper': nn.LayerNorm(hid),
}))
# 3) Read-out --------------------------------------------------------
self.output = nn.Sequential(
nn.Linear(hid * 3, hid), # Combine shooter + goal + goalkeeper
nn.ReLU(),
nn.Dropout(p),
nn.Linear(hid, hid//2),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(hid//2, 1),
)
# ----------------------------------------------------------------------
def forward(self, data, shot_idx):
"""
Forward pass for a specific shot.
Args:
data: HeteroData with all nodes and edges
shot_idx: Index of the shot to predict (for masking edges/features)
"""
# Prepare node feature dict (all persistent nodes)
shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long())
shooter_emb = self.dropout(shooter_emb)
x = {
'shooter': shooter_emb,
'goal': self.goal_feat.expand(data['goal'].num_nodes, -1),
'goalkeeper': self.gk_feat.expand(data['goalkeeper'].num_nodes, -1),
}
# Mask edges for this specific shot
shooter_goal_mask = (data['shooter', 'shoots_at', 'goal'].shot_idx == shot_idx)
shooter_gk_mask = (data['shooter', 'faces', 'goalkeeper'].shot_idx == shot_idx)
edge_index_dict = {
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask],
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask].flip(0), # reverse
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask],
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask].flip(0), # reverse
}
edge_attr_dict = {
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask],
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], # same attributes
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask],
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], # same attributes
}
# Message passing with masked edges
for li, conv in enumerate(self.convs):
x_new = conv(x, edge_index_dict, edge_attr_dict)
# Apply normalization and residual connection
if self.norms is not None:
for node_type in x.keys():
if node_type in x_new:
x_new[node_type] = self.norms[li][node_type](x_new[node_type])
x[node_type] = self.dropout(x_new[node_type] + x[node_type])
else:
for node_type in x.keys():
if node_type in x_new:
x[node_type] = self.dropout(x_new[node_type] + x[node_type])
# Get the active shooter for this shot
active_shooter_idx = edge_index_dict[('shooter', 'shoots_at', 'goal')][0, 0]
shooter_repr = x['shooter'][active_shooter_idx] # (hid,)
goal_repr = x['goal'][0] # (hid,)
gk_repr = x['goalkeeper'][0] # (hid,)
# Get global features for this shot
global_mask = (data['global'].shot_idx == shot_idx)
global_feat = self.global_encoder(data['global'].x[global_mask].squeeze(0)) # (hid,)
# Combine all representations
combined = torch.cat([
shooter_repr + global_feat, # Shooter with context
goal_repr,
gk_repr
], dim=0) # (hid * 3,)
return self.output(combined.unsqueeze(0)).squeeze() # scalar xG prediction