File size: 5,042 Bytes
7ef71f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
---
tags:
- pytorch
- pytorch-geometric
- xG
- football
- soccer
- expected-goals
- gnn
- heterogeneous-graph
- graph-neural-network
- binary-classification
- one-graph-architecture
library_name: pytorch
---
# Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)
This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.
## Model Description
- **Architecture**: Heterogeneous GNN with GAT attention layers
- **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking
- **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
- **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
- **Global Features**: 18 shot-level contextual features
- **Hidden Dimensions**: 64
- **Number of Layers**: 3 GAT layers
- **Attention Heads**: 4
- **Dropout Rate**: 0.3
- **Framework**: PyTorch Geometric
- **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0)
## Performance Metrics
- **Test Loss**: 0.08266454190015793
- **Test AUC**: 0.609370231628418
- **Test Accuracy**: 0.14332698285579681
- **Test F1**: 0.2105599045753479
## Global Features (18 features)
The model uses the following contextual features for each shot:
- ball_closer_than_gk
- body_part_name_Left Foot
- body_part_name_Other
- body_part_name_Right Foot
- goal_dist_to_gk
- minute
- nearest_opponent_dist
- nearest_teammate_dist
- opponents_within_5m
- play_pattern_name_From Counter
- play_pattern_name_From Free Kick
- play_pattern_name_From Goal Kick
- play_pattern_name_From Keeper
- play_pattern_name_From Kick Off
- play_pattern_name_From Throw In
- play_pattern_name_Other
- play_pattern_name_Regular Play
- teammates_within_5m
## Graph Structure
### Persistent Nodes
1. **Shooter Nodes**: 496 player embeddings (dimension: 64)
2. **Goal Node**: 1 learnable goal representation (dimension: 64)
3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64)
### Edge Types (with attributes, shot-indexed)
1. **shooter → goal**: Distance (meters) + Angle to goal (radians)
2. **goal → shooter**: Reverse edges with same attributes
3. **shooter → goalkeeper**: Distance to goalkeeper (meters)
4. **goalkeeper → shooter**: Reverse edges with same attributes
All edges are indexed by shot_idx for efficient masking during prediction.
## Usage
```python
import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json
# Download files
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")
# Load configuration
with open(config_path, 'r') as f:
config = json.load(f)
# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Create model instance
model = model_module.XGNet(
num_players=config['num_players'],
hid=config['hidden_dim'],
p=config['dropout_rate'],
heads=config['num_heads'],
num_layers=config['num_layers'],
use_norm=config['use_norm'],
num_global_features=config['num_global_features']
)
# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Prepare persistent graph (build once with all players/shots)
# Then predict by passing the graph and shot_idx:
with torch.no_grad():
xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()
```
## Training Details
The model was trained with:
- **Loss Function**: Focal Loss (addresses class imbalance)
- **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4)
- **Scheduler**: ReduceLROnPlateau
- **Batch Size**: 32
- **Max Epochs**: 100
- **Early Stopping**: Patience=15 on validation F1
- **Framework**: PyTorch Lightning
## Model Architecture
The heterogeneous GNN uses:
1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper
2. **Edge-Conditioned Attention**: GAT layers with edge attributes
3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges
4. **Shot-Indexed Masking**: Efficient prediction via edge masking
5. **Global Context**: Shot-level features encoded and combined
6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier
## Key Innovation
Unlike traditional approaches that create separate graphs for each shot, this model uses:
- **One persistent graph** with all nodes
- **Shot-indexed edges** for efficient masking
- **Bidirectional message passing** for richer representations
- This architecture is more memory-efficient and allows for better batch processing
## License
MIT
|