witgaw commited on
Commit
01fa32b
·
verified ·
1 Parent(s): b2a358c

Upload STGformer model trained on METR-LA

Browse files
Files changed (5) hide show
  1. README.md +74 -0
  2. config.json +32 -0
  3. hub_metadata.json +11 -0
  4. metadata.json +36 -0
  5. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - traffic-forecasting
4
+ - time-series
5
+ - graph-neural-network
6
+ - transformer
7
+ - stgformer
8
+ datasets:
9
+ - metr-la
10
+ ---
11
+
12
+ # STGformer Model - METR-LA
13
+
14
+ Spatio-Temporal Graph Transformer (STGformer) trained on METR-LA dataset for traffic speed forecasting.
15
+
16
+ ## Model Description
17
+
18
+ This model uses a transformer-based graph neural network architecture that combines:
19
+ - Self-attention mechanisms for capturing temporal dependencies
20
+ - Spatial graph convolution for modeling spatial relationships
21
+ - Adaptive embeddings for learning node-specific patterns
22
+ - Time-of-day embeddings for capturing daily patterns
23
+
24
+ ## Evaluation Metrics
25
+
26
+ - **Test MAE (15 min)**: 2.5637
27
+ - **Test MAPE (15 min)**: 0.0654
28
+ - **Test RMSE (15 min)**: 4.8755
29
+
30
+
31
+ ## Dataset
32
+
33
+ **METR-LA**: Traffic speed data from highway sensors.
34
+
35
+ ## Usage
36
+
37
+ ```python
38
+ from utils.stgformer import load_from_hub
39
+
40
+ # Load model from Hub
41
+ model, scaler = load_from_hub("METR-LA")
42
+
43
+ # Get predictions
44
+ import numpy as np
45
+ x = np.random.randn(10, 12, 207, 2) # (batch, seq_len, nodes, [value, tod])
46
+ predictions = model.predict(x)
47
+ ```
48
+
49
+ ## Training
50
+
51
+ Model was trained using the STGformer implementation with configuration:
52
+ - Input features: 2 [speed, time-of-day]
53
+ - Time-of-day embedding dimension: 24
54
+ - Day-of-week embedding dimension: 0 (disabled)
55
+ - Adaptive embedding dimension: 80
56
+ - Number of attention heads: 4
57
+ - Number of layers: 3
58
+
59
+ ## Citation
60
+
61
+ If you use this model, please cite the STGformer paper:
62
+
63
+ ```bibtex
64
+ @article{stgformer,
65
+ title={STGformer: Spatio-Temporal Graph Transformer for Traffic Forecasting},
66
+ author={Author names},
67
+ journal={Conference/Journal},
68
+ year={Year}
69
+ }
70
+ ```
71
+
72
+ ## License
73
+
74
+ This model checkpoint is released under the same license as the training code.
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_nodes": 207,
3
+ "in_steps": 12,
4
+ "out_steps": 12,
5
+ "input_dim": 2,
6
+ "output_dim": 1,
7
+ "steps_per_day": 288,
8
+ "input_embedding_dim": 24,
9
+ "tod_embedding_dim": 24,
10
+ "dow_embedding_dim": 0,
11
+ "adaptive_embedding_dim": 80,
12
+ "num_heads": 4,
13
+ "num_layers": 3,
14
+ "dropout": 0.1,
15
+ "dropout_a": 0.3,
16
+ "kernel_size": [
17
+ 1
18
+ ],
19
+ "epochs": 100,
20
+ "batch_size": 64,
21
+ "learning_rate": 0.001,
22
+ "weight_decay": 0.0003,
23
+ "milestones": [
24
+ 20,
25
+ 30
26
+ ],
27
+ "lr_decay_rate": 0.1,
28
+ "early_stop": 10,
29
+ "clip_grad": 0,
30
+ "device": "cuda",
31
+ "verbose": 1
32
+ }
hub_metadata.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": "METR-LA",
3
+ "upload_date": "2025-11-10T18:22:38.458773",
4
+ "metrics": {
5
+ "Test MAE (15 min)": 2.5637319087982178,
6
+ "Test MAPE (15 min)": 0.06541310995817184,
7
+ "Test RMSE (15 min)": 4.875480432556589
8
+ },
9
+ "framework": "PyTorch",
10
+ "model_type": "STGformer"
11
+ }
metadata.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config": {
3
+ "num_nodes": 207,
4
+ "in_steps": 12,
5
+ "out_steps": 12,
6
+ "input_dim": 2,
7
+ "output_dim": 1,
8
+ "steps_per_day": 288,
9
+ "input_embedding_dim": 24,
10
+ "tod_embedding_dim": 24,
11
+ "dow_embedding_dim": 0,
12
+ "adaptive_embedding_dim": 80,
13
+ "num_heads": 4,
14
+ "num_layers": 3,
15
+ "dropout": 0.1,
16
+ "dropout_a": 0.3,
17
+ "kernel_size": [
18
+ 1
19
+ ],
20
+ "epochs": 100,
21
+ "batch_size": 64,
22
+ "learning_rate": 0.001,
23
+ "weight_decay": 0.0003,
24
+ "milestones": [
25
+ 20,
26
+ 30
27
+ ],
28
+ "lr_decay_rate": 0.1,
29
+ "early_stop": 10,
30
+ "clip_grad": 0,
31
+ "device": "cuda",
32
+ "verbose": 1
33
+ },
34
+ "scaler_mean": 54.40592575073242,
35
+ "scaler_std": 19.49374008178711
36
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d596a483045a7187645aef67a34283469b1d8cc0964f8218f54465e3c79dd052
3
+ size 3530912