import torch from torch import nn from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv class XGNet(nn.Module): """ Heterogeneous GNN for xG with Persistent Nodes and Shot-Indexed Edges. Graph Structure: - Nodes: shooter (num_players, persistent), goal (1, persistent), goalkeeper (1, persistent) - Edges: shooter -> goal (distance, angle), shooter -> goalkeeper (dist_to_gk) - Global: shot-level contextual features (18 features) - Masking: All edges and global features indexed by shot_idx for prediction """ def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int, use_norm: bool, num_global_features: int = 18): super().__init__() # 1) Node encoders --------------------------------------------------- self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK self.goal_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goal self.gk_feat = nn.Parameter(torch.randn(1, hid) * 0.01) # learnable goalkeeper # Global feature encoder self.global_encoder = nn.Linear(num_global_features, hid) self.dropout = nn.Dropout(p=p) # 2) Edge-conditioned message passing ------------------------------- def mk_gat_with_edge(edge_dim: int): """GAT with edge features""" return GATv2Conv( in_channels=(hid, hid), out_channels=hid, edge_dim=edge_dim, heads=heads, concat=False, dropout=p, add_self_loops=False, ) self.convs = nn.ModuleList() self.norms = nn.ModuleList() if use_norm else None for _ in range(num_layers): conv = HeteroConv({ ('shooter', 'shoots_at', 'goal'): mk_gat_with_edge(edge_dim=2), # distance + angle ('goal', 'rev_shoots_at', 'shooter'): mk_gat_with_edge(edge_dim=2), # reverse edge ('shooter', 'faces', 'goalkeeper'): mk_gat_with_edge(edge_dim=1), # dist_to_gk ('goalkeeper', 'rev_faces', 'shooter'): mk_gat_with_edge(edge_dim=1), # reverse edge }, aggr='sum') self.convs.append(conv) if use_norm: # Normalize for each node type self.norms.append(nn.ModuleDict({ 'shooter': nn.LayerNorm(hid), 'goal': nn.LayerNorm(hid), 'goalkeeper': nn.LayerNorm(hid), })) # 3) Read-out -------------------------------------------------------- self.output = nn.Sequential( nn.Linear(hid * 3, hid), # Combine shooter + goal + goalkeeper nn.ReLU(), nn.Dropout(p), nn.Linear(hid, hid//2), nn.ReLU(), nn.Dropout(p), nn.Linear(hid//2, 1), ) # ---------------------------------------------------------------------- def forward(self, data, shot_idx): """ Forward pass for a specific shot. Args: data: HeteroData with all nodes and edges shot_idx: Index of the shot to predict (for masking edges/features) """ # Prepare node feature dict (all persistent nodes) shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long()) shooter_emb = self.dropout(shooter_emb) x = { 'shooter': shooter_emb, 'goal': self.goal_feat.expand(data['goal'].num_nodes, -1), 'goalkeeper': self.gk_feat.expand(data['goalkeeper'].num_nodes, -1), } # Mask edges for this specific shot shooter_goal_mask = (data['shooter', 'shoots_at', 'goal'].shot_idx == shot_idx) shooter_gk_mask = (data['shooter', 'faces', 'goalkeeper'].shot_idx == shot_idx) edge_index_dict = { ('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask], ('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_index[:, shooter_goal_mask].flip(0), # reverse ('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask], ('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_index[:, shooter_gk_mask].flip(0), # reverse } edge_attr_dict = { ('shooter', 'shoots_at', 'goal'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], ('goal', 'rev_shoots_at', 'shooter'): data['shooter', 'shoots_at', 'goal'].edge_attr[shooter_goal_mask], # same attributes ('shooter', 'faces', 'goalkeeper'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], ('goalkeeper', 'rev_faces', 'shooter'): data['shooter', 'faces', 'goalkeeper'].edge_attr[shooter_gk_mask], # same attributes } # Message passing with masked edges for li, conv in enumerate(self.convs): x_new = conv(x, edge_index_dict, edge_attr_dict) # Apply normalization and residual connection if self.norms is not None: for node_type in x.keys(): if node_type in x_new: x_new[node_type] = self.norms[li][node_type](x_new[node_type]) x[node_type] = self.dropout(x_new[node_type] + x[node_type]) else: for node_type in x.keys(): if node_type in x_new: x[node_type] = self.dropout(x_new[node_type] + x[node_type]) # Get the active shooter for this shot active_shooter_idx = edge_index_dict[('shooter', 'shoots_at', 'goal')][0, 0] shooter_repr = x['shooter'][active_shooter_idx] # (hid,) goal_repr = x['goal'][0] # (hid,) gk_repr = x['goalkeeper'][0] # (hid,) # Get global features for this shot global_mask = (data['global'].shot_idx == shot_idx) global_feat = self.global_encoder(data['global'].x[global_mask].squeeze(0)) # (hid,) # Combine all representations combined = torch.cat([ shooter_repr + global_feat, # Shooter with context goal_repr, gk_repr ], dim=0) # (hid * 3,) return self.output(combined.unsqueeze(0)).squeeze() # scalar xG prediction