|
|
--- |
|
|
tags: |
|
|
- pytorch |
|
|
- pytorch-geometric |
|
|
- xG |
|
|
- football |
|
|
- soccer |
|
|
- expected-goals |
|
|
- gnn |
|
|
- heterogeneous-graph |
|
|
- graph-neural-network |
|
|
- binary-classification |
|
|
- one-graph-architecture |
|
|
library_name: pytorch |
|
|
--- |
|
|
|
|
|
# Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture) |
|
|
|
|
|
This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Architecture**: Heterogeneous GNN with GAT attention layers |
|
|
- **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking |
|
|
- **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding) |
|
|
- **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper) |
|
|
- **Global Features**: 18 shot-level contextual features |
|
|
- **Hidden Dimensions**: 64 |
|
|
- **Number of Layers**: 3 GAT layers |
|
|
- **Attention Heads**: 4 |
|
|
- **Dropout Rate**: 0.3 |
|
|
- **Framework**: PyTorch Geometric |
|
|
- **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0) |
|
|
|
|
|
## Performance Metrics |
|
|
|
|
|
- **Test Loss**: 0.08266454190015793 |
|
|
- **Test AUC**: 0.609370231628418 |
|
|
- **Test Accuracy**: 0.14332698285579681 |
|
|
- **Test F1**: 0.2105599045753479 |
|
|
|
|
|
## Global Features (18 features) |
|
|
|
|
|
The model uses the following contextual features for each shot: |
|
|
|
|
|
- ball_closer_than_gk |
|
|
- body_part_name_Left Foot |
|
|
- body_part_name_Other |
|
|
- body_part_name_Right Foot |
|
|
- goal_dist_to_gk |
|
|
- minute |
|
|
- nearest_opponent_dist |
|
|
- nearest_teammate_dist |
|
|
- opponents_within_5m |
|
|
- play_pattern_name_From Counter |
|
|
- play_pattern_name_From Free Kick |
|
|
- play_pattern_name_From Goal Kick |
|
|
- play_pattern_name_From Keeper |
|
|
- play_pattern_name_From Kick Off |
|
|
- play_pattern_name_From Throw In |
|
|
- play_pattern_name_Other |
|
|
- play_pattern_name_Regular Play |
|
|
- teammates_within_5m |
|
|
|
|
|
## Graph Structure |
|
|
|
|
|
### Persistent Nodes |
|
|
1. **Shooter Nodes**: 496 player embeddings (dimension: 64) |
|
|
2. **Goal Node**: 1 learnable goal representation (dimension: 64) |
|
|
3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64) |
|
|
|
|
|
### Edge Types (with attributes, shot-indexed) |
|
|
1. **shooter → goal**: Distance (meters) + Angle to goal (radians) |
|
|
2. **goal → shooter**: Reverse edges with same attributes |
|
|
3. **shooter → goalkeeper**: Distance to goalkeeper (meters) |
|
|
4. **goalkeeper → shooter**: Reverse edges with same attributes |
|
|
|
|
|
All edges are indexed by shot_idx for efficient masking during prediction. |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch_geometric.data import HeteroData |
|
|
from huggingface_hub import hf_hub_download |
|
|
import importlib.util |
|
|
import json |
|
|
|
|
|
# Download files |
|
|
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth") |
|
|
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py") |
|
|
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json") |
|
|
|
|
|
# Load configuration |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
# Load architecture |
|
|
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path) |
|
|
model_module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(model_module) |
|
|
|
|
|
# Create model instance |
|
|
model = model_module.XGNet( |
|
|
num_players=config['num_players'], |
|
|
hid=config['hidden_dim'], |
|
|
p=config['dropout_rate'], |
|
|
heads=config['num_heads'], |
|
|
num_layers=config['num_layers'], |
|
|
use_norm=config['use_norm'], |
|
|
num_global_features=config['num_global_features'] |
|
|
) |
|
|
|
|
|
# Load weights |
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
|
model.eval() |
|
|
|
|
|
# Prepare persistent graph (build once with all players/shots) |
|
|
# Then predict by passing the graph and shot_idx: |
|
|
with torch.no_grad(): |
|
|
xg_prediction = torch.sigmoid(model(graph, shot_idx)).item() |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
The model was trained with: |
|
|
- **Loss Function**: Focal Loss (addresses class imbalance) |
|
|
- **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4) |
|
|
- **Scheduler**: ReduceLROnPlateau |
|
|
- **Batch Size**: 32 |
|
|
- **Max Epochs**: 100 |
|
|
- **Early Stopping**: Patience=15 on validation F1 |
|
|
- **Framework**: PyTorch Lightning |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
The heterogeneous GNN uses: |
|
|
1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper |
|
|
2. **Edge-Conditioned Attention**: GAT layers with edge attributes |
|
|
3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges |
|
|
4. **Shot-Indexed Masking**: Efficient prediction via edge masking |
|
|
5. **Global Context**: Shot-level features encoded and combined |
|
|
6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier |
|
|
|
|
|
## Key Innovation |
|
|
|
|
|
Unlike traditional approaches that create separate graphs for each shot, this model uses: |
|
|
- **One persistent graph** with all nodes |
|
|
- **Shot-indexed edges** for efficient masking |
|
|
- **Bidirectional message passing** for richer representations |
|
|
- This architecture is more memory-efficient and allows for better batch processing |
|
|
|
|
|
## License |
|
|
|
|
|
MIT |
|
|
|