Wonder-Griffin commited on
Commit
4d2983f
·
verified ·
1 Parent(s): 269e9de

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -103
README.md CHANGED
@@ -1,103 +1,128 @@
1
- ---
2
- library_name: pytorch
3
- license: mit
4
- datasets:
5
- - TorNet
6
- tags:
7
- - weather
8
- - radar
9
- - tornado
10
- - NEXRAD
11
- - MRMS
12
- - HRRR
13
- - lightning
14
- metrics:
15
- - auprc
16
- - f1
17
- - accuracy
18
- - brier
19
- - ece
20
- pipeline_tag: image-classification
21
- ---
22
-
23
- # Wonder-Griffin/tornado-super-predictor
24
-
25
- **TornadoSuperPredictor** from Storm-Oracle, trained on **TorNet (Zenodo)** patches.
26
- Outputs a tornado probability per patch (optionally with atmospheric features).
27
-
28
- ## Summary
29
-
30
- - **Data**: TorNet (official split); optional recent holdout recommended.
31
- - **Architecture**: CNN feature extractor + heads (probability, EF logits, location, timing, uncertainty).
32
- - **Temporal**: 3 volume(s) stacked as channels.
33
- - **Normalization**: zscore.
34
- - **Loss**: bce (pos_weight=2.0).
35
- - **Calibration**: Platt (A,B)=n/a,n/a; Temperature T=n/a.
36
-
37
- ## Intended Use
38
-
39
- - Research on tornado nowcasting from radar patches;
40
- - Evaluation under class imbalance with PR metrics;
41
- - **Not** an operational warning system without further validation & human oversight.
42
-
43
- ## Dataset
44
-
45
- - **Train examples**: 6
46
- - **Eval examples**: 4
47
- - **Class balance**: positives=n/a, negatives=n/a, pos_weight≈2.0
48
-
49
- ## Evaluation (threshold = 0.5)
50
-
51
- Confusion matrix (rows = truth, cols = prediction):
52
-
53
- | | Pred 0 | Pred 1 |
54
- |-------:|-------:|-------:|
55
- | True 0 | 0 | 2 |
56
- | True 1 | 0 | 2 |
57
-
58
- Metrics:
59
-
60
- - **AUPRC**: n/a
61
- - **Accuracy**: n/a
62
- - **(Optional)**: attach PR curve & reliability diagrams
63
-
64
- ## Training
65
-
66
- - Optimizer: AdamW (lr=1e-4, wd=1e-4 by default)
67
- - Batch size: n/a
68
- - Epochs: n/a
69
- - Precision: 16-mixed
70
- - Augmentations: flips/rotations/intensity jitter + optional crops
71
- - Hardware: GPU (FP16 mixed)
72
-
73
- ## How to use
74
-
75
- ```python
76
- from huggingface_hub import snapshot_download
77
- import torch, os, importlib.util, sys
78
-
79
- repo_id = "Wonder-Griffin/tornado-super-predictor"
80
- local_dir = snapshot_download(repo_id)
81
- sys.path.insert(0, local_dir)
82
-
83
- from modeling import load, apply_temperature
84
- device = "cuda" if torch.cuda.is_available() else "cpu"
85
- model = load(device=device)
86
-
87
- # x: torch.Tensor of shape (B, C, 256, 256), C = 3 * T
88
- B = 1; C = 3*3
89
- x = torch.randn(B, C, 256, 256, device=device)
90
-
91
- # atmospheric dict (optional—batch-shaped)
92
- atmo = {
93
- "cape": torch.zeros(B,1, device=device),
94
- "wind_shear": torch.zeros(B,4, device=device),
95
- "helicity": torch.zeros(B,2, device=device),
96
- "temperature": torch.zeros(B,3, device=device),
97
- "dewpoint": torch.zeros(B,2, device=device),
98
- "pressure": torch.zeros(B,1, device=device),
99
- }
100
-
101
- with torch.no_grad():
102
- out = model(x, atmo)
103
- prob = out["tornado_probability"] # (B,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ license: mit
4
+ datasets:
5
+ - TorNet
6
+ tags:
7
+ - weather
8
+ - radar
9
+ - tornado
10
+ - tornado_prediction
11
+ - NEXRAD
12
+ - MRMS
13
+ - HRRR
14
+ - lightning
15
+ metrics:
16
+ - auprc
17
+ - f1
18
+ - accuracy
19
+ - brier
20
+ - ece
21
+ pipeline_tag: image-classification
22
+ ---
23
+
24
+ # Wonder-Griffin/tornado-super-predictor
25
+
26
+ **TornadoSuperPredictor** from Storm-Oracle, trained on **TorNet (Zenodo)** patches.
27
+ Outputs a tornado probability per patch (optionally with atmospheric features).
28
+
29
+ ## Summary
30
+
31
+ - **Data**: TorNet (official split); optional recent holdout recommended.
32
+ - **Architecture**: CNN feature extractor + heads (probability, EF logits, location, timing, uncertainty).
33
+ - **Temporal**: 3 volume(s) stacked as channels.
34
+ - **Normalization**: zscore.
35
+ - **Loss**: bce (pos_weight=2.0).
36
+ - **Calibration**: Platt (A,B)=n/a,n/a; Temperature T=n/a.
37
+
38
+ ## Intended Use
39
+
40
+ - Research on tornado nowcasting from radar patches;
41
+ - Evaluation under class imbalance with PR metrics;
42
+ - **Not** an operational warning system without further validation & human oversight.
43
+
44
+ ## Dataset
45
+
46
+ - **Train examples**: 6
47
+ - **Eval examples**: 4
48
+ - **Class balance**: positives=n/a, negatives=n/a, pos_weight≈2.0
49
+
50
+ ## Evaluation (threshold = 0.5)
51
+
52
+ Confusion matrix (rows = truth, cols = prediction):
53
+
54
+ | | Pred 0 | Pred 1 |
55
+ |-------:|-------:|-------:|
56
+ | True 0 | 0 | 2 |
57
+ | True 1 | 0 | 2 |
58
+
59
+ Metrics:
60
+
61
+ - **AUPRC**: n/a
62
+ - **Accuracy**: n/a
63
+ - **(Optional)**: attach PR curve & reliability diagrams
64
+
65
+ ## Training
66
+
67
+ - Optimizer: AdamW (lr=1e-4, wd=1e-4 by default)
68
+ - Batch size: n/a
69
+ - Epochs: n/a
70
+ - Precision: 16-mixed
71
+ - Augmentations: flips/rotations/intensity jitter + optional crops
72
+ - Hardware: 1× GPU (FP16 mixed)
73
+
74
+ ## Quickstart
75
+
76
+ ```python
77
+ import torch
78
+ from transformers import AutoModel
79
+
80
+ repo = "Wonder-Griffin/TorNet-Oracle"
81
+ model = AutoModel.from_pretrained(repo, trust_remote_code=True).eval()
82
+
83
+ # Example dummy batch
84
+ B, T, H, W = 2, 1, 256, 256 # T time steps -> in_channels = 3*T (reflectivity, velocity, spectrum width?)
85
+ radar_x = torch.randn(B, 3*T, H, W)
86
+
87
+ # Atmospheric dictionary (use only what you have; shapes must be (B, dim))
88
+ atmo = {
89
+ "cape": torch.randn(B, 1),
90
+ "wind_shear": torch.randn(B, 4), # 0–1, 0–3, 0–6, deep
91
+ "helicity": torch.randn(B, 2), # 0–1, 0–3
92
+ "temperature": torch.randn(B, 3), # sfc, 850, 500
93
+ "dewpoint": torch.randn(B, 2), # sfc, 850
94
+ "pressure": torch.randn(B, 1),
95
+ }
96
+
97
+ out = model(radar_x=radar_x, atmo=atmo)
98
+ print(out.tornado_probability.shape) # (B,)
99
+ print(out.ef_scale_probs.shape) # (B, 6)
100
+ print(out.location_offset.shape) # (B, 2)
101
+ print(out.timing_predictions.shape) # (B, 3)
102
+ ---
103
+
104
+ # 3) Notes to avoid common gotchas
105
+
106
+ - **Export the class names**: Make sure `StormOracleModel` and `StormOracleConfig` are importable at the repo root via `__init__.py`. Hugging Face uses that when `trust_remote_code=True`.
107
+ - **Architectures**: The `"architectures"` array in `config.json` **must** include `"StormOracleModel"`.
108
+ - **Weights**: You already have `pytorch_model.bin`/**or** `model.safetensors`. Either is fine. Keep the filenames standard.
109
+ - **Forward signature**: With remote code, it’s okay that `forward` takes `radar_x` and `atmo`. Users pass them as keyword args as shown.
110
+ - **Version pins**: If you rely on features from newer `transformers`, keep the `transformers_version` in `config.json` current.
111
+
112
+ ---
113
+
114
+ # 4) Optional niceties
115
+
116
+ - **`hubconf.py`** (for `torch.hub` users):
117
+ ```python
118
+ from .tornado_predictor import TornadoSuperPredictor
119
+
120
+ def storm_oracle(in_channels=3, pretrained=False, hf_repo=None, map_location="cpu"):
121
+ model = TornadoSuperPredictor(in_channels=in_channels)
122
+ if pretrained and hf_repo is not None:
123
+ from huggingface_hub import hf_hub_download
124
+ path = hf_hub_download(hf_repo, filename="pytorch_model.bin")
125
+ import torch
126
+ state = torch.load(path, map_location=map_location)
127
+ model.load_state_dict(state, strict=True)
128
+ return model