| import torch |
| from torch import nn |
| from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv |
|
|
| class XGNet(nn.Module): |
| """ |
| Heterogeneous GNN for xG with Global Features. |
| |
| Graph Structure: |
| - Nodes: shooter (player_id), goal (learnable) |
| - Edges: goal → shooter (distance, angle_to_goal, dist_to_gk, angle_to_gk) |
| - Global: 18 shot-level features (body part, play pattern, timing, etc.) |
| """ |
| def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int, |
| use_norm: bool, num_global_features: int = 18): |
| super().__init__() |
|
|
| |
| self.shooter_emb = nn.Embedding(num_players + 1, hid) |
| self.goal_feat = nn.Parameter(torch.zeros(1, hid)) |
|
|
| |
| self.global_encoder = nn.Linear(num_global_features, hid) |
|
|
| self.dropout = nn.Dropout(p=p) |
|
|
| |
| def mk_gat_with_edge(edge_dim: int): |
| """GAT with edge features""" |
| return GATv2Conv( |
| in_channels=(hid, hid), |
| out_channels=hid, |
| edge_dim=edge_dim, |
| heads=heads, |
| concat=False, |
| dropout=p, |
| add_self_loops=False, |
| ) |
|
|
| self.convs = nn.ModuleList() |
| self.norms = nn.ModuleList() |
|
|
| for _ in range(num_layers): |
| conv = HeteroConv({ |
| ('goal', 'distance', 'shooter'): mk_gat_with_edge(edge_dim=1), |
| ('goal', 'angle_to_goal', 'shooter'): mk_gat_with_edge(edge_dim=1), |
| ('goal', 'dist_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1), |
| ('goal', 'angle_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1), |
| }, aggr='sum') |
| self.convs.append(conv) |
|
|
| if use_norm: |
| self.norms.append(nn.LayerNorm(hid)) |
|
|
| |
| self.output = nn.Sequential( |
| nn.Linear(hid, hid//2), |
| nn.ReLU(), |
| nn.Dropout(p), |
| nn.Linear(hid//2, 1), |
| ) |
|
|
| |
| def forward(self, data): |
| |
| shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long()) |
| shooter_emb = self.dropout(shooter_emb) |
|
|
| x = { |
| 'goal' : self.goal_feat.expand(data['goal'].num_nodes, -1), |
| 'shooter': shooter_emb, |
| } |
|
|
| |
| for li, conv in enumerate(self.convs): |
| x_new = conv(x, data.edge_index_dict, data.edge_attr_dict) |
|
|
| shooter_updated = x_new['shooter'] |
| if self.norms is not None and len(self.norms) > 0: |
| shooter_updated = self.norms[li](shooter_updated) |
|
|
| x['shooter'] = self.dropout(shooter_updated + x['shooter']) |
|
|
| |
| shooter_batch = getattr( |
| data['shooter'], 'batch', |
| torch.zeros(x['shooter'].size(0), dtype=torch.long, device=x['shooter'].device) |
| ) |
| g_repr = global_mean_pool(x['shooter'], shooter_batch) |
|
|
| |
| global_feat = self.global_encoder(data.global_features.squeeze(1)) |
|
|
| |
| combined = g_repr + global_feat |
|
|
| return self.output(combined).squeeze(1) |
|
|