File size: 6,484 Bytes
50a402c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from torch import nn
from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv
class XGNet(nn.Module):
"""
Heterogeneous GNN for xG with Persistent Nodes and Shot-Indexed Edges.
Graph Structure:
- Nodes: shooter (num_players, persistent), goal (1, persistent), goalkeeper (1, persistent)
- Edges: shooter -> goal (distance, angle), shooter -> goalkeeper (dist_to_gk)
- Global: shot-level contextual features (18 features)
- Masking: All edges and global features indexed by shot_idx for prediction
"""
def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int,
use_norm: bool, num_global_features: int = 18):
super().__init__()
# 1) Node encoders ---------------------------------------------------
self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK
self.goal_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goal
self.gk_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goalkeeper
# Global feature encoder
self.global_encoder = nn.Linear(num_global_features, hid)
self.dropout = nn.Dropout(p=p)
# 2) Edge-conditioned message passing -------------------------------
def mk_gat_with_edge(edge_dim: int):
"""GAT with edge features"""
return GATv2Conv(
in_channels=(hid, hid),
out_channels=hid,
edge_dim=edge_dim,
heads=heads,
concat=False,
dropout=p,
add_self_loops=False,
)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList() if use_norm else None
for _ in range(num_layers):
conv = HeteroConv({
('shooter', 'shoots_at', 'goal'): mk_gat_with_edge(edge_dim=2), # distance + angle
('goal', 'rev_shoots_at', 'shooter'): mk_gat_with_edge(edge_dim=2), # reverse edge
('shooter', 'faces', 'goalkeeper'): mk_gat_with_edge(edge_dim=1), # dist_to_gk
('goalkeeper', 'rev_faces', 'shooter'): mk_gat_with_edge(edge_dim=1), # reverse edge
}, aggr='sum')
self.convs.append(conv)
if use_norm:
# Normalize for each node type
self.norms.append(nn.ModuleDict({
'shooter': nn.LayerNorm(hid),
'goal': nn.LayerNorm(hid),
'goalkeeper': nn.LayerNorm(hid),
}))
# 3) Read-out --------------------------------------------------------
self.output = nn.Sequential(
nn.Linear(hid * 3, hid), # Combine shooter + goal + goalkeeper
nn.ReLU(),
nn.Dropout(p),
nn.Linear(hid, hid//2),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(hid//2, 1),
)
# ----------------------------------------------------------------------
def forward(self, data, shot_idx):
"""
Forward pass for a specific shot.
Args:
data: HeteroData with all nodes and edges
shot_idx: Index of the shot to predict (for masking edges/features)
"""
# Prepare node feature dict (all persistent nodes)
shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long())
shooter_emb = self.dropout(shooter_emb)
x = {
'shooter': shooter_emb,
'goal': self.goal_feat.expand(data['goal'].num_nodes, -1),
'goalkeeper': self.gk_feat.expand(data['goalkeeper'].num_nodes, -1),
}
# Mask edges for this specific shot
shooter_goal_mask = (data['shooter', 'shoots_at', 'goal'].shot_idx == shot_idx)
shooter_gk_mask = (data['shooter', 'faces', 'goalkeeper'].shot_idx == shot_idx)
edge_index_dict = {
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask],
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask].flip(0), # reverse
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask],
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask].flip(0), # reverse
}
edge_attr_dict = {
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask],
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], # same attributes
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask],
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], # same attributes
}
# Message passing with masked edges
for li, conv in enumerate(self.convs):
x_new = conv(x, edge_index_dict, edge_attr_dict)
# Apply normalization and residual connection
if self.norms is not None:
for node_type in x.keys():
if node_type in x_new:
x_new[node_type] = self.norms[li][node_type](x_new[node_type])
x[node_type] = self.dropout(x_new[node_type] + x[node_type])
else:
for node_type in x.keys():
if node_type in x_new:
x[node_type] = self.dropout(x_new[node_type] + x[node_type])
# Get the active shooter for this shot
active_shooter_idx = edge_index_dict[('shooter', 'shoots_at', 'goal')][0, 0]
shooter_repr = x['shooter'][active_shooter_idx] # (hid,)
goal_repr = x['goal'][0] # (hid,)
gk_repr = x['goalkeeper'][0] # (hid,)
# Get global features for this shot
global_mask = (data['global'].shot_idx == shot_idx)
global_feat = self.global_encoder(data['global'].x[global_mask].squeeze(0)) # (hid,)
# Combine all representations
combined = torch.cat([
shooter_repr + global_feat, # Shooter with context
goal_repr,
gk_repr
], dim=0) # (hid * 3,)
return self.output(combined.unsqueeze(0)).squeeze() # scalar xG prediction
|