--- tags: - pytorch - pytorch-geometric - xG - football - soccer - expected-goals - gnn - heterogeneous-graph - graph-neural-network - binary-classification - one-graph-architecture library_name: pytorch --- # Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture) This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges. ## Model Description - **Architecture**: Heterogeneous GNN with GAT attention layers - **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking - **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding) - **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper) - **Global Features**: 18 shot-level contextual features - **Hidden Dimensions**: 64 - **Number of Layers**: 3 GAT layers - **Attention Heads**: 4 - **Dropout Rate**: 0.3 - **Framework**: PyTorch Geometric - **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0) ## Performance Metrics - **Test Loss**: 0.08266454190015793 - **Test AUC**: 0.609370231628418 - **Test Accuracy**: 0.14332698285579681 - **Test F1**: 0.2105599045753479 ## Global Features (18 features) The model uses the following contextual features for each shot: - ball_closer_than_gk - body_part_name_Left Foot - body_part_name_Other - body_part_name_Right Foot - goal_dist_to_gk - minute - nearest_opponent_dist - nearest_teammate_dist - opponents_within_5m - play_pattern_name_From Counter - play_pattern_name_From Free Kick - play_pattern_name_From Goal Kick - play_pattern_name_From Keeper - play_pattern_name_From Kick Off - play_pattern_name_From Throw In - play_pattern_name_Other - play_pattern_name_Regular Play - teammates_within_5m ## Graph Structure ### Persistent Nodes 1. **Shooter Nodes**: 496 player embeddings (dimension: 64) 2. **Goal Node**: 1 learnable goal representation (dimension: 64) 3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64) ### Edge Types (with attributes, shot-indexed) 1. **shooter → goal**: Distance (meters) + Angle to goal (radians) 2. **goal → shooter**: Reverse edges with same attributes 3. **shooter → goalkeeper**: Distance to goalkeeper (meters) 4. **goalkeeper → shooter**: Reverse edges with same attributes All edges are indexed by shot_idx for efficient masking during prediction. ## Usage ```python import torch from torch_geometric.data import HeteroData from huggingface_hub import hf_hub_download import importlib.util import json # Download files model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth") architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py") config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json") # Load configuration with open(config_path, 'r') as f: config = json.load(f) # Load architecture spec = importlib.util.spec_from_file_location("model_architecture", architecture_path) model_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(model_module) # Create model instance model = model_module.XGNet( num_players=config['num_players'], hid=config['hidden_dim'], p=config['dropout_rate'], heads=config['num_heads'], num_layers=config['num_layers'], use_norm=config['use_norm'], num_global_features=config['num_global_features'] ) # Load weights model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # Prepare persistent graph (build once with all players/shots) # Then predict by passing the graph and shot_idx: with torch.no_grad(): xg_prediction = torch.sigmoid(model(graph, shot_idx)).item() ``` ## Training Details The model was trained with: - **Loss Function**: Focal Loss (addresses class imbalance) - **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4) - **Scheduler**: ReduceLROnPlateau - **Batch Size**: 32 - **Max Epochs**: 100 - **Early Stopping**: Patience=15 on validation F1 - **Framework**: PyTorch Lightning ## Model Architecture The heterogeneous GNN uses: 1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper 2. **Edge-Conditioned Attention**: GAT layers with edge attributes 3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges 4. **Shot-Indexed Masking**: Efficient prediction via edge masking 5. **Global Context**: Shot-level features encoded and combined 6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier ## Key Innovation Unlike traditional approaches that create separate graphs for each shot, this model uses: - **One persistent graph** with all nodes - **Shot-indexed edges** for efficient masking - **Bidirectional message passing** for richer representations - This architecture is more memory-efficient and allows for better batch processing ## License MIT