Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)

This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.

Model Description

  • Architecture: Heterogeneous GNN with GAT attention layers
  • Graph Structure: One persistent graph with all nodes, shot-specific edge masking
    • Nodes: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
    • Edges: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
    • Global Features: 18 shot-level contextual features
  • Hidden Dimensions: 64
  • Number of Layers: 3 GAT layers
  • Attention Heads: 4
  • Dropout Rate: 0.3
  • Framework: PyTorch Geometric
  • Loss Function: Focal Loss (alpha=0.8773, gamma=2.0)

Performance Metrics

  • Test Loss: 0.08266454190015793
  • Test AUC: 0.609370231628418
  • Test Accuracy: 0.14332698285579681
  • Test F1: 0.2105599045753479

Global Features (18 features)

The model uses the following contextual features for each shot:

  • ball_closer_than_gk
  • body_part_name_Left Foot
  • body_part_name_Other
  • body_part_name_Right Foot
  • goal_dist_to_gk
  • minute
  • nearest_opponent_dist
  • nearest_teammate_dist
  • opponents_within_5m
  • play_pattern_name_From Counter
  • play_pattern_name_From Free Kick
  • play_pattern_name_From Goal Kick
  • play_pattern_name_From Keeper
  • play_pattern_name_From Kick Off
  • play_pattern_name_From Throw In
  • play_pattern_name_Other
  • play_pattern_name_Regular Play
  • teammates_within_5m

Graph Structure

Persistent Nodes

  1. Shooter Nodes: 496 player embeddings (dimension: 64)
  2. Goal Node: 1 learnable goal representation (dimension: 64)
  3. Goalkeeper Node: 1 learnable goalkeeper representation (dimension: 64)

Edge Types (with attributes, shot-indexed)

  1. shooter → goal: Distance (meters) + Angle to goal (radians)
  2. goal → shooter: Reverse edges with same attributes
  3. shooter → goalkeeper: Distance to goalkeeper (meters)
  4. goalkeeper → shooter: Reverse edges with same attributes

All edges are indexed by shot_idx for efficient masking during prediction.

Usage

import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json

# Download files
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")

# Load configuration
with open(config_path, 'r') as f:
    config = json.load(f)

# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)

# Create model instance
model = model_module.XGNet(
    num_players=config['num_players'],
    hid=config['hidden_dim'],
    p=config['dropout_rate'],
    heads=config['num_heads'],
    num_layers=config['num_layers'],
    use_norm=config['use_norm'],
    num_global_features=config['num_global_features']
)

# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Prepare persistent graph (build once with all players/shots)
# Then predict by passing the graph and shot_idx:
with torch.no_grad():
    xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()

Training Details

The model was trained with:

  • Loss Function: Focal Loss (addresses class imbalance)
  • Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
  • Scheduler: ReduceLROnPlateau
  • Batch Size: 32
  • Max Epochs: 100
  • Early Stopping: Patience=15 on validation F1
  • Framework: PyTorch Lightning

Model Architecture

The heterogeneous GNN uses:

  1. Node Embeddings: Learnable embeddings for shooters, goal, and goalkeeper
  2. Edge-Conditioned Attention: GAT layers with edge attributes
  3. Bidirectional Message Passing: 3 layers with forward and reverse edges
  4. Shot-Indexed Masking: Efficient prediction via edge masking
  5. Global Context: Shot-level features encoded and combined
  6. Readout: Concatenate shooter + goal + goalkeeper representations → MLP classifier

Key Innovation

Unlike traditional approaches that create separate graphs for each shot, this model uses:

  • One persistent graph with all nodes
  • Shot-indexed edges for efficient masking
  • Bidirectional message passing for richer representations
  • This architecture is more memory-efficient and allows for better batch processing

License

MIT

Downloads last month
11
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support