|
|
import torch |
|
|
from torch import nn |
|
|
from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv |
|
|
|
|
|
class XGNet(nn.Module): |
|
|
""" |
|
|
Heterogeneous GNN for xG with Persistent Nodes and Shot-Indexed Edges. |
|
|
|
|
|
Graph Structure: |
|
|
- Nodes: shooter (num_players, persistent), goal (1, persistent), goalkeeper (1, persistent) |
|
|
- Edges: shooter -> goal (distance, angle), shooter -> goalkeeper (dist_to_gk) |
|
|
- Global: shot-level contextual features (18 features) |
|
|
- Masking: All edges and global features indexed by shot_idx for prediction |
|
|
""" |
|
|
def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int, |
|
|
use_norm: bool, num_global_features: int = 18): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.shooter_emb = nn.Embedding(num_players + 1, hid) |
|
|
self.goal_feat = nn.Parameter(torch.randn(1, hid) * 0.01) |
|
|
self.gk_feat = nn.Parameter(torch.randn(1, hid) * 0.01) |
|
|
|
|
|
|
|
|
self.global_encoder = nn.Linear(num_global_features, hid) |
|
|
|
|
|
self.dropout = nn.Dropout(p=p) |
|
|
|
|
|
|
|
|
def mk_gat_with_edge(edge_dim: int): |
|
|
"""GAT with edge features""" |
|
|
return GATv2Conv( |
|
|
in_channels=(hid, hid), |
|
|
out_channels=hid, |
|
|
edge_dim=edge_dim, |
|
|
heads=heads, |
|
|
concat=False, |
|
|
dropout=p, |
|
|
add_self_loops=False, |
|
|
) |
|
|
|
|
|
self.convs = nn.ModuleList() |
|
|
self.norms = nn.ModuleList() if use_norm else None |
|
|
|
|
|
for _ in range(num_layers): |
|
|
conv = HeteroConv({ |
|
|
('shooter', 'shoots_at', 'goal'): mk_gat_with_edge(edge_dim=2), |
|
|
('goal', 'rev_shoots_at', 'shooter'): mk_gat_with_edge(edge_dim=2), |
|
|
('shooter', 'faces', 'goalkeeper'): mk_gat_with_edge(edge_dim=1), |
|
|
('goalkeeper', 'rev_faces', 'shooter'): mk_gat_with_edge(edge_dim=1), |
|
|
}, aggr='sum') |
|
|
self.convs.append(conv) |
|
|
|
|
|
if use_norm: |
|
|
|
|
|
self.norms.append(nn.ModuleDict({ |
|
|
'shooter': nn.LayerNorm(hid), |
|
|
'goal': nn.LayerNorm(hid), |
|
|
'goalkeeper': nn.LayerNorm(hid), |
|
|
})) |
|
|
|
|
|
|
|
|
self.output = nn.Sequential( |
|
|
nn.Linear(hid * 3, hid), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(p), |
|
|
nn.Linear(hid, hid//2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(p), |
|
|
nn.Linear(hid//2, 1), |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, data, shot_idx): |
|
|
""" |
|
|
Forward pass for a specific shot. |
|
|
|
|
|
Args: |
|
|
data: HeteroData with all nodes and edges |
|
|
shot_idx: Index of the shot to predict (for masking edges/features) |
|
|
""" |
|
|
|
|
|
shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long()) |
|
|
shooter_emb = self.dropout(shooter_emb) |
|
|
|
|
|
x = { |
|
|
'shooter': shooter_emb, |
|
|
'goal': self.goal_feat.expand(data['goal'].num_nodes, -1), |
|
|
'goalkeeper': self.gk_feat.expand(data['goalkeeper'].num_nodes, -1), |
|
|
} |
|
|
|
|
|
|
|
|
shooter_goal_mask = (data['shooter', 'shoots_at', 'goal'].shot_idx == shot_idx) |
|
|
shooter_gk_mask = (data['shooter', 'faces', 'goalkeeper'].shot_idx == shot_idx) |
|
|
|
|
|
edge_index_dict = { |
|
|
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask], |
|
|
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask].flip(0), |
|
|
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask], |
|
|
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask].flip(0), |
|
|
} |
|
|
|
|
|
edge_attr_dict = { |
|
|
('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], |
|
|
('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], |
|
|
('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], |
|
|
('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], |
|
|
} |
|
|
|
|
|
|
|
|
for li, conv in enumerate(self.convs): |
|
|
x_new = conv(x, edge_index_dict, edge_attr_dict) |
|
|
|
|
|
|
|
|
if self.norms is not None: |
|
|
for node_type in x.keys(): |
|
|
if node_type in x_new: |
|
|
x_new[node_type] = self.norms[li][node_type](x_new[node_type]) |
|
|
x[node_type] = self.dropout(x_new[node_type] + x[node_type]) |
|
|
else: |
|
|
for node_type in x.keys(): |
|
|
if node_type in x_new: |
|
|
x[node_type] = self.dropout(x_new[node_type] + x[node_type]) |
|
|
|
|
|
|
|
|
active_shooter_idx = edge_index_dict[('shooter', 'shoots_at', 'goal')][0, 0] |
|
|
shooter_repr = x['shooter'][active_shooter_idx] |
|
|
goal_repr = x['goal'][0] |
|
|
gk_repr = x['goalkeeper'][0] |
|
|
|
|
|
|
|
|
global_mask = (data['global'].shot_idx == shot_idx) |
|
|
global_feat = self.global_encoder(data['global'].x[global_mask].squeeze(0)) |
|
|
|
|
|
|
|
|
combined = torch.cat([ |
|
|
shooter_repr + global_feat, |
|
|
goal_repr, |
|
|
gk_repr |
|
|
], dim=0) |
|
|
|
|
|
return self.output(combined.unsqueeze(0)).squeeze() |
|
|
|