File size: 5,042 Bytes
7ef71f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
---
tags:
- pytorch
- pytorch-geometric
- xG
- football
- soccer
- expected-goals
- gnn
- heterogeneous-graph
- graph-neural-network
- binary-classification
- one-graph-architecture
library_name: pytorch
---

# Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)

This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.

## Model Description

- **Architecture**: Heterogeneous GNN with GAT attention layers
- **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking
  - **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
  - **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
  - **Global Features**: 18 shot-level contextual features
- **Hidden Dimensions**: 64
- **Number of Layers**: 3 GAT layers
- **Attention Heads**: 4
- **Dropout Rate**: 0.3
- **Framework**: PyTorch Geometric
- **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0)

## Performance Metrics

- **Test Loss**: 0.08266454190015793
- **Test AUC**: 0.609370231628418
- **Test Accuracy**: 0.14332698285579681
- **Test F1**: 0.2105599045753479

## Global Features (18 features)

The model uses the following contextual features for each shot:

- ball_closer_than_gk
- body_part_name_Left Foot
- body_part_name_Other
- body_part_name_Right Foot
- goal_dist_to_gk
- minute
- nearest_opponent_dist
- nearest_teammate_dist
- opponents_within_5m
- play_pattern_name_From Counter
- play_pattern_name_From Free Kick
- play_pattern_name_From Goal Kick
- play_pattern_name_From Keeper
- play_pattern_name_From Kick Off
- play_pattern_name_From Throw In
- play_pattern_name_Other
- play_pattern_name_Regular Play
- teammates_within_5m

## Graph Structure

### Persistent Nodes
1. **Shooter Nodes**: 496 player embeddings (dimension: 64)
2. **Goal Node**: 1 learnable goal representation (dimension: 64)
3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64)

### Edge Types (with attributes, shot-indexed)
1. **shooter → goal**: Distance (meters) + Angle to goal (radians)
2. **goal → shooter**: Reverse edges with same attributes
3. **shooter → goalkeeper**: Distance to goalkeeper (meters)
4. **goalkeeper → shooter**: Reverse edges with same attributes

All edges are indexed by shot_idx for efficient masking during prediction.

## Usage

```python
import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json

# Download files
model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")

# Load configuration
with open(config_path, 'r') as f:
    config = json.load(f)

# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)

# Create model instance
model = model_module.XGNet(
    num_players=config['num_players'],
    hid=config['hidden_dim'],
    p=config['dropout_rate'],
    heads=config['num_heads'],
    num_layers=config['num_layers'],
    use_norm=config['use_norm'],
    num_global_features=config['num_global_features']
)

# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Prepare persistent graph (build once with all players/shots)
# Then predict by passing the graph and shot_idx:
with torch.no_grad():
    xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()
```

## Training Details

The model was trained with:
- **Loss Function**: Focal Loss (addresses class imbalance)
- **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4)
- **Scheduler**: ReduceLROnPlateau
- **Batch Size**: 32
- **Max Epochs**: 100
- **Early Stopping**: Patience=15 on validation F1
- **Framework**: PyTorch Lightning

## Model Architecture

The heterogeneous GNN uses:
1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper
2. **Edge-Conditioned Attention**: GAT layers with edge attributes
3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges
4. **Shot-Indexed Masking**: Efficient prediction via edge masking
5. **Global Context**: Shot-level features encoded and combined
6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier

## Key Innovation

Unlike traditional approaches that create separate graphs for each shot, this model uses:
- **One persistent graph** with all nodes
- **Shot-indexed edges** for efficient masking
- **Bidirectional message passing** for richer representations
- This architecture is more memory-efficient and allows for better batch processing

## License

MIT