focal_gnn_one_graph / README.md
rokati's picture
Upload README.md with huggingface_hub
7ef71f2 verified
---
tags:
- pytorch
- pytorch-geometric
- xG
- football
- soccer
- expected-goals
- gnn
- heterogeneous-graph
- graph-neural-network
- binary-classification
- one-graph-architecture
library_name: pytorch
---
# Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)
This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.
## Model Description
- **Architecture**: Heterogeneous GNN with GAT attention layers
- **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking
- **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
- **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
- **Global Features**: 18 shot-level contextual features
- **Hidden Dimensions**: 64
- **Number of Layers**: 3 GAT layers
- **Attention Heads**: 4
- **Dropout Rate**: 0.3
- **Framework**: PyTorch Geometric
- **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0)
## Performance Metrics
- **Test Loss**: 0.08266454190015793
- **Test AUC**: 0.609370231628418
- **Test Accuracy**: 0.14332698285579681
- **Test F1**: 0.2105599045753479
## Global Features (18 features)
The model uses the following contextual features for each shot:
- ball_closer_than_gk
- body_part_name_Left Foot
- body_part_name_Other
- body_part_name_Right Foot
- goal_dist_to_gk
- minute
- nearest_opponent_dist
- nearest_teammate_dist
- opponents_within_5m
- play_pattern_name_From Counter
- play_pattern_name_From Free Kick
- play_pattern_name_From Goal Kick
- play_pattern_name_From Keeper
- play_pattern_name_From Kick Off
- play_pattern_name_From Throw In
- play_pattern_name_Other
- play_pattern_name_Regular Play
- teammates_within_5m
## Graph Structure
### Persistent Nodes
1. **Shooter Nodes**: 496 player embeddings (dimension: 64)
2. **Goal Node**: 1 learnable goal representation (dimension: 64)
3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64)
### Edge Types (with attributes, shot-indexed)
1. **shooter → goal**: Distance (meters) + Angle to goal (radians)
2. **goal → shooter**: Reverse edges with same attributes
3. **shooter → goalkeeper**: Distance to goalkeeper (meters)
4. **goalkeeper → shooter**: Reverse edges with same attributes
All edges are indexed by shot_idx for efficient masking during prediction.
## Usage
```python
import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json
# Download files
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")
# Load configuration
with open(config_path, 'r') as f:
config = json.load(f)
# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Create model instance
model = model_module.XGNet(
num_players=config['num_players'],
hid=config['hidden_dim'],
p=config['dropout_rate'],
heads=config['num_heads'],
num_layers=config['num_layers'],
use_norm=config['use_norm'],
num_global_features=config['num_global_features']
)
# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Prepare persistent graph (build once with all players/shots)
# Then predict by passing the graph and shot_idx:
with torch.no_grad():
xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()
```
## Training Details
The model was trained with:
- **Loss Function**: Focal Loss (addresses class imbalance)
- **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4)
- **Scheduler**: ReduceLROnPlateau
- **Batch Size**: 32
- **Max Epochs**: 100
- **Early Stopping**: Patience=15 on validation F1
- **Framework**: PyTorch Lightning
## Model Architecture
The heterogeneous GNN uses:
1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper
2. **Edge-Conditioned Attention**: GAT layers with edge attributes
3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges
4. **Shot-Indexed Masking**: Efficient prediction via edge masking
5. **Global Context**: Shot-level features encoded and combined
6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier
## Key Innovation
Unlike traditional approaches that create separate graphs for each shot, this model uses:
- **One persistent graph** with all nodes
- **Shot-indexed edges** for efficient masking
- **Bidirectional message passing** for richer representations
- This architecture is more memory-efficient and allows for better batch processing
## License
MIT