Wonder-Griffin commited on
Commit
228af26
·
verified ·
1 Parent(s): baeeefb

Upload 8 files

Browse files
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ license: mit
4
+ datasets:
5
+ - TorNet
6
+ tags:
7
+ - weather
8
+ - radar
9
+ - tornado
10
+ - NEXRAD
11
+ - MRMS
12
+ - HRRR
13
+ - lightning
14
+ metrics:
15
+ - auprc
16
+ - f1
17
+ - accuracy
18
+ - brier
19
+ - ece
20
+ pipeline_tag: image-classification
21
+ ---
22
+
23
+ # Wonder-Griffin/tornado-super-predictor
24
+
25
+ **TornadoSuperPredictor** from Storm-Oracle, trained on **TorNet (Zenodo)** patches.
26
+ Outputs a tornado probability per patch (optionally with atmospheric features).
27
+
28
+ ## Summary
29
+
30
+ - **Data**: TorNet (official split); optional recent holdout recommended.
31
+ - **Architecture**: CNN feature extractor + heads (probability, EF logits, location, timing, uncertainty).
32
+ - **Temporal**: 3 volume(s) stacked as channels.
33
+ - **Normalization**: zscore.
34
+ - **Loss**: bce (pos_weight=2.0).
35
+ - **Calibration**: Platt (A,B)=n/a,n/a; Temperature T=n/a.
36
+
37
+ ## Intended Use
38
+
39
+ - Research on tornado nowcasting from radar patches;
40
+ - Evaluation under class imbalance with PR metrics;
41
+ - **Not** an operational warning system without further validation & human oversight.
42
+
43
+ ## Dataset
44
+
45
+ - **Train examples**: 6
46
+ - **Eval examples**: 4
47
+ - **Class balance**: positives=n/a, negatives=n/a, pos_weight≈2.0
48
+
49
+ ## Evaluation (threshold = 0.5)
50
+
51
+ Confusion matrix (rows = truth, cols = prediction):
52
+
53
+ | | Pred 0 | Pred 1 |
54
+ |-------:|-------:|-------:|
55
+ | True 0 | 0 | 2 |
56
+ | True 1 | 0 | 2 |
57
+
58
+ Metrics:
59
+
60
+ - **AUPRC**: n/a
61
+ - **Accuracy**: n/a
62
+ - **(Optional)**: attach PR curve & reliability diagrams
63
+
64
+ ## Training
65
+
66
+ - Optimizer: AdamW (lr=1e-4, wd=1e-4 by default)
67
+ - Batch size: n/a
68
+ - Epochs: n/a
69
+ - Precision: 16-mixed
70
+ - Augmentations: flips/rotations/intensity jitter + optional crops
71
+ - Hardware: 1× GPU (FP16 mixed)
72
+
73
+ ## How to use
74
+
75
+ ```python
76
+ from huggingface_hub import snapshot_download
77
+ import torch, os, importlib.util, sys
78
+
79
+ repo_id = "Wonder-Griffin/tornado-super-predictor"
80
+ local_dir = snapshot_download(repo_id)
81
+ sys.path.insert(0, local_dir)
82
+
83
+ from modeling import load, apply_temperature
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ model = load(device=device)
86
+
87
+ # x: torch.Tensor of shape (B, C, 256, 256), C = 3 * T
88
+ B = 1; C = 3*3
89
+ x = torch.randn(B, C, 256, 256, device=device)
90
+
91
+ # atmospheric dict (optional—batch-shaped)
92
+ atmo = {
93
+ "cape": torch.zeros(B,1, device=device),
94
+ "wind_shear": torch.zeros(B,4, device=device),
95
+ "helicity": torch.zeros(B,2, device=device),
96
+ "temperature": torch.zeros(B,3, device=device),
97
+ "dewpoint": torch.zeros(B,2, device=device),
98
+ "pressure": torch.zeros(B,1, device=device),
99
+ }
100
+
101
+ with torch.no_grad():
102
+ out = model(x, atmo)
103
+ prob = out["tornado_probability"] # (B,)
daac0fa3/checkpoints/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb8887ef376af6359d69d3b63cb5c62190b1bbb74d1395bcacc761f91ca99a4
3
+ size 70716168
metrics.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auprc": "n/a",
3
+ "acc": "n/a",
4
+ "epochs": "n/a",
5
+ "batch_size": "n/a",
6
+ "precision": "16-mixed",
7
+ "n_pos": "n/a",
8
+ "n_neg": "n/a"
9
+ }
modelcard.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": {"train": 6, "eval": 4}, "confusion_matrix": [[0, 2], [0, 2]], "loss": "bce", "pos_weight": 2.0, "time_steps": 3, "normalize": "zscore"}
modeling.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from backend.ml_models.tornado_predictor import TornadoSuperPredictor as Model
2
+ import torch
3
+ def load(device='cpu'):
4
+ m = Model().to(device)
5
+ m.load_state_dict(torch.load('pytorch_model.bin', map_location=device))
6
+ m.eval(); return m
7
+ def apply_temperature(logits, T):
8
+ return logits / max(T,1e-6)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26084ab20a7844d2cce12cd6133cb8c169c0dc5a776edcfe0af37fa9a3540ad8
3
+ size 33218612
tornado_predictor.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🌪️ STORM ORACLE — Tornado Super-Predictor (training-ready, no placeholders)
3
+
4
+ - RadarPatternExtractor: multi-scale CNN + spatial attention pooling
5
+ - AtmosphericConditionEncoder: per-variable MLPs -> tokens -> attention -> fused vector
6
+ - Heads: probability (sigmoid), EF (logits), location (reg), timing (reg), uncertainty (sigmoid)
7
+ - Calibration: single temperature parameter (learnable/fittable after training)
8
+ - ContinuousLearner: online fine-tuning with replay buffer and EMA weights
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ # ----------------------------- Types ---------------------------------
20
+
21
+ @dataclass
22
+ class TornadoPredictionBatch:
23
+ """All outputs are BATCH TENSORS (no Python scalars)."""
24
+ tornado_probability: torch.Tensor # (B,)
25
+ ef_scale_probs: torch.Tensor # (B,6)
26
+ most_likely_ef_scale: torch.Tensor # (B,)
27
+ location_offset: torch.Tensor # (B,2)
28
+ timing_predictions: torch.Tensor # (B,3)
29
+ uncertainty_scores: torch.Tensor # (B,4) in [0,1]
30
+ radar_signatures: torch.Tensor # (B,3) [hook, meso, couplet]
31
+ atmospheric_indicators: torch.Tensor # (B,3) [cape, shear_norm, instability]
32
+ logits: Optional[torch.Tensor] = None # (B,) pre-sigmoid (for calibration/loss)
33
+
34
+
35
+ # ---------------------- Building blocks --------------------------------
36
+
37
+ class SpatialAttentionPool(nn.Module):
38
+ """
39
+ Turns a 2D feature map (B,C,H,W) into (B,C) using a learned query and MHA over H*W tokens.
40
+ """
41
+ def __init__(self, channels: int, num_heads: int = 8):
42
+ super().__init__()
43
+ self.channels = channels
44
+ self.pos_embed = nn.Parameter(torch.randn(1, channels, 1)) # simple scalar per-channel bias over tokens
45
+ self.query = nn.Parameter(torch.randn(1, 1, channels)) # learned global query token
46
+ self.attn = nn.MultiheadAttention(embed_dim=channels, num_heads=num_heads, batch_first=True)
47
+ self.ln = nn.LayerNorm(channels)
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ # x: (B,C,H,W) -> tokens: (B, H*W, C)
51
+ B, C, H, W = x.shape
52
+ tokens = x.view(B, C, H * W).transpose(1, 2) # (B, HW, C)
53
+ tokens = self.ln(tokens + self.pos_embed.expand(B, C, 1).transpose(1, 2)) # broadcast mild bias
54
+ q = self.query.expand(B, -1, -1) # (B,1,C)
55
+ pooled, _ = self.attn(q, tokens, tokens) # (B,1,C)
56
+ return pooled.squeeze(1) # (B,C)
57
+
58
+
59
+ class RadarPatternExtractor(nn.Module):
60
+ """
61
+ Advanced radar pattern extraction with spatial attention pooling.
62
+ Accepts variable input_channels (e.g., 3×T for T time steps).
63
+ """
64
+ def __init__(self, input_channels: int = 3):
65
+ super().__init__()
66
+ self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, padding=3)
67
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
68
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
69
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
70
+
71
+ self.bn4 = nn.BatchNorm2d(512)
72
+
73
+ # Specialized detectors
74
+ self.hook_echo_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1)
75
+ self.mesocyclone_detector = nn.Conv2d(512, 64, kernel_size=5, padding=2)
76
+ self.velocity_couplet_detector = nn.Conv2d(512, 64, kernel_size=3, padding=1)
77
+
78
+ # Attention pooling to summarize (B,512,H',W') -> (B,512)
79
+ self.pool = SpatialAttentionPool(512, num_heads=8)
80
+
81
+ # Combine base + specialists -> 512 + 64*3 = 704 -> project to 1024
82
+ self.proj = nn.Sequential(
83
+ nn.Linear(512 + 64 * 3, 1024),
84
+ nn.ReLU(),
85
+ nn.Dropout(0.5),
86
+ )
87
+
88
+ def forward(self, radar_data: torch.Tensor) -> Dict[str, torch.Tensor]:
89
+ # radar_data: (B,C,H,W)
90
+ x = F.relu(self.conv1(radar_data)); x = F.max_pool2d(x, 2)
91
+ x = F.relu(self.conv2(x)); x = F.max_pool2d(x, 2)
92
+ x = F.relu(self.conv3(x)); x = F.max_pool2d(x, 2)
93
+ x = F.relu(self.conv4(x)); x = self.bn4(x)
94
+
95
+ hook = F.relu(self.hook_echo_detector(x))
96
+ meso = F.relu(self.mesocyclone_detector(x))
97
+ vel = F.relu(self.velocity_couplet_detector(x))
98
+
99
+ base_vec = self.pool(x) # (B,512)
100
+ hook_vec = hook.mean(dim=(2, 3)) # (B,64)
101
+ meso_vec = meso.mean(dim=(2, 3)) # (B,64)
102
+ vel_vec = vel.mean(dim=(2, 3)) # (B,64)
103
+
104
+ fused = torch.cat([base_vec, hook_vec, meso_vec, vel_vec], dim=1) # (B,704)
105
+ combined = self.proj(fused) # (B,1024)
106
+
107
+ strengths = torch.stack([
108
+ hook_vec.mean(dim=1), # (B,)
109
+ meso_vec.mean(dim=1), # (B,)
110
+ vel_vec.mean(dim=1), # (B,)
111
+ ], dim=1) # (B,3)
112
+
113
+ return {
114
+ "combined_features": combined,
115
+ "signature_strengths": strengths, # hook, meso, velocity couplet
116
+ }
117
+
118
+
119
+ class AtmosphericConditionEncoder(nn.Module):
120
+ """
121
+ Encode environmental parameters using per-variable MLPs, then treat them as tokens and apply MHA.
122
+ """
123
+ def __init__(self):
124
+ super().__init__()
125
+ self.enc_cape = nn.Linear(1, 32)
126
+ self.enc_shear = nn.Linear(4, 64) # 0–1, 0–3, 0–6, deep
127
+ self.enc_helicity = nn.Linear(2, 32) # 0–1, 0–3
128
+ self.enc_temp = nn.Linear(3, 32) # sfc, 850, 500
129
+ self.enc_dewpoint = nn.Linear(2, 32) # sfc, 850
130
+ self.enc_pressure = nn.Linear(1, 16)
131
+
132
+ # we will embed each of the 6 groups to dim=64 and self-attend
133
+ self.to_64 = nn.ModuleDict({
134
+ "cape": nn.Linear(32, 64),
135
+ "shear": nn.Identity(), # already 64
136
+ "helicity": nn.Linear(32, 64),
137
+ "temp": nn.Linear(32, 64),
138
+ "dewpoint": nn.Linear(32, 64),
139
+ "pressure": nn.Linear(16, 64),
140
+ })
141
+ self.ln = nn.LayerNorm(64)
142
+ self.attn = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True)
143
+
144
+ self.fuse = nn.Sequential(
145
+ nn.Linear(64 * 6, 256),
146
+ nn.ReLU(),
147
+ nn.Dropout(0.3),
148
+ )
149
+
150
+ def forward(self, atmo: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
151
+ def ensure_2d(t: torch.Tensor, d: int) -> torch.Tensor:
152
+ # make (B,d)
153
+ t = t if t.ndim == 2 else t.view(-1, d)
154
+ return t
155
+
156
+ cape = ensure_2d(atmo.get("cape", torch.zeros(1, 1, device=next(self.parameters()).device)), 1)
157
+ shear= ensure_2d(atmo.get("wind_shear", torch.zeros(1, 4, device=next(self.parameters()).device)), 4)
158
+ hel = ensure_2d(atmo.get("helicity", torch.zeros(1, 2, device=next(self.parameters()).device)), 2)
159
+ temp = ensure_2d(atmo.get("temperature", torch.zeros(1, 3, device=next(self.parameters()).device)), 3)
160
+ dew = ensure_2d(atmo.get("dewpoint", torch.zeros(1, 2, device=next(self.parameters()).device)), 2)
161
+ pres = ensure_2d(atmo.get("pressure", torch.zeros(1, 1, device=next(self.parameters()).device)), 1)
162
+
163
+ cape_e = F.relu(self.enc_cape(cape)) # (B,32)
164
+ shear_e= F.relu(self.enc_shear(shear)) # (B,64)
165
+ hel_e = F.relu(self.enc_helicity(hel)) # (B,32)
166
+ temp_e = F.relu(self.enc_temp(temp)) # (B,32)
167
+ dew_e = F.relu(self.enc_dewpoint(dew)) # (B,32)
168
+ pres_e = F.relu(self.enc_pressure(pres)) # (B,16)
169
+
170
+ tokens = torch.stack([
171
+ self.ln(self.to_64["cape"](cape_e)),
172
+ self.ln(self.to_64["shear"](shear_e)),
173
+ self.ln(self.to_64["helicity"](hel_e)),
174
+ self.ln(self.to_64["temp"](temp_e)),
175
+ self.ln(self.to_64["dewpoint"](dew_e)),
176
+ self.ln(self.to_64["pressure"](pres_e)),
177
+ ], dim=1) # (B, 6, 64)
178
+
179
+ attn_out, _ = self.attn(tokens, tokens, tokens) # (B,6,64)
180
+ fused = self.fuse(attn_out.reshape(attn_out.size(0), -1)) # (B,256)
181
+
182
+ # easy indicators for explanations/QA
183
+ shear_mag = torch.linalg.vector_norm(shear, dim=-1) # (B,)
184
+ instab = cape.squeeze(-1) * shear_mag # (B,)
185
+
186
+ return {
187
+ "atmospheric_features": fused, # (B,256)
188
+ "cape_score": cape.squeeze(-1), # (B,)
189
+ "shear_magnitude": shear_mag, # (B,)
190
+ "instability_index": instab, # (B,)
191
+ }
192
+
193
+
194
+ # -------------------------- Main model --------------------------------
195
+
196
+ class TornadoSuperPredictor(nn.Module):
197
+ def __init__(self, in_channels: int = 3):
198
+ super().__init__()
199
+ self.radar_extractor = RadarPatternExtractor(input_channels=in_channels)
200
+ self.atmo_encoder = AtmosphericConditionEncoder()
201
+
202
+ fused_dim = 1024 + 256
203
+
204
+ self.prob_head = nn.Sequential(
205
+ nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4),
206
+ nn.Linear(512, 256), nn.ReLU(),
207
+ nn.Linear(256, 1)
208
+ )
209
+ self.ef_head = nn.Sequential(
210
+ nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4),
211
+ nn.Linear(512, 6)
212
+ )
213
+ self.loc_head = nn.Sequential(
214
+ nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4),
215
+ nn.Linear(512, 2)
216
+ )
217
+ self.time_head = nn.Sequential(
218
+ nn.Linear(fused_dim, 512), nn.ReLU(), nn.Dropout(0.4),
219
+ nn.Linear(512, 3)
220
+ )
221
+ self.unc_head = nn.Sequential(
222
+ nn.Linear(fused_dim, 256), nn.ReLU(),
223
+ nn.Linear(256, 4)
224
+ )
225
+
226
+ # temperature parameter for calibration (start at 1.0)
227
+ self.register_parameter("log_temperature", nn.Parameter(torch.zeros(())))
228
+
229
+ self._init_weights()
230
+
231
+ def _init_weights(self):
232
+ for m in self.modules():
233
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
234
+ if isinstance(m, nn.Linear):
235
+ nn.init.xavier_uniform_(m.weight)
236
+ else:
237
+ nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="relu")
238
+ if m.bias is not None:
239
+ nn.init.zeros_(m.bias)
240
+
241
+ @property
242
+ def temperature(self) -> torch.Tensor:
243
+ return torch.exp(self.log_temperature) # positive
244
+
245
+ def forward(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor]) -> TornadoPredictionBatch:
246
+ # radar_x: (B,C,H,W), atmo: dict of (B,dim)
247
+ r = self.radar_extractor(radar_x)
248
+ a = self.atmo_encoder(atmo)
249
+
250
+ fused = torch.cat([r["combined_features"], a["atmospheric_features"]], dim=1) # (B,1280)
251
+
252
+ logits = self.prob_head(fused).squeeze(-1) # (B,)
253
+ logits = logits / self.temperature.clamp_min(1e-6) # calibrated logits
254
+ probs = torch.sigmoid(logits) # (B,)
255
+
256
+ ef_logits = self.ef_head(fused) # (B,6)
257
+ ef_probs = F.softmax(ef_logits, dim=-1)
258
+ ef_idx = ef_probs.argmax(dim=-1)
259
+
260
+ loc = self.loc_head(fused) # (B,2)
261
+ tim = self.time_head(fused) # (B,3)
262
+ unc = torch.sigmoid(self.unc_head(fused)) # (B,4) in [0,1]
263
+
264
+ return TornadoPredictionBatch(
265
+ tornado_probability=probs,
266
+ ef_scale_probs=ef_probs,
267
+ most_likely_ef_scale=ef_idx,
268
+ location_offset=loc,
269
+ timing_predictions=tim,
270
+ uncertainty_scores=unc,
271
+ radar_signatures=r["signature_strengths"],
272
+ atmospheric_indicators=torch.stack([
273
+ a["cape_score"], a["shear_magnitude"], a["instability_index"]
274
+ ], dim=1),
275
+ logits=logits,
276
+ )
277
+
278
+
279
+ # --------------------- Continuous learning wrapper --------------------
280
+
281
+ class ContinuousLearner(nn.Module):
282
+ """
283
+ Light wrapper that adds:
284
+ - optimizer + (optional) pos_weight or focal loss
285
+ - EMA weights for stable inference during online updates
286
+ - small replay buffer to avoid catastrophic forgetting
287
+ """
288
+ def __init__(
289
+ self,
290
+ model: TornadoSuperPredictor,
291
+ lr: float = 1e-4,
292
+ wd: float = 1e-4,
293
+ use_focal: bool = False,
294
+ pos_weight: Optional[float] = None,
295
+ ema_decay: float = 0.999,
296
+ replay_capacity: int = 2048,
297
+ device: Optional[torch.device] = None,
298
+ ):
299
+ super().__init__()
300
+ self.model = model
301
+ self.device = device or next(model.parameters()).device
302
+ self.opt = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=wd)
303
+ self.use_focal = use_focal
304
+ self.pos_weight = None if pos_weight is None else torch.tensor(pos_weight, device=self.device)
305
+ self.ema_decay = ema_decay
306
+
307
+ # EMA weights
308
+ self.shadow = {k: v.detach().clone() for k, v in self.model.state_dict().items()}
309
+ self.replay_capacity = replay_capacity
310
+ self._replay = [] # list of tuples (radar_x, atmo_dict, y)
311
+
312
+ def _bce_loss(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
313
+ if self.pos_weight is not None:
314
+ return F.binary_cross_entropy_with_logits(logits, y.float(), pos_weight=self.pos_weight)
315
+ return F.binary_cross_entropy_with_logits(logits, y.float())
316
+
317
+ def _focal_loss(self, logits: torch.Tensor, y: torch.Tensor, gamma: float = 2.0, alpha: float = 0.5) -> torch.Tensor:
318
+ p = torch.sigmoid(logits)
319
+ pt = p * y + (1 - p) * (1 - y)
320
+ w = (1 - pt).pow(gamma)
321
+ at = alpha * y + (1 - alpha) * (1 - y)
322
+ loss = -(y * torch.log(p.clamp_min(1e-9)) + (1 - y) * torch.log((1 - p).clamp_min(1e-9))) * w * at
323
+ return loss.mean()
324
+
325
+ @torch.no_grad()
326
+ def _update_ema(self):
327
+ for k, v in self.model.state_dict().items():
328
+ self.shadow[k].mul_(self.ema_decay).add_(v, alpha=(1.0 - self.ema_decay))
329
+
330
+ def train_step(self, radar_x: torch.Tensor, atmo: Dict[str, torch.Tensor], y: torch.Tensor) -> Dict[str, float]:
331
+ self.model.train()
332
+ out = self.model(radar_x, atmo) # contains logits & probs
333
+
334
+ if self.use_focal:
335
+ loss = self._focal_loss(out.logits, y)
336
+ else:
337
+ loss = self._bce_loss(out.logits, y)
338
+
339
+ self.opt.zero_grad(set_to_none=True)
340
+ loss.backward()
341
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
342
+ self.opt.step()
343
+ self._update_ema()
344
+
345
+ # push to replay
346
+ if self.replay_capacity > 0:
347
+ with torch.no_grad():
348
+ if len(self._replay) >= self.replay_capacity:
349
+ self._replay.pop(0)
350
+ # store small detached copy (avoid GPU memory blowup)
351
+ self._replay.append((
352
+ radar_x.detach().cpu(),
353
+ {k: v.detach().cpu() for k, v in atmo.items()},
354
+ y.detach().cpu()
355
+ ))
356
+
357
+ with torch.no_grad():
358
+ prob = out.tornado_probability.mean().item()
359
+ return {"loss": float(loss.item()), "avg_prob": prob}
360
+
361
+ @torch.no_grad()
362
+ def ema_state_dict(self) -> Dict[str, torch.Tensor]:
363
+ return {k: v.clone() for k, v in self.shadow.items()}
364
+
365
+ @torch.no_grad()
366
+ def load_ema_weights(self):
367
+ self.model.load_state_dict(self.ema_state_dict())
368
+
369
+ def replay_step(self, batch_size: int = 16) -> Optional[Dict[str, float]]:
370
+ if not self._replay:
371
+ return None
372
+ import random
373
+ idxs = random.sample(range(len(self._replay)), k=min(batch_size, len(self._replay)))
374
+ xs = torch.cat([self._replay[i][0] for i in idxs], dim=0).to(self.device)
375
+ ys = torch.cat([self._replay[i][2] for i in idxs], dim=0).to(self.device)
376
+ atmo = {}
377
+ # stack dict fields
378
+ keys = list(self._replay[idxs[0]][1].keys())
379
+ for k in keys:
380
+ atmo[k] = torch.cat([self._replay[i][1][k] for i in idxs], dim=0).to(self.device)
381
+ return self.train_step(xs, atmo, ys)
we3uhx9k/checkpoints/best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fb0c43ae8d50c8d7bf6cc840efabfc318f274b7a552f9d5f23a8b53687ce3ba
3
+ size 78975760