Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)
This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.
Model Description
- Architecture: Heterogeneous GNN with GAT attention layers
- Graph Structure: One persistent graph with all nodes, shot-specific edge masking
- Nodes: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
- Edges: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
- Global Features: 18 shot-level contextual features
- Hidden Dimensions: 64
- Number of Layers: 3 GAT layers
- Attention Heads: 4
- Dropout Rate: 0.3
- Framework: PyTorch Geometric
- Loss Function: Focal Loss (alpha=0.8773, gamma=2.0)
Performance Metrics
- Test Loss: 0.08266454190015793
- Test AUC: 0.609370231628418
- Test Accuracy: 0.14332698285579681
- Test F1: 0.2105599045753479
Global Features (18 features)
The model uses the following contextual features for each shot:
- ball_closer_than_gk
- body_part_name_Left Foot
- body_part_name_Other
- body_part_name_Right Foot
- goal_dist_to_gk
- minute
- nearest_opponent_dist
- nearest_teammate_dist
- opponents_within_5m
- play_pattern_name_From Counter
- play_pattern_name_From Free Kick
- play_pattern_name_From Goal Kick
- play_pattern_name_From Keeper
- play_pattern_name_From Kick Off
- play_pattern_name_From Throw In
- play_pattern_name_Other
- play_pattern_name_Regular Play
- teammates_within_5m
Graph Structure
Persistent Nodes
- Shooter Nodes: 496 player embeddings (dimension: 64)
- Goal Node: 1 learnable goal representation (dimension: 64)
- Goalkeeper Node: 1 learnable goalkeeper representation (dimension: 64)
Edge Types (with attributes, shot-indexed)
- shooter → goal: Distance (meters) + Angle to goal (radians)
- goal → shooter: Reverse edges with same attributes
- shooter → goalkeeper: Distance to goalkeeper (meters)
- goalkeeper → shooter: Reverse edges with same attributes
All edges are indexed by shot_idx for efficient masking during prediction.
Usage
import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json
# Download files
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")
# Load configuration
with open(config_path, 'r') as f:
config = json.load(f)
# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Create model instance
model = model_module.XGNet(
num_players=config['num_players'],
hid=config['hidden_dim'],
p=config['dropout_rate'],
heads=config['num_heads'],
num_layers=config['num_layers'],
use_norm=config['use_norm'],
num_global_features=config['num_global_features']
)
# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Prepare persistent graph (build once with all players/shots)
# Then predict by passing the graph and shot_idx:
with torch.no_grad():
xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()
Training Details
The model was trained with:
- Loss Function: Focal Loss (addresses class imbalance)
- Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
- Scheduler: ReduceLROnPlateau
- Batch Size: 32
- Max Epochs: 100
- Early Stopping: Patience=15 on validation F1
- Framework: PyTorch Lightning
Model Architecture
The heterogeneous GNN uses:
- Node Embeddings: Learnable embeddings for shooters, goal, and goalkeeper
- Edge-Conditioned Attention: GAT layers with edge attributes
- Bidirectional Message Passing: 3 layers with forward and reverse edges
- Shot-Indexed Masking: Efficient prediction via edge masking
- Global Context: Shot-level features encoded and combined
- Readout: Concatenate shooter + goal + goalkeeper representations → MLP classifier
Key Innovation
Unlike traditional approaches that create separate graphs for each shot, this model uses:
- One persistent graph with all nodes
- Shot-indexed edges for efficient masking
- Bidirectional message passing for richer representations
- This architecture is more memory-efficient and allows for better batch processing
License
MIT
- Downloads last month
- 11
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support