heterogen_focal_without_kpis / model_architecture.py
rokati's picture
Upload model_architecture.py with huggingface_hub
7a44376 verified
import torch
from torch import nn
from torch_geometric.nn import HeteroConv, global_mean_pool, GATv2Conv
class XGNet(nn.Module):
"""
Heterogeneous GNN for xG with Global Features.
Graph Structure:
- Nodes: shooter (player_id), goal (learnable)
- Edges: goal → shooter (distance, angle_to_goal, dist_to_gk, angle_to_gk)
- Global: 18 shot-level features (body part, play pattern, timing, etc.)
"""
def __init__(self, num_players: int, hid: int, p: float, heads: int, num_layers: int,
use_norm: bool, num_global_features: int = 18):
super().__init__()
# 1) Node encoders ---------------------------------------------------
self.shooter_emb = nn.Embedding(num_players + 1, hid) # +1 for padding/UNK
self.goal_feat = nn.Parameter(torch.zeros(1, hid)) # learnable goal feature
# Global feature encoder
self.global_encoder = nn.Linear(num_global_features, hid)
self.dropout = nn.Dropout(p=p)
# 2) Edge-conditioned message passing -------------------------------
def mk_gat_with_edge(edge_dim: int):
"""GAT with edge features"""
return GATv2Conv(
in_channels=(hid, hid),
out_channels=hid,
edge_dim=edge_dim,
heads=heads,
concat=False,
dropout=p,
add_self_loops=False,
)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv({
('goal', 'distance', 'shooter'): mk_gat_with_edge(edge_dim=1),
('goal', 'angle_to_goal', 'shooter'): mk_gat_with_edge(edge_dim=1),
('goal', 'dist_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
('goal', 'angle_to_gk', 'shooter'): mk_gat_with_edge(edge_dim=1),
}, aggr='sum')
self.convs.append(conv)
if use_norm:
self.norms.append(nn.LayerNorm(hid))
# 3) Read-out --------------------------------------------------------
self.output = nn.Sequential(
nn.Linear(hid, hid//2),
nn.ReLU(),
nn.Dropout(p),
nn.Linear(hid//2, 1),
)
# ----------------------------------------------------------------------
def forward(self, data):
# Prepare node feature dict
shooter_emb = self.shooter_emb(data['shooter'].x.squeeze(-1).long())
shooter_emb = self.dropout(shooter_emb)
x = {
'goal' : self.goal_feat.expand(data['goal'].num_nodes, -1),
'shooter': shooter_emb,
}
# Message passing
for li, conv in enumerate(self.convs):
x_new = conv(x, data.edge_index_dict, data.edge_attr_dict)
shooter_updated = x_new['shooter']
if self.norms is not None and len(self.norms) > 0:
shooter_updated = self.norms[li](shooter_updated)
x['shooter'] = self.dropout(shooter_updated + x['shooter'])
# Graph-level pooling (handles batches transparently)
shooter_batch = getattr(
data['shooter'], 'batch',
torch.zeros(x['shooter'].size(0), dtype=torch.long, device=x['shooter'].device)
)
g_repr = global_mean_pool(x['shooter'], shooter_batch) # (batch × hid)
# Encode global features
global_feat = self.global_encoder(data.global_features.squeeze(1)) # (batch × hid)
# Combine pooled node embeddings + global context
combined = g_repr + global_feat # Element-wise addition
return self.output(combined).squeeze(1) # xG ∈ (0, 1)