Alex Y commited on
Commit
589d487
·
1 Parent(s): 7df8e27

initial commit

Browse files
Files changed (2) hide show
  1. app.py +1299 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.figure import Figure
8
+ import matplotlib.patches as mpatches
9
+
10
+ import pysteps
11
+ from pysteps import io, rcparams, motion, datasets
12
+ from pysteps.motion.lucaskanade import dense_lucaskanade
13
+ from pysteps.nowcasts import linda as pysteps_linda
14
+ from pysteps.utils import conversion
15
+ from sklearn.metrics import mean_squared_error
16
+ import time
17
+ from datetime import datetime
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ print(f"Using device: {device}")
24
+
25
+ class LINDAPINNModel(nn.Module):
26
+ def __init__(self, layers=[4, 256, 256, 256, 256, 256, 1]):
27
+ super().__init__()
28
+ self.layers = nn.ModuleList()
29
+ for i in range(len(layers)-1):
30
+ self.layers.append(nn.Linear(layers[i], layers[i+1]))
31
+ if i < len(layers)-2:
32
+ nn.init.xavier_uniform_(self.layers[i].weight)
33
+
34
+
35
+ self.kernel_net = nn.Sequential(
36
+ nn.Linear(3, 128),
37
+ nn.Tanh(),
38
+ nn.Linear(128, 128),
39
+ nn.Tanh(),
40
+ nn.Linear(128, 64),
41
+ nn.Tanh(),
42
+ nn.Linear(64, 32),
43
+ nn.Tanh(),
44
+ nn.Linear(32, 1)
45
+ )
46
+ self.advection_net = nn.Sequential(
47
+ nn.Linear(4, 128),
48
+ nn.Tanh(),
49
+ nn.Linear(128, 128),
50
+ nn.Tanh(),
51
+ nn.Linear(128, 64),
52
+ nn.Tanh(),
53
+ nn.Linear(64, 32),
54
+ nn.Tanh(),
55
+ nn.Linear(32, 1),
56
+ nn.Sigmoid()
57
+ )
58
+
59
+
60
+ for net in [self.kernel_net, self.advection_net]:
61
+ for layer in net:
62
+ if isinstance(layer, nn.Linear):
63
+ nn.init.xavier_uniform_(layer.weight)
64
+
65
+ self.to(device)
66
+
67
+
68
+ self.log_sigma = nn.Parameter(torch.tensor(0.0))
69
+ self.survival_prob = nn.Parameter(torch.tensor(0.8))
70
+ self.growth_rate = nn.Parameter(torch.tensor(0.1))
71
+ self.carrying_capacity = nn.Parameter(torch.tensor(10.0))
72
+
73
+ # Move model to device
74
+ self.to(device)
75
+
76
+ def dispersal_kernel(self, dx, dy, t=None):
77
+ """LINDA redistribution kernel with learnable parameters"""
78
+ sigma = torch.exp(self.log_sigma) + 0.1
79
+
80
+ if t is not None:
81
+ if isinstance(t, (int, float)):
82
+ t_tensor = torch.full_like(dx, float(t), device=device)
83
+ else:
84
+ t_tensor = torch.full_like(dx, t.item() if hasattr(t, 'item') else float(t), device=device)
85
+
86
+ dx_flat = dx.flatten().unsqueeze(1)
87
+ dy_flat = dy.flatten().unsqueeze(1)
88
+ t_flat = t_tensor.flatten().unsqueeze(1)
89
+
90
+ kernel_input = torch.cat([dx_flat, dy_flat, t_flat], dim=1)
91
+ kernel_weight = torch.sigmoid(self.kernel_net(kernel_input))
92
+ kernel_weight = kernel_weight.reshape(dx.shape)
93
+ else:
94
+ kernel_weight = torch.tensor(1.0, device=device)
95
+
96
+ kernel = kernel_weight * torch.exp(-(dx**2 + dy**2) / (2 * sigma**2))
97
+ kernel = kernel / (2 * np.pi * sigma**2)
98
+
99
+ return kernel
100
+
101
+ def compute_integral_term(self, R_field, x_coords, y_coords, t):
102
+ """Compute the integral term in LINDA equation using FFT-based convolution"""
103
+ ny, nx = R_field.shape
104
+
105
+
106
+ x_tensor = torch.tensor(x_coords, dtype=torch.float32, device=device)
107
+ y_tensor = torch.tensor(y_coords, dtype=torch.float32, device=device)
108
+
109
+ # Create coordinate grids for kernel - centered at origin
110
+ # Use fftshift to ensure kernel is centered properly
111
+ Y_grid, X_grid = torch.meshgrid(
112
+ torch.arange(ny, dtype=torch.float32, device=device) - ny//2,
113
+ torch.arange(nx, dtype=torch.float32, device=device) - nx//2,
114
+ indexing='ij'
115
+ )
116
+
117
+ # Scale coordinates based on actual pixel sizes
118
+ if len(x_coords) > 1 and len(y_coords) > 1:
119
+ dx_scale = x_coords[1] - x_coords[0]
120
+ dy_scale = y_coords[1] - y_coords[0]
121
+ else:
122
+ dx_scale = 1.0
123
+ dy_scale = 1.0
124
+
125
+ X_grid = X_grid * dx_scale
126
+ Y_grid = Y_grid * dy_scale
127
+
128
+ # Compute dispersal kernel centered at origin
129
+ kernel = self.dispersal_kernel(X_grid, Y_grid, t)
130
+
131
+ # Normalize kernel to preserve mass
132
+ kernel = kernel / torch.sum(kernel)
133
+
134
+ # Apply fftshift to move zero frequency to center
135
+ kernel_shifted = torch.fft.fftshift(kernel)
136
+
137
+ # Compute FFT of both kernel and field
138
+ # Use rfft2 for real-valued inputs (more efficient)
139
+ kernel_fft = torch.fft.rfft2(kernel_shifted)
140
+ field_fft = torch.fft.rfft2(R_field)
141
+
142
+ # Multiply in frequency domain (convolution theorem)
143
+ convolved_fft = kernel_fft * field_fft
144
+
145
+ # Inverse FFT to get result
146
+ integral_result = torch.fft.irfft2(convolved_fft, s=(ny, nx))
147
+
148
+ # Ensure result is real and positive
149
+ integral_result = torch.real(integral_result)
150
+ integral_result = torch.clamp(integral_result, min=0.0)
151
+
152
+ # Scale by pixel area to get proper integral
153
+ pixel_area = dx_scale * dy_scale
154
+ integral_result = integral_result * pixel_area
155
+
156
+ return integral_result
157
+
158
+ def apply_advection(self, field, advection_field, metadata):
159
+ """Apply semi-Lagrangian advection to field
160
+
161
+ Args:
162
+ field: 2D tensor (ny, nx) - the field to advect
163
+ advection_field: 3D numpy array (2, ny, nx) - velocity field [u, v]
164
+ metadata: dict with pixel sizes
165
+ """
166
+ if isinstance(field, torch.Tensor):
167
+ field_np = field.cpu().numpy()
168
+ else:
169
+ field_np = field
170
+
171
+ # Get velocity components
172
+ u = advection_field[0] # x-component
173
+ v = advection_field[1] # y-component
174
+
175
+ # Get grid dimensions
176
+ ny, nx = field_np.shape
177
+
178
+ # Create coordinate grids
179
+ x = np.arange(nx)
180
+ y = np.arange(ny)
181
+ Y, X = np.meshgrid(y, x, indexing='ij')
182
+
183
+ # Time step (5 minutes in seconds)
184
+ dt = 5 * 60
185
+
186
+ # Pixel sizes in meters
187
+ dx = metadata.get('xpixelsize', 1000)
188
+ dy = metadata.get('ypixelsize', 1000)
189
+
190
+ # Convert velocities from pixels/timestep to grid units
191
+ u_grid = u * dt / dx
192
+ v_grid = v * dt / dy
193
+
194
+ # Backward trajectories
195
+ X_back = X - u_grid
196
+ Y_back = Y - v_grid
197
+
198
+ # Clip to domain
199
+ X_back = np.clip(X_back, 0, nx - 1)
200
+ Y_back = np.clip(Y_back, 0, ny - 1)
201
+
202
+ # Bilinear interpolation
203
+ from scipy.ndimage import map_coordinates
204
+ coords = np.array([Y_back.ravel(), X_back.ravel()])
205
+ advected = map_coordinates(field_np, coords, order=1, mode='constant', cval=0.0)
206
+ advected = advected.reshape(ny, nx)
207
+
208
+ # Convert back to tensor if needed
209
+ if isinstance(field, torch.Tensor):
210
+ return torch.tensor(advected, dtype=field.dtype, device=field.device)
211
+ else:
212
+ return advected
213
+
214
+ def linda_equation(self, R_current, x_coords, y_coords, t, advection_field, metadata):
215
+ """Implement the actual LINDA integro-difference equation
216
+
217
+ LINDA equation:
218
+ R(x,t+1) = s * ∫∫ K(x-y) * R(y,t) dy + growth_term + advection_term
219
+ """
220
+ # Ensure R_current is on device
221
+ if not R_current.is_cuda and device.type == 'cuda':
222
+ R_current = R_current.to(device)
223
+
224
+ # 1. Dispersal term (integral)
225
+ integral_term = self.compute_integral_term(R_current, x_coords, y_coords, t)
226
+ dispersal_term = torch.sigmoid(self.survival_prob) * integral_term
227
+
228
+ # 2. Growth term (logistic or other)
229
+ growth_rate = torch.sigmoid(self.growth_rate)
230
+ carrying_capacity = F.softplus(self.carrying_capacity) + 1.0
231
+ growth_term = growth_rate * R_current * (1 - R_current / carrying_capacity)
232
+
233
+ # 3. Advection term - this is what's missing!
234
+ if advection_field is not None:
235
+ # Apply semi-Lagrangian advection
236
+ advected_field = self.apply_advection(R_current, advection_field, metadata)
237
+ else:
238
+ advected_field = R_current
239
+
240
+ # Combine all terms according to LINDA
241
+ R_next = dispersal_term + growth_term
242
+
243
+ # Apply advection as a separate step (operator splitting)
244
+ R_next = 0.7 * R_next + 0.3 * advected_field
245
+
246
+ return torch.clamp(R_next, min=0.0)
247
+
248
+ def forward(self, R_field, x_coords, y_coords, t, advection_field=None):
249
+ """Forward pass with advection"""
250
+ if not R_field.is_cuda and device.type == 'cuda':
251
+ R_field = R_field.to(device)
252
+
253
+ return self.linda_equation(R_field, x_coords, y_coords, t, advection_field, metadata={})
254
+
255
+ class LINDAPINNTrainer:
256
+ def __init__(self, spatial_domain=(-100, 100), temporal_domain=(0, 6)):
257
+ self.model = LINDAPINNModel()
258
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=1e-5)
259
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=100)
260
+
261
+ self.x_min, self.x_max = spatial_domain
262
+ self.t_min, self.t_max = temporal_domain
263
+
264
+ # Print device info
265
+ print(f"Model initialized on device: {device}")
266
+ print(f"Model parameters on device: {next(self.model.parameters()).device}")
267
+
268
+ def prepare_training_data_from_radar(self, rainrate_sequence, metadata, advection_field=None):
269
+ """Convert radar sequence to training data for LINDA-PINN"""
270
+ nt, ny, nx = rainrate_sequence.shape
271
+
272
+ if 'xpixelsize' in metadata and 'ypixelsize' in metadata:
273
+ x_coords = np.arange(nx) * metadata['xpixelsize'] / 1000.0
274
+ y_coords = np.arange(ny) * metadata['ypixelsize'] / 1000.0
275
+ else:
276
+ x_coords = np.linspace(self.x_min, self.x_max, nx)
277
+ y_coords = np.linspace(self.x_min, self.x_max, ny)
278
+
279
+ training_pairs = []
280
+
281
+ for t in range(nt-1):
282
+ R_current = rainrate_sequence[t]
283
+ R_next = rainrate_sequence[t+1]
284
+
285
+ mask = (R_current > 0.1) | (R_next > 0.1)
286
+
287
+ if np.sum(mask) > 100:
288
+ training_pairs.append({
289
+ 'R_current': torch.tensor(R_current, dtype=torch.float32, device=device),
290
+ 'R_next': torch.tensor(R_next, dtype=torch.float32, device=device),
291
+ 'x_coords': x_coords,
292
+ 'y_coords': y_coords,
293
+ 't': float(t),
294
+ 'mask': torch.tensor(mask, dtype=torch.bool, device=device),
295
+ 'advection': advection_field, # Include advection
296
+ 'metadata': metadata
297
+ })
298
+
299
+ return training_pairs
300
+
301
+ def compute_physics_loss(self, training_pair):
302
+ """Compute physics-informed loss for proper LINDA IDE"""
303
+ R_current = training_pair['R_current']
304
+ R_target = training_pair['R_next']
305
+ x_coords = training_pair['x_coords']
306
+ y_coords = training_pair['y_coords']
307
+ t = training_pair['t']
308
+ mask = training_pair['mask']
309
+ advection = training_pair.get('advection', None)
310
+ metadata = training_pair.get('metadata', {})
311
+
312
+ # Forward pass through model
313
+ R_predicted = self.model(R_current, x_coords, y_coords, t, advection)
314
+
315
+ # 1. Data loss (supervised)
316
+ if torch.sum(mask) > 0:
317
+ data_loss = F.mse_loss(R_predicted[mask], R_target[mask])
318
+ else:
319
+ data_loss = torch.tensor(0.0, device=device)
320
+
321
+ # 2. Physics loss - enforce IDE structure
322
+ with torch.enable_grad():
323
+ # Recompute terms to check consistency
324
+ integral_term = self.model.compute_integral_term(R_current, x_coords, y_coords, t)
325
+
326
+ # Dispersal conservation
327
+ total_before = torch.sum(R_current)
328
+ total_integral = torch.sum(integral_term)
329
+ dispersal_conservation = torch.abs(total_integral - total_before) / (total_before + 1e-6)
330
+
331
+ # Growth bounds (ensure realistic growth)
332
+ growth_rate = torch.sigmoid(self.model.growth_rate)
333
+ max_growth = growth_rate * R_current * (1 - R_current / self.model.carrying_capacity)
334
+ growth_penalty = torch.mean(F.relu(max_growth - 0.5)) # Penalize excessive growth
335
+
336
+ # Advection conservation
337
+ if advection is not None:
338
+ advected = self.model.apply_advection(R_current, advection, metadata)
339
+ advection_diff = torch.mean(torch.abs(torch.sum(advected) - torch.sum(R_current)))
340
+ else:
341
+ advection_diff = torch.tensor(0.0, device=device)
342
+
343
+ # 3. Smoothness regularization
344
+ if R_predicted.shape[0] > 1 and R_predicted.shape[1] > 1:
345
+ grad_x = torch.diff(R_predicted, dim=1)
346
+ grad_y = torch.diff(R_predicted, dim=0)
347
+ smoothness_loss = torch.mean(grad_x**2) + torch.mean(grad_y**2)
348
+ else:
349
+ smoothness_loss = torch.tensor(0.0, device=device)
350
+
351
+ # 4. Parameter regularization
352
+ param_reg = (
353
+ torch.abs(self.model.log_sigma) + # Prevent extreme kernel widths
354
+ torch.abs(self.model.survival_prob - 0.8) +
355
+ torch.abs(self.model.growth_rate - 0.1)
356
+ )
357
+
358
+ # Combine losses with proper weighting
359
+ total_loss = (
360
+ data_loss +
361
+ 0.1 * dispersal_conservation +
362
+ 0.05 * growth_penalty +
363
+ 0.05 * advection_diff +
364
+ 0.01 * smoothness_loss +
365
+ 0.01 * param_reg
366
+ )
367
+
368
+ return total_loss, {
369
+ 'data_loss': data_loss.item(),
370
+ 'dispersal_conservation': dispersal_conservation.item(),
371
+ 'growth_penalty': growth_penalty.item(),
372
+ 'advection_diff': advection_diff.item(),
373
+ 'smoothness_loss': smoothness_loss.item()
374
+ }
375
+
376
+ def train_on_radar_sequence(self, rainrate_sequence, metadata, epochs=10, verbose=True):
377
+ """Train PINN on radar data sequence"""
378
+ training_data = self.prepare_training_data_from_radar(rainrate_sequence, metadata)
379
+
380
+ if len(training_data) == 0:
381
+ raise ValueError("No valid training data found!")
382
+
383
+ print(f"Created {len(training_data)} training pairs")
384
+ print(f"Training on device: {device}")
385
+
386
+ losses = []
387
+ physics_losses = []
388
+ loss_components = {
389
+ 'data_loss': [],
390
+ 'dispersal_conservation': [],
391
+ 'growth_penalty': [],
392
+ 'advection_diff': [],
393
+ 'smoothness_loss': []
394
+ }
395
+
396
+ for epoch in range(epochs):
397
+ epoch_loss = 0
398
+ epoch_physics_loss = 0
399
+ epoch_components = {k: 0 for k in loss_components.keys()}
400
+ valid_batches = 0
401
+
402
+ np.random.shuffle(training_data)
403
+
404
+ for training_pair in training_data:
405
+ self.optimizer.zero_grad()
406
+
407
+ # compute_physics_loss now returns (total_loss, loss_dict)
408
+ loss_output = self.compute_physics_loss(training_pair)
409
+
410
+ # Handle different return types
411
+ if isinstance(loss_output, tuple) and len(loss_output) == 2:
412
+ loss, loss_details = loss_output
413
+ # Extract physics loss from the dictionary
414
+ physics_loss = loss_details.get('data_loss', 0.0)
415
+
416
+ # Accumulate component losses
417
+ for key, value in loss_details.items():
418
+ if key in epoch_components:
419
+ epoch_components[key] += value
420
+ else:
421
+ # Fallback for old format
422
+ loss = loss_output
423
+ physics_loss = loss.item() if hasattr(loss, 'item') else 0.0
424
+
425
+ if loss.requires_grad and loss.item() > 0:
426
+ loss.backward()
427
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
428
+ self.optimizer.step()
429
+
430
+ epoch_loss += loss.item()
431
+ epoch_physics_loss += physics_loss
432
+ valid_batches += 1
433
+
434
+ if valid_batches > 0:
435
+ avg_loss = epoch_loss / valid_batches
436
+ avg_physics_loss = epoch_physics_loss / valid_batches
437
+
438
+ # Average component losses
439
+ for key in epoch_components:
440
+ epoch_components[key] /= valid_batches
441
+ loss_components[key].append(epoch_components[key])
442
+ else:
443
+ avg_loss = 0
444
+ avg_physics_loss = 0
445
+ for key in loss_components:
446
+ loss_components[key].append(0)
447
+
448
+ losses.append(avg_loss)
449
+ physics_losses.append(avg_physics_loss)
450
+
451
+ self.scheduler.step(avg_loss)
452
+
453
+ if verbose and epoch % 2 == 0:
454
+ print(f'Epoch {epoch}/{epochs}:')
455
+ print(f' Total Loss: {avg_loss:.6f}')
456
+ print(f' Physics Loss: {avg_physics_loss:.6f}')
457
+
458
+ # Print component losses
459
+ if valid_batches > 0:
460
+ print(f' Loss components:')
461
+ for key, value in epoch_components.items():
462
+ print(f' {key}: {value:.6f}')
463
+
464
+ print(f' Valid batches: {valid_batches}/{len(training_data)}')
465
+ print(f' Learned params: σ={torch.exp(self.model.log_sigma).item():.3f}, '
466
+ f's={torch.sigmoid(self.model.survival_prob).item():.3f}, '
467
+ f'r={torch.sigmoid(self.model.growth_rate).item():.3f}')
468
+ print(f' Learning rate: {self.optimizer.param_groups[0]["lr"]:.6f}')
469
+ print(f' GPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB' if torch.cuda.is_available() else '')
470
+ print()
471
+
472
+ return losses, physics_losses
473
+
474
+
475
+ def load_swiss_radar_data():
476
+ """Load Swiss radar data from pysteps"""
477
+ try:
478
+ # Try to use built-in datasets first
479
+ print("Attempting to download pysteps data...")
480
+ root_path = pysteps.datasets.download_pysteps_data()
481
+
482
+ # Use sample date from dataset
483
+ date = datetime.strptime("201609080000", "%Y%m%d%H%M")
484
+ data_source = "mch"
485
+
486
+ # Create file list
487
+ fns = pysteps.datasets.create_file_list(
488
+ root_path, "mchrzc12",
489
+ "201609080000", "201609081200",
490
+ timestep=5
491
+ )
492
+
493
+ if len(fns) == 0:
494
+ raise FileNotFoundError("No files found in dataset")
495
+
496
+ print(f"Found {len(fns)} radar files")
497
+
498
+ # Get importer
499
+ importer = io.get_method("mchrzc12")
500
+
501
+ # Read the data
502
+ rainrate_sequence, _, metadata = io.read_timeseries(
503
+ fns, importer,
504
+ **importer.kwargs if hasattr(importer, 'kwargs') else {}
505
+ )
506
+
507
+ print(f"Loaded radar sequence shape: {rainrate_sequence.shape}")
508
+ print(f"Pixel resolution: {metadata.get('xpixelsize', 'unknown')}m x {metadata.get('ypixelsize', 'unknown')}m")
509
+
510
+ return rainrate_sequence, metadata
511
+
512
+ except Exception as e:
513
+ print(f"Failed to load pysteps data: {e}")
514
+ print("Generating synthetic data instead...")
515
+ return generate_synthetic_data()
516
+
517
+ def generate_synthetic_data():
518
+ """Generate synthetic radar data for testing"""
519
+ np.random.seed(42)
520
+ nt, ny, nx = 12, 256, 256
521
+
522
+ # Create synthetic precipitation patterns
523
+ rainrate_sequence = np.zeros((nt, ny, nx))
524
+
525
+ for t in range(nt):
526
+ # Moving rain cells
527
+ center_x = int(nx * 0.3 + (nx * 0.4) * t / nt)
528
+ center_y = int(ny * 0.5 + 20 * np.sin(t * 0.5))
529
+
530
+ # Create Gaussian rain cell
531
+ y_grid, x_grid = np.mgrid[0:ny, 0:nx]
532
+ rain_cell = np.exp(-((x_grid - center_x)**2 + (y_grid - center_y)**2) / (2 * 30**2))
533
+
534
+ # Add some noise and evolution
535
+ evolution = 1.0 + 0.2 * np.sin(t * 0.3)
536
+ rainrate_sequence[t] = evolution * rain_cell * (5 + 2 * np.random.random())
537
+
538
+ # Add smaller cells
539
+ for i in range(2):
540
+ small_x = int(np.random.random() * nx)
541
+ small_y = int(np.random.random() * ny)
542
+ small_cell = np.exp(-((x_grid - small_x)**2 + (y_grid - small_y)**2) / (2 * 15**2))
543
+ rainrate_sequence[t] += 0.5 * small_cell * np.random.random()
544
+
545
+ # Create basic metadata
546
+ metadata = {
547
+ 'xpixelsize': 1000.0, # 1km resolution
548
+ 'ypixelsize': 1000.0,
549
+ 'unit': 'mm/h',
550
+ 'accutime': 5.0, # 5 minute accumulation
551
+ 'transform': None
552
+ }
553
+
554
+ print(f"Generated synthetic data shape: {rainrate_sequence.shape}")
555
+ return rainrate_sequence, metadata
556
+
557
+ def train_traditional_linda(rainrate_sequence, metadata):
558
+ """Train traditional LINDA model using pysteps"""
559
+ print("\n=== Training Traditional LINDA ===")
560
+
561
+ # Split data into training and testing
562
+ n_input = 3 # Use 3 timesteps for prediction
563
+ n_forecast = 128 # Predict 6 timesteps ahead
564
+
565
+ if rainrate_sequence.shape[0] < n_input + n_forecast:
566
+ print("Warning: Not enough timesteps for proper train/test split")
567
+ n_forecast = min(3, rainrate_sequence.shape[0] - n_input)
568
+
569
+ # Use first part for nowcasting setup
570
+ R_input = rainrate_sequence[:n_input]
571
+ R_truth = rainrate_sequence[n_input:n_input+n_forecast]
572
+
573
+ print(f"Input shape: {R_input.shape}")
574
+ print(f"Truth shape: {R_truth.shape}")
575
+
576
+ # Convert to rain rate format expected by pysteps
577
+ # R_input_rr = conversion.to_rainrate(R_input, metadata)
578
+ # R_input_rr[~np.isfinite(R_input_rr)] = 0.0
579
+
580
+ conv_out = conversion.to_rainrate(R_input, metadata)
581
+
582
+ # If conversion returned a tuple (arr, meta) handle it
583
+ if isinstance(conv_out, tuple) and len(conv_out) >= 1:
584
+ conv_arr = conv_out[0]
585
+ else:
586
+ conv_arr = conv_out
587
+
588
+ # If conv_arr is a list/tuple of 2D arrays, stack them.
589
+ if isinstance(conv_arr, (list, tuple)):
590
+ # ensure all elements are numeric 2D arrays and have the same shape
591
+ arrs = [np.asarray(a, dtype=np.float32) for a in conv_arr]
592
+ R_input_rr = np.stack(arrs, axis=0)
593
+ elif isinstance(conv_arr, np.ndarray) and conv_arr.dtype == object:
594
+ # object array -> try to convert each element
595
+ arrs = [np.asarray(a, dtype=np.float32) for a in conv_arr.tolist()]
596
+ R_input_rr = np.stack(arrs, axis=0)
597
+ else:
598
+ # already a ndarray of numeric dtype (either 2D or 3D)
599
+ R_input_rr = np.asarray(conv_arr, dtype=np.float32)
600
+
601
+ # Now R_input_rr is guaranteed numeric (n,ny,nx)
602
+
603
+ # IMPORT DEBUG!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
604
+ # print("R_input_rr shape after stack:", R_input_rr.shape, "dtype:", R_input_rr.dtype)
605
+
606
+ # Safe NaN/infinite replacement
607
+ finite_mask = np.isfinite(R_input_rr)
608
+ if not finite_mask.all():
609
+ R_input_rr[~finite_mask] = 0.0
610
+
611
+
612
+
613
+ # Compute motion field
614
+ print("Computing motion field...")
615
+ motion_field = dense_lucaskanade(R_input_rr)
616
+ print(f"Motion field shape: {motion_field.shape}")
617
+
618
+ # Initialize LINDA
619
+ print("Initializing LINDA...")
620
+
621
+ linda_forecast = pysteps_linda.forecast(
622
+ R_input_rr, # 3D: (n_input, ny, nx)
623
+ motion_field, # (2, ny, nx)
624
+ n_forecast,
625
+ kmperpixel=1,
626
+ timestep=5,
627
+ n_ens_members=10,
628
+ vel_pert_kwargs={"p_pert_par": [1.0, 0.1, 0.01, 0.1, 0.01]}
629
+ )
630
+
631
+ print(f"LINDA forecast shape: {linda_forecast.shape}")
632
+
633
+ return {
634
+ 'model_name': 'Traditional LINDA',
635
+ 'predictions': linda_forecast,
636
+ 'ground_truth': R_truth,
637
+ 'metadata': metadata,
638
+ 'motion_field': motion_field
639
+ }
640
+
641
+ def train_custom_pinn(rainrate_sequence, metadata):
642
+ """Train custom LINDA-PINN model"""
643
+ print("\n=== Training Custom LINDA-PINN ===")
644
+
645
+ # Initialize trainer
646
+ trainer = LINDAPINNTrainer()
647
+
648
+ # Use most of the sequence for training, keep last few for testing
649
+ n_test = 3
650
+ train_sequence = rainrate_sequence[:-n_test]
651
+ test_sequence = rainrate_sequence[-n_test-3:] # Need overlap for prediction
652
+
653
+ print(f"Training sequence shape: {train_sequence.shape}")
654
+ print(f"Test sequence shape: {test_sequence.shape}")
655
+
656
+ try:
657
+ # Train the model
658
+ start_time = time.time()
659
+ losses, physics_losses = trainer.train_on_radar_sequence(
660
+ train_sequence, metadata, epochs=10, verbose=True
661
+ )
662
+ training_time = time.time() - start_time
663
+
664
+ print(f"Training completed in {training_time:.2f} seconds")
665
+
666
+ # Make predictions on test data
667
+ print("Making predictions...")
668
+ predictions = []
669
+
670
+ # Use last 3 frames from training + first frame from test as input
671
+ input_frames = test_sequence[:4] # 4 input frames
672
+
673
+ for t in range(n_test):
674
+ if t + 3 < len(test_sequence):
675
+ current_frame = test_sequence[t + 3] # Current frame to predict from
676
+
677
+ # Create coordinate grids
678
+ ny, nx = current_frame.shape
679
+ if 'xpixelsize' in metadata and 'ypixelsize' in metadata:
680
+ x_coords = np.arange(nx) * metadata['xpixelsize'] / 1000.0
681
+ y_coords = np.arange(ny) * metadata['ypixelsize'] / 1000.0
682
+ else:
683
+ x_coords = np.linspace(-100, 100, nx)
684
+ y_coords = np.linspace(-100, 100, ny)
685
+
686
+ # Convert to tensor
687
+ current_tensor = torch.tensor(current_frame, dtype=torch.float32, device=device)
688
+
689
+ # Predict next frame
690
+ with torch.no_grad():
691
+ next_frame = trainer.model(current_tensor, x_coords, y_coords, float(t))
692
+ predictions.append(next_frame.cpu().numpy())
693
+
694
+ predictions = np.array(predictions) if predictions else np.zeros((n_test, *rainrate_sequence.shape[1:]))
695
+ ground_truth = test_sequence[4:4+len(predictions)] if len(test_sequence) > 4 else test_sequence[-len(predictions):]
696
+
697
+ return {
698
+ 'model_name': 'LINDA-PINN',
699
+ 'predictions': predictions,
700
+ 'ground_truth': ground_truth,
701
+ 'metadata': metadata,
702
+ 'training_time': training_time,
703
+ 'losses': losses,
704
+ 'physics_losses': physics_losses
705
+ }
706
+
707
+ except Exception as e:
708
+ print(f"PINN training failed: {e}")
709
+ # Return dummy results
710
+ n_pred = min(3, rainrate_sequence.shape[0] - 1)
711
+ dummy_predictions = np.zeros((n_pred, *rainrate_sequence.shape[1:]))
712
+ ground_truth = rainrate_sequence[-n_pred:]
713
+
714
+ return {
715
+ 'model_name': 'LINDA-PINN (Failed)',
716
+ 'predictions': dummy_predictions,
717
+ 'ground_truth': ground_truth,
718
+ 'metadata': metadata,
719
+ 'training_time': 0,
720
+ 'losses': [],
721
+ 'physics_losses': []
722
+ }
723
+
724
+ def compute_metrics(predictions, ground_truth):
725
+ """Compute RMSE and accuracy metrics with robust shape alignment."""
726
+
727
+ if predictions is None or ground_truth is None:
728
+ return {'rmse': float('inf'), 'mae': float('inf'), 'correlation': 0, 'accuracy': 0}
729
+
730
+ # Convert to numpy arrays
731
+ pred = np.asarray(predictions)
732
+ truth = np.asarray(ground_truth)
733
+
734
+ # Ensure we have at least (time, ny, nx)
735
+ if pred.ndim < 2 or truth.ndim < 2:
736
+ return {'rmse': float('inf'), 'mae': float('inf'), 'correlation': 0, 'accuracy': 0}
737
+
738
+ # If spatial shapes differ -> can't compare directly
739
+ # Try to support pred with an extra leading dimension (e.g., ensemble or cascade)
740
+ # Cases to handle:
741
+ # - pred.shape == truth.shape -> fine
742
+ # - pred has shape (M, T, ny, nx) while truth is (T, ny, nx) and M>1 -> average over M
743
+ # - pred has shape (K, ny, nx) while truth is (T, ny, nx) -> handle if K is multiple of T or K>=T
744
+
745
+ # Normalize to (T, ny, nx)
746
+ if pred.shape == truth.shape:
747
+ aligned_pred = pred
748
+ else:
749
+ # If pred has one extra leading dim but same spatial dims
750
+ if pred.ndim == truth.ndim + 1 and pred.shape[1:] == truth.shape:
751
+ # pred is (M, T, ny, nx) -> average over M to get (T, ny, nx)
752
+ M = pred.shape[0]
753
+ aligned_pred = np.mean(pred, axis=0)
754
+ elif pred.ndim == truth.ndim and pred.shape[1:] == truth.shape[1:]:
755
+ # pred is (K, ny, nx) and truth is (T, ny, nx)
756
+ K = pred.shape[0]
757
+ T = truth.shape[0]
758
+ if K % T == 0:
759
+ # e.g. K = groups * T -> reshape and average over groups
760
+ groups = K // T
761
+ try:
762
+ aligned_pred = pred.reshape(groups, T, *pred.shape[1:]).mean(axis=0)
763
+ except Exception:
764
+ # fallback: take first T frames
765
+ aligned_pred = pred[:T]
766
+ elif K >= T:
767
+ # take first T frames (most conservative)
768
+ aligned_pred = pred[:T]
769
+ else:
770
+ raise ValueError(f"Predictions have fewer timesteps ({K}) than ground truth ({T}).")
771
+ else:
772
+ # Shapes incompatible
773
+ raise ValueError(f"Incompatible shapes: predictions {pred.shape}, ground_truth {truth.shape}")
774
+
775
+ # Now aligned_pred and truth should have the same shape
776
+ if aligned_pred.shape != truth.shape:
777
+ raise ValueError(f"Failed to align shapes: aligned_pred {aligned_pred.shape}, truth {truth.shape}")
778
+
779
+ # Flatten and compute metrics, excluding non-finite values
780
+ pred_flat = aligned_pred.flatten()
781
+ truth_flat = truth.flatten()
782
+
783
+ valid_mask = np.isfinite(pred_flat) & np.isfinite(truth_flat)
784
+ pred_valid = pred_flat[valid_mask]
785
+ truth_valid = truth_flat[valid_mask]
786
+
787
+ if pred_valid.size == 0:
788
+ return {'rmse': float('inf'), 'mae': float('inf'), 'correlation': 0, 'accuracy': 0}
789
+
790
+ rmse = np.sqrt(mean_squared_error(truth_valid, pred_valid))
791
+ mae = np.mean(np.abs(pred_valid - truth_valid))
792
+
793
+ if np.std(pred_valid) > 0 and np.std(truth_valid) > 0:
794
+ correlation = np.corrcoef(pred_valid, truth_valid)[0, 1]
795
+ else:
796
+ correlation = 0.0
797
+
798
+ relative_error = np.abs(pred_valid - truth_valid) / (np.abs(truth_valid) + 1e-6)
799
+ accuracy = float(np.mean(relative_error < 0.2) * 100.0)
800
+
801
+ return {
802
+ 'rmse': float(rmse),
803
+ 'mae': float(mae),
804
+ 'correlation': float(correlation),
805
+ 'accuracy': accuracy,
806
+ 'valid_points': int(pred_valid.size),
807
+ 'total_points': int(pred_flat.size)
808
+ }
809
+
810
+
811
+
812
+
813
+
814
+
815
+ def print_comparison(linda_results, pinn_results):
816
+ """Print comparison of results"""
817
+ print("\n" + "="*60)
818
+ print("MODEL COMPARISON RESULTS")
819
+ print("="*60)
820
+
821
+ # Compute metrics
822
+ linda_metrics = compute_metrics(linda_results['predictions'], linda_results['ground_truth'])
823
+ pinn_metrics = compute_metrics(pinn_results['predictions'], pinn_results['ground_truth'])
824
+
825
+ # Print results
826
+ print(f"\n{linda_results['model_name']}:")
827
+ print(f" RMSE: {linda_metrics['rmse']:.4f}")
828
+ print(f" MAE: {linda_metrics['mae']:.4f}")
829
+ print(f" Correlation: {linda_metrics['correlation']:.4f}")
830
+ print(f" Accuracy (±20%): {linda_metrics['accuracy']:.2f}%")
831
+ print(f" Valid points: {linda_metrics['valid_points']}/{linda_metrics['total_points']}")
832
+
833
+ print(f"\n{pinn_results['model_name']}:")
834
+ print(f" RMSE: {pinn_metrics['rmse']:.4f}")
835
+ print(f" MAE: {pinn_metrics['mae']:.4f}")
836
+ print(f" Correlation: {pinn_metrics['correlation']:.4f}")
837
+ print(f" Accuracy (±20%): {pinn_metrics['accuracy']:.2f}%")
838
+ print(f" Valid points: {pinn_metrics['valid_points']}/{pinn_metrics['total_points']}")
839
+
840
+ if 'training_time' in pinn_results:
841
+ print(f" Training time: {pinn_results['training_time']:.2f}s")
842
+
843
+ # Determine winner
844
+ print(f"\n{'='*60}")
845
+ print("SUMMARY:")
846
+
847
+ metrics_comparison = []
848
+ if linda_metrics['rmse'] < pinn_metrics['rmse']:
849
+ metrics_comparison.append(f"RMSE: {linda_results['model_name']} wins")
850
+ elif pinn_metrics['rmse'] < linda_metrics['rmse']:
851
+ metrics_comparison.append(f"RMSE: {pinn_results['model_name']} wins")
852
+ else:
853
+ metrics_comparison.append("RMSE: Tie")
854
+
855
+ if linda_metrics['accuracy'] > pinn_metrics['accuracy']:
856
+ metrics_comparison.append(f"Accuracy: {linda_results['model_name']} wins")
857
+ elif pinn_metrics['accuracy'] > linda_metrics['accuracy']:
858
+ metrics_comparison.append(f"Accuracy: {pinn_results['model_name']} wins")
859
+ else:
860
+ metrics_comparison.append("Accuracy: Tie")
861
+
862
+ for comparison in metrics_comparison:
863
+ print(f" {comparison}")
864
+
865
+ print("="*60)
866
+
867
+
868
+ import gradio as gr
869
+ import matplotlib.pyplot as plt
870
+ import matplotlib.animation as animation
871
+ from matplotlib.animation import PillowWriter
872
+ import io
873
+ import base64
874
+ from PIL import Image
875
+
876
+ # Add this method to visualize predictions
877
+ def create_prediction_visualization(linda_results, pinn_results, max_frames=6):
878
+ """Create side-by-side visualization of predictions"""
879
+
880
+ # Get predictions and ground truth
881
+ linda_pred = linda_results['predictions']
882
+ pinn_pred = pinn_results['predictions']
883
+ ground_truth = linda_results['ground_truth']
884
+
885
+ # Handle shape mismatches
886
+ if linda_pred.ndim == 4: # (ensemble, time, ny, nx)
887
+ linda_pred = np.mean(linda_pred, axis=0)
888
+
889
+ # Determine number of frames to show
890
+ n_frames = min(max_frames, ground_truth.shape[0], linda_pred.shape[0], pinn_pred.shape[0])
891
+
892
+ # Create figure with subplots
893
+ fig, axes = plt.subplots(3, n_frames, figsize=(n_frames*3, 9))
894
+
895
+ if n_frames == 1:
896
+ axes = axes.reshape(3, 1)
897
+
898
+ vmin = 0
899
+ vmax = max(np.max(ground_truth[:n_frames]),
900
+ np.max(linda_pred[:n_frames]),
901
+ np.max(pinn_pred[:n_frames]))
902
+
903
+ for t in range(n_frames):
904
+ # Ground truth
905
+ im1 = axes[0, t].imshow(ground_truth[t], cmap='viridis', vmin=vmin, vmax=vmax)
906
+ axes[0, t].set_title(f'Truth t+{t+1}')
907
+ axes[0, t].axis('off')
908
+
909
+ # LINDA prediction
910
+ im2 = axes[1, t].imshow(linda_pred[t] if t < len(linda_pred) else np.zeros_like(ground_truth[0]),
911
+ cmap='viridis', vmin=vmin, vmax=vmax)
912
+ axes[1, t].set_title(f'LINDA t+{t+1}')
913
+ axes[1, t].axis('off')
914
+
915
+ # PINN prediction
916
+ im3 = axes[2, t].imshow(pinn_pred[t] if t < len(pinn_pred) else np.zeros_like(ground_truth[0]),
917
+ cmap='viridis', vmin=vmin, vmax=vmax)
918
+ axes[2, t].set_title(f'PINN t+{t+1}')
919
+ axes[2, t].axis('off')
920
+
921
+ # Add colorbar
922
+ fig.colorbar(im1, ax=axes, orientation='horizontal', pad=0.1, fraction=0.05)
923
+
924
+ plt.tight_layout()
925
+ return fig
926
+
927
+ def create_loss_plot(pinn_results):
928
+ """Create loss evolution plot for PINN"""
929
+ if 'losses' not in pinn_results or len(pinn_results['losses']) == 0:
930
+ return None
931
+
932
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
933
+
934
+ # Total loss
935
+ ax1.plot(pinn_results['losses'], label='Total Loss', linewidth=2)
936
+ ax1.set_xlabel('Epoch')
937
+ ax1.set_ylabel('Loss')
938
+ ax1.set_title('PINN Training Loss')
939
+ ax1.grid(True, alpha=0.3)
940
+ ax1.legend()
941
+
942
+ # Physics loss
943
+ if 'physics_losses' in pinn_results and len(pinn_results['physics_losses']) > 0:
944
+ ax2.plot(pinn_results['physics_losses'], label='Physics Loss', linewidth=2, color='orange')
945
+ ax2.set_xlabel('Epoch')
946
+ ax2.set_ylabel('Physics Loss')
947
+ ax2.set_title('Physics-Informed Loss')
948
+ ax2.grid(True, alpha=0.3)
949
+ ax2.legend()
950
+
951
+ plt.tight_layout()
952
+ return fig
953
+
954
+ # Modified training functions with parameters
955
+ def train_traditional_linda_with_params(rainrate_sequence, metadata,
956
+ n_ens_members=10,
957
+ vel_pert_p1=1.0,
958
+ vel_pert_p2=0.1,
959
+ vel_pert_p3=0.01,
960
+ vel_pert_p4=0.1,
961
+ vel_pert_p5=0.01,
962
+ kmperpixel=1,
963
+ timestep=5):
964
+ """Train traditional LINDA model with custom parameters"""
965
+ print("\n=== Training Traditional LINDA with Custom Parameters ===")
966
+
967
+ n_input = 3
968
+ n_forecast = min(6, rainrate_sequence.shape[0] - n_input)
969
+
970
+ R_input = rainrate_sequence[:n_input]
971
+ R_truth = rainrate_sequence[n_input:n_input+n_forecast]
972
+
973
+ # Convert to rain rate
974
+ conv_out = conversion.to_rainrate(R_input, metadata)
975
+ if isinstance(conv_out, tuple) and len(conv_out) >= 1:
976
+ conv_arr = conv_out[0]
977
+ else:
978
+ conv_arr = conv_out
979
+
980
+ if isinstance(conv_arr, (list, tuple)):
981
+ arrs = [np.asarray(a, dtype=np.float32) for a in conv_arr]
982
+ R_input_rr = np.stack(arrs, axis=0)
983
+ elif isinstance(conv_arr, np.ndarray) and conv_arr.dtype == object:
984
+ arrs = [np.asarray(a, dtype=np.float32) for a in conv_arr.tolist()]
985
+ R_input_rr = np.stack(arrs, axis=0)
986
+ else:
987
+ R_input_rr = np.asarray(conv_arr, dtype=np.float32)
988
+
989
+ finite_mask = np.isfinite(R_input_rr)
990
+ if not finite_mask.all():
991
+ R_input_rr[~finite_mask] = 0.0
992
+
993
+ # Compute motion field
994
+ motion_field = dense_lucaskanade(R_input_rr)
995
+
996
+ # Run LINDA with custom parameters
997
+ linda_forecast = pysteps_linda.forecast(
998
+ R_input_rr,
999
+ motion_field,
1000
+ n_forecast,
1001
+ kmperpixel=kmperpixel,
1002
+ timestep=timestep,
1003
+ n_ens_members=n_ens_members,
1004
+ vel_pert_kwargs={"p_pert_par": [vel_pert_p1, vel_pert_p2, vel_pert_p3, vel_pert_p4, vel_pert_p5]}
1005
+ )
1006
+
1007
+ return {
1008
+ 'model_name': 'Traditional LINDA',
1009
+ 'predictions': linda_forecast,
1010
+ 'ground_truth': R_truth,
1011
+ 'metadata': metadata,
1012
+ 'motion_field': motion_field
1013
+ }
1014
+
1015
+ def train_custom_pinn_with_params(rainrate_sequence, metadata,
1016
+ epochs=10,
1017
+ learning_rate=0.001,
1018
+ weight_decay=1e-5,
1019
+ batch_size=1,
1020
+ hidden_layers=256,
1021
+ num_layers=5,
1022
+ initial_sigma=0.0,
1023
+ initial_survival=0.8,
1024
+ initial_growth=0.1):
1025
+ """Train custom LINDA-PINN model with custom parameters"""
1026
+ print("\n=== Training Custom LINDA-PINN with Custom Parameters ===")
1027
+
1028
+ # Create custom model with specified architecture
1029
+ layers = [4] + [hidden_layers] * num_layers + [1]
1030
+
1031
+ # Modify the trainer to accept custom parameters
1032
+ trainer = LINDAPINNTrainer()
1033
+ trainer.model = LINDAPINNModel(layers=layers)
1034
+
1035
+ # Set initial parameters
1036
+ trainer.model.log_sigma.data = torch.tensor(initial_sigma)
1037
+ trainer.model.survival_prob.data = torch.tensor(initial_survival)
1038
+ trainer.model.growth_rate.data = torch.tensor(initial_growth)
1039
+
1040
+ # Update optimizer with custom parameters
1041
+ trainer.optimizer = torch.optim.Adam(
1042
+ trainer.model.parameters(),
1043
+ lr=learning_rate,
1044
+ weight_decay=weight_decay
1045
+ )
1046
+ trainer.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
1047
+ trainer.optimizer,
1048
+ patience=max(10, epochs//10)
1049
+ )
1050
+
1051
+ # Train/test split
1052
+ n_test = 3
1053
+ train_sequence = rainrate_sequence[:-n_test]
1054
+ test_sequence = rainrate_sequence[-n_test-3:]
1055
+
1056
+ try:
1057
+ start_time = time.time()
1058
+ losses, physics_losses = trainer.train_on_radar_sequence(
1059
+ train_sequence, metadata, epochs=epochs, verbose=True
1060
+ )
1061
+ training_time = time.time() - start_time
1062
+
1063
+ # Make predictions
1064
+ predictions = []
1065
+ for t in range(n_test):
1066
+ if t + 3 < len(test_sequence):
1067
+ current_frame = test_sequence[t + 3]
1068
+
1069
+ ny, nx = current_frame.shape
1070
+ if 'xpixelsize' in metadata and 'ypixelsize' in metadata:
1071
+ x_coords = np.arange(nx) * metadata['xpixelsize'] / 1000.0
1072
+ y_coords = np.arange(ny) * metadata['ypixelsize'] / 1000.0
1073
+ else:
1074
+ x_coords = np.linspace(-100, 100, nx)
1075
+ y_coords = np.linspace(-100, 100, ny)
1076
+
1077
+ current_tensor = torch.tensor(current_frame, dtype=torch.float32, device=device)
1078
+
1079
+ with torch.no_grad():
1080
+ next_frame = trainer.model(current_tensor, x_coords, y_coords, float(t))
1081
+ predictions.append(next_frame.cpu().numpy())
1082
+
1083
+ predictions = np.array(predictions) if predictions else np.zeros((n_test, *rainrate_sequence.shape[1:]))
1084
+ ground_truth = test_sequence[4:4+len(predictions)] if len(test_sequence) > 4 else test_sequence[-len(predictions):]
1085
+
1086
+ return {
1087
+ 'model_name': 'LINDA-PINN',
1088
+ 'predictions': predictions,
1089
+ 'ground_truth': ground_truth,
1090
+ 'metadata': metadata,
1091
+ 'training_time': training_time,
1092
+ 'losses': losses,
1093
+ 'physics_losses': physics_losses,
1094
+ 'final_params': {
1095
+ 'sigma': torch.exp(trainer.model.log_sigma).item(),
1096
+ 'survival': torch.sigmoid(trainer.model.survival_prob).item(),
1097
+ 'growth': torch.sigmoid(trainer.model.growth_rate).item()
1098
+ }
1099
+ }
1100
+
1101
+ except Exception as e:
1102
+ print(f"PINN training failed: {e}")
1103
+ n_pred = min(3, rainrate_sequence.shape[0] - 1)
1104
+ return {
1105
+ 'model_name': 'LINDA-PINN (Failed)',
1106
+ 'predictions': np.zeros((n_pred, *rainrate_sequence.shape[1:])),
1107
+ 'ground_truth': rainrate_sequence[-n_pred:],
1108
+ 'metadata': metadata,
1109
+ 'training_time': 0,
1110
+ 'losses': [],
1111
+ 'physics_losses': []
1112
+ }
1113
+
1114
+ # Main Gradio interface function
1115
+ def run_comparison(
1116
+ # LINDA parameters
1117
+ linda_n_ens_members, linda_vel_p1, linda_vel_p2, linda_vel_p3, linda_vel_p4, linda_vel_p5,
1118
+ linda_kmperpixel, linda_timestep,
1119
+ # PINN parameters
1120
+ pinn_epochs, pinn_lr, pinn_weight_decay, pinn_hidden_layers, pinn_num_layers,
1121
+ pinn_initial_sigma, pinn_initial_survival, pinn_initial_growth,
1122
+ # Data selection
1123
+ use_synthetic_data
1124
+ ):
1125
+ """Main function to run the comparison"""
1126
+
1127
+ # Load data
1128
+ if use_synthetic_data:
1129
+ rainrate_sequence, metadata = generate_synthetic_data()
1130
+ else:
1131
+ try:
1132
+ rainrate_sequence, metadata = load_swiss_radar_data()
1133
+ except:
1134
+ print("Failed to load real data, using synthetic instead")
1135
+ rainrate_sequence, metadata = generate_synthetic_data()
1136
+
1137
+ # Train LINDA
1138
+ linda_results = train_traditional_linda_with_params(
1139
+ rainrate_sequence, metadata,
1140
+ n_ens_members=int(linda_n_ens_members),
1141
+ vel_pert_p1=linda_vel_p1,
1142
+ vel_pert_p2=linda_vel_p2,
1143
+ vel_pert_p3=linda_vel_p3,
1144
+ vel_pert_p4=linda_vel_p4,
1145
+ vel_pert_p5=linda_vel_p5,
1146
+ kmperpixel=linda_kmperpixel,
1147
+ timestep=linda_timestep
1148
+ )
1149
+
1150
+ # Train PINN
1151
+ pinn_results = train_custom_pinn_with_params(
1152
+ rainrate_sequence, metadata,
1153
+ epochs=int(pinn_epochs),
1154
+ learning_rate=pinn_lr,
1155
+ weight_decay=pinn_weight_decay,
1156
+ hidden_layers=int(pinn_hidden_layers),
1157
+ num_layers=int(pinn_num_layers),
1158
+ initial_sigma=pinn_initial_sigma,
1159
+ initial_survival=pinn_initial_survival,
1160
+ initial_growth=pinn_initial_growth
1161
+ )
1162
+
1163
+ # Compute metrics
1164
+ linda_metrics = compute_metrics(linda_results['predictions'], linda_results['ground_truth'])
1165
+ pinn_metrics = compute_metrics(pinn_results['predictions'], pinn_results['ground_truth'])
1166
+
1167
+ # Create visualizations
1168
+ pred_fig = create_prediction_visualization(linda_results, pinn_results)
1169
+ loss_fig = create_loss_plot(pinn_results)
1170
+
1171
+ # Format results
1172
+ results_text = f"""
1173
+ ## Model Comparison Results
1174
+
1175
+ ### Traditional LINDA
1176
+ - **RMSE**: {linda_metrics['rmse']:.4f}
1177
+ - **MAE**: {linda_metrics['mae']:.4f}
1178
+ - **Correlation**: {linda_metrics['correlation']:.4f}
1179
+ - **Accuracy (±20%)**: {linda_metrics['accuracy']:.2f}%
1180
+
1181
+ ### LINDA-PINN
1182
+ - **RMSE**: {pinn_metrics['rmse']:.4f}
1183
+ - **MAE**: {pinn_metrics['mae']:.4f}
1184
+ - **Correlation**: {pinn_metrics['correlation']:.4f}
1185
+ - **Accuracy (±20%)**: {pinn_metrics['accuracy']:.2f}%
1186
+ - **Training Time**: {pinn_results.get('training_time', 0):.2f}s
1187
+
1188
+ ### Learned PINN Parameters
1189
+ - **Sigma**: {pinn_results.get('final_params', {}).get('sigma', 'N/A'):.3f}
1190
+ - **Survival**: {pinn_results.get('final_params', {}).get('survival', 'N/A'):.3f}
1191
+ - **Growth**: {pinn_results.get('final_params', {}).get('growth', 'N/A'):.3f}
1192
+
1193
+ ### Winner
1194
+ - **RMSE**: {'LINDA' if linda_metrics['rmse'] < pinn_metrics['rmse'] else 'PINN' if pinn_metrics['rmse'] < linda_metrics['rmse'] else 'Tie'}
1195
+ - **Accuracy**: {'LINDA' if linda_metrics['accuracy'] > pinn_metrics['accuracy'] else 'PINN' if pinn_metrics['accuracy'] > linda_metrics['accuracy'] else 'Tie'}
1196
+ """
1197
+
1198
+ return results_text, pred_fig, loss_fig
1199
+
1200
+ # Create Gradio interface
1201
+ def create_gradio_app():
1202
+ with gr.Blocks(title="LINDA vs LINDA-PINN Comparison") as app:
1203
+ gr.Markdown("""
1204
+ # LINDA vs LINDA-PINN Weather Nowcasting Comparison
1205
+
1206
+ Compare traditional LINDA with Physics-Informed Neural Network (PINN) approach for precipitation nowcasting.
1207
+ Adjust hyperparameters for both models and see how they perform!
1208
+ """)
1209
+
1210
+ with gr.Row():
1211
+ with gr.Column():
1212
+ gr.Markdown("### LINDA Parameters")
1213
+ linda_n_ens = gr.Slider(1, 50, value=10, step=1, label="Ensemble Members")
1214
+ linda_vel_p1 = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Velocity Perturbation P1")
1215
+ linda_vel_p2 = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Velocity Perturbation P2")
1216
+ linda_vel_p3 = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Velocity Perturbation P3")
1217
+ linda_vel_p4 = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Velocity Perturbation P4")
1218
+ linda_vel_p5 = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Velocity Perturbation P5")
1219
+ linda_km = gr.Slider(0.5, 5.0, value=1.0, step=0.1, label="KM per Pixel")
1220
+ linda_timestep = gr.Slider(1, 15, value=5, step=1, label="Timestep (minutes)")
1221
+
1222
+ with gr.Column():
1223
+ gr.Markdown("### PINN Parameters")
1224
+ pinn_epochs = gr.Slider(5, 100, value=10, step=5, label="Training Epochs")
1225
+ pinn_lr = gr.Slider(0.0001, 0.01, value=0.001, step=0.0001, label="Learning Rate")
1226
+ pinn_weight_decay = gr.Slider(1e-6, 1e-3, value=1e-5, step=1e-6, label="Weight Decay")
1227
+ pinn_hidden = gr.Slider(64, 512, value=256, step=64, label="Hidden Layer Size")
1228
+ pinn_layers = gr.Slider(2, 8, value=5, step=1, label="Number of Layers")
1229
+ pinn_sigma = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="Initial Log Sigma")
1230
+ pinn_survival = gr.Slider(0.1, 1.0, value=0.8, step=0.1, label="Initial Survival Probability")
1231
+ pinn_growth = gr.Slider(0.01, 0.5, value=0.1, step=0.01, label="Initial Growth Rate")
1232
+
1233
+ with gr.Row():
1234
+ use_synthetic = gr.Checkbox(value=True, label="Use Synthetic Data (faster)")
1235
+ run_btn = gr.Button("Run Comparison", variant="primary")
1236
+
1237
+ with gr.Row():
1238
+ results_output = gr.Markdown()
1239
+
1240
+ with gr.Row():
1241
+ predictions_plot = gr.Plot(label="Predictions Comparison")
1242
+ loss_plot = gr.Plot(label="PINN Training Loss")
1243
+
1244
+ run_btn.click(
1245
+ fn=run_comparison,
1246
+ inputs=[
1247
+ linda_n_ens, linda_vel_p1, linda_vel_p2, linda_vel_p3, linda_vel_p4, linda_vel_p5,
1248
+ linda_km, linda_timestep,
1249
+ pinn_epochs, pinn_lr, pinn_weight_decay, pinn_hidden, pinn_layers,
1250
+ pinn_sigma, pinn_survival, pinn_growth,
1251
+ use_synthetic
1252
+ ],
1253
+ outputs=[results_output, predictions_plot, loss_plot]
1254
+ )
1255
+
1256
+ gr.Markdown("""
1257
+ ### About
1258
+ - **LINDA**: Lagrangian Integro-Difference equation with Nowcasting and Data Assimilation
1259
+ - **PINN**: Physics-Informed Neural Network implementation of LINDA
1260
+ - Metrics shown are computed on test data (last 3 timesteps)
1261
+ """)
1262
+
1263
+ return app
1264
+
1265
+ # Launch the app
1266
+ if __name__ == "__main__":
1267
+ app = create_gradio_app()
1268
+ app.launch(share=True)
1269
+
1270
+
1271
+ # if __name__ == "__main__":
1272
+ # print("Starting LINDA vs LINDA-PINN Comparison...")
1273
+ #
1274
+ # try:
1275
+ # # Load data
1276
+ # print("Loading radar data...")
1277
+ # rainrate_sequence, metadata = load_swiss_radar_data()
1278
+ #
1279
+ # if rainrate_sequence is None or len(rainrate_sequence) < 6:
1280
+ # raise ValueError("Insufficient data for comparison")
1281
+ #
1282
+ # print(f"Data loaded successfully: {rainrate_sequence.shape}")
1283
+ # print(f"Data range: {np.min(rainrate_sequence):.3f} to {np.max(rainrate_sequence):.3f}")
1284
+ #
1285
+ # # Train traditional LINDA
1286
+ # linda_results = train_traditional_linda(rainrate_sequence, metadata)
1287
+ #
1288
+ # # Train custom PINN
1289
+ # pinn_results = train_custom_pinn(rainrate_sequence, metadata)
1290
+ #
1291
+ # # Print comparison
1292
+ # print_comparison(linda_results, pinn_results)
1293
+ # # print(linda_results)
1294
+ # print("\nComparison completed successfully!")
1295
+ #
1296
+ # except Exception as e:
1297
+ # print(f"Error in main execution: {e}")
1298
+ # import traceback
1299
+ # traceback.print_exc()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ scikit-learn
4
+ pysteps
5
+ matplotlib
6
+ gradio
7
+ scipy
8
+ pillow