rokati commited on
Commit
7ef71f2
·
verified ·
1 Parent(s): 3949a44

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +155 -0
README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - pytorch
4
+ - pytorch-geometric
5
+ - xG
6
+ - football
7
+ - soccer
8
+ - expected-goals
9
+ - gnn
10
+ - heterogeneous-graph
11
+ - graph-neural-network
12
+ - binary-classification
13
+ - one-graph-architecture
14
+ library_name: pytorch
15
+ ---
16
+
17
+ # Heterogeneous GNN xG Prediction Model (Focal Loss, One Graph Architecture)
18
+
19
+ This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using a single persistent graph with shot-indexed edges.
20
+
21
+ ## Model Description
22
+
23
+ - **Architecture**: Heterogeneous GNN with GAT attention layers
24
+ - **Graph Structure**: One persistent graph with all nodes, shot-specific edge masking
25
+ - **Nodes**: shooter (player embeddings, 496 total), goal (learnable embedding), goalkeeper (learnable embedding)
26
+ - **Edges**: 4 edge types (bidirectional shooter↔goal and shooter↔goalkeeper)
27
+ - **Global Features**: 18 shot-level contextual features
28
+ - **Hidden Dimensions**: 64
29
+ - **Number of Layers**: 3 GAT layers
30
+ - **Attention Heads**: 4
31
+ - **Dropout Rate**: 0.3
32
+ - **Framework**: PyTorch Geometric
33
+ - **Loss Function**: Focal Loss (alpha=0.8773, gamma=2.0)
34
+
35
+ ## Performance Metrics
36
+
37
+ - **Test Loss**: 0.08266454190015793
38
+ - **Test AUC**: 0.609370231628418
39
+ - **Test Accuracy**: 0.14332698285579681
40
+ - **Test F1**: 0.2105599045753479
41
+
42
+ ## Global Features (18 features)
43
+
44
+ The model uses the following contextual features for each shot:
45
+
46
+ - ball_closer_than_gk
47
+ - body_part_name_Left Foot
48
+ - body_part_name_Other
49
+ - body_part_name_Right Foot
50
+ - goal_dist_to_gk
51
+ - minute
52
+ - nearest_opponent_dist
53
+ - nearest_teammate_dist
54
+ - opponents_within_5m
55
+ - play_pattern_name_From Counter
56
+ - play_pattern_name_From Free Kick
57
+ - play_pattern_name_From Goal Kick
58
+ - play_pattern_name_From Keeper
59
+ - play_pattern_name_From Kick Off
60
+ - play_pattern_name_From Throw In
61
+ - play_pattern_name_Other
62
+ - play_pattern_name_Regular Play
63
+ - teammates_within_5m
64
+
65
+ ## Graph Structure
66
+
67
+ ### Persistent Nodes
68
+ 1. **Shooter Nodes**: 496 player embeddings (dimension: 64)
69
+ 2. **Goal Node**: 1 learnable goal representation (dimension: 64)
70
+ 3. **Goalkeeper Node**: 1 learnable goalkeeper representation (dimension: 64)
71
+
72
+ ### Edge Types (with attributes, shot-indexed)
73
+ 1. **shooter → goal**: Distance (meters) + Angle to goal (radians)
74
+ 2. **goal → shooter**: Reverse edges with same attributes
75
+ 3. **shooter → goalkeeper**: Distance to goalkeeper (meters)
76
+ 4. **goalkeeper → shooter**: Reverse edges with same attributes
77
+
78
+ All edges are indexed by shot_idx for efficient masking during prediction.
79
+
80
+ ## Usage
81
+
82
+ ```python
83
+ import torch
84
+ from torch_geometric.data import HeteroData
85
+ from huggingface_hub import hf_hub_download
86
+ import importlib.util
87
+ import json
88
+
89
+ # Download files
90
+ model_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="best_gnn_model.pth")
91
+ architecture_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="model_architecture.py")
92
+ config_path = hf_hub_download(repo_id="rokati/focal_gnn_one_graph", filename="config.json")
93
+
94
+ # Load configuration
95
+ with open(config_path, 'r') as f:
96
+ config = json.load(f)
97
+
98
+ # Load architecture
99
+ spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
100
+ model_module = importlib.util.module_from_spec(spec)
101
+ spec.loader.exec_module(model_module)
102
+
103
+ # Create model instance
104
+ model = model_module.XGNet(
105
+ num_players=config['num_players'],
106
+ hid=config['hidden_dim'],
107
+ p=config['dropout_rate'],
108
+ heads=config['num_heads'],
109
+ num_layers=config['num_layers'],
110
+ use_norm=config['use_norm'],
111
+ num_global_features=config['num_global_features']
112
+ )
113
+
114
+ # Load weights
115
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
116
+ model.eval()
117
+
118
+ # Prepare persistent graph (build once with all players/shots)
119
+ # Then predict by passing the graph and shot_idx:
120
+ with torch.no_grad():
121
+ xg_prediction = torch.sigmoid(model(graph, shot_idx)).item()
122
+ ```
123
+
124
+ ## Training Details
125
+
126
+ The model was trained with:
127
+ - **Loss Function**: Focal Loss (addresses class imbalance)
128
+ - **Optimizer**: Adam (lr=1e-3, weight_decay=1e-4)
129
+ - **Scheduler**: ReduceLROnPlateau
130
+ - **Batch Size**: 32
131
+ - **Max Epochs**: 100
132
+ - **Early Stopping**: Patience=15 on validation F1
133
+ - **Framework**: PyTorch Lightning
134
+
135
+ ## Model Architecture
136
+
137
+ The heterogeneous GNN uses:
138
+ 1. **Node Embeddings**: Learnable embeddings for shooters, goal, and goalkeeper
139
+ 2. **Edge-Conditioned Attention**: GAT layers with edge attributes
140
+ 3. **Bidirectional Message Passing**: 3 layers with forward and reverse edges
141
+ 4. **Shot-Indexed Masking**: Efficient prediction via edge masking
142
+ 5. **Global Context**: Shot-level features encoded and combined
143
+ 6. **Readout**: Concatenate shooter + goal + goalkeeper representations → MLP classifier
144
+
145
+ ## Key Innovation
146
+
147
+ Unlike traditional approaches that create separate graphs for each shot, this model uses:
148
+ - **One persistent graph** with all nodes
149
+ - **Shot-indexed edges** for efficient masking
150
+ - **Bidirectional message passing** for richer representations
151
+ - This architecture is more memory-efficient and allows for better batch processing
152
+
153
+ ## License
154
+
155
+ MIT