Lilly Makkos commited on
Commit
ef16512
·
0 Parent(s):

fresh new main branch

Browse files
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensor filter=lfs diff=lfs merge=lfs -text
3
+ *.pth filter=lfs diff=lfs merge=lfs -text
4
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ veg/MultiTaskConvLSTM_veg_variables filter=lfs diff=lfs merge=lfs -text
6
+ no_veg/MultiTaskConvLSTM_no_veg_variables filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiTask ConvLSTM for Precipitation Prediction
2
+
3
+ This repository contains two MultiTask ConvLSTM models:
4
+ - **veg/**: Model trained with vegetation input variables
5
+ - **noveg/**: Model trained without vegetation input variables
6
+
7
+ Both directories include:
8
+ - `convlstm.py`: base ConvLSTM layers
9
+ - `model.py`: MultiTask ConvLSTM model definition
10
+ - `example_inference.py`: inference script
11
+ - `data/`: example `.pth` files (test)
12
+
13
+ These scripts are provided for reproducibility of the model architecture and workflow.
14
+ Exact runtime and performance may vary depending on hardware.
15
+
16
+ ## Example Data
17
+
18
+ We provide a large test `.pth` files
19
+ so you can immediately run the inference script without preprocessing.
20
+ These files are already preprocessed and normalized from the ECWMF REA5 reanalysis data.
21
+
22
+ Each `.pth` file loads as a list of batches:
23
+
24
+ - `X_batch`: shape `(B, T_in, C_in, H*W)`
25
+ - `y_batch`: shape `(B, T_out, C_out, H*W)`
26
+ - `y_zero_batch`: shape `(B, T_out, C_out, H*W)`
27
+
28
+ with `H=81`, `W=97`. Inside `evaluate(...)`, these are reshaped to `(B, T, C, H, W)`.
29
+
30
+ ---
31
+
32
+ ## How to Use
33
+
34
+ Ensure all files are in the correct directory then run the example_inference.py file.
35
+
36
+ # 1 Get the repo
37
+ git clone https://huggingface.co/<your-username>/MultiTaskConvLSTM
38
+ cd MultiTaskConvLSTM
39
+
40
+ # 2 Install minimal deps
41
+ pip install -r requirements.txt
42
+
43
+ # 3 Run inference (choose one variant)
44
+ python veg/example_inference.py
45
+ # or
46
+ python noveg/example_inference.py
47
+
48
+
49
+ ## Citation If you use this model, please cite: > Lilly Horvath-Makkos (2025). [title] [journal] BibTeX:
50
+ bibtex
51
+ @article{horvathmakkos2025,
52
+ title={Title},
53
+ author={Horvath-Makkos, Lilly},
54
+ journal={Journal},
55
+ year={2025}
56
+ }
no_veg/ConvLSTM.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ConvLSTM definition
2
+
3
+
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ class ConvLSTMCell(nn.Module):
9
+
10
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias):
11
+ #Input_um is the number of channels per input tensor, hidden_dim is the numer of channels of hidden state, bias is a booleam, wehther or not to add a bias
12
+
13
+ super(ConvLSTMCell, self).__init__()
14
+
15
+ self.input_dim = input_dim
16
+ self.hidden_dim = hidden_dim
17
+
18
+ self.kernel_size = kernel_size
19
+ self.padding = (kernel_size[0])// 2, (kernel_size[1]) // 2
20
+ self.bias = bias
21
+
22
+ self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
23
+ out_channels=4 * self.hidden_dim,
24
+ kernel_size=self.kernel_size,
25
+ padding=self.padding,
26
+ bias=self.bias)
27
+
28
+ def forward(self, input_tensor, cur_state):
29
+ h_cur, c_cur = cur_state
30
+
31
+ combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
32
+
33
+ combined_conv = self.conv(combined)
34
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
35
+ i = torch.sigmoid(cc_i)
36
+ f = torch.sigmoid(cc_f)
37
+ o = torch.sigmoid(cc_o)
38
+ g = torch.tanh(cc_g)
39
+
40
+ c_next = f * c_cur + i * g
41
+ h_next = o * torch.tanh(c_next)
42
+
43
+ return h_next, c_next
44
+
45
+ def init_hidden(self, batch_size, image_size):
46
+ height, width = image_size
47
+ return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
48
+ torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
49
+
50
+
51
+ class ConvLSTM(nn.Module):
52
+
53
+ """
54
+
55
+ Parameters:
56
+ input_dim: Number of channels in input
57
+ hidden_dim: Number of hidden channels
58
+ kernel_size: Size of kernel in convolutions
59
+ num_layers: Number of LSTM layers stacked on each other
60
+ batch_first: Whether or not dimension 0 is the batch or not
61
+ bias: Bias or no bias in Convolution
62
+ return_all_layers: Return the list of computations for all layers
63
+ Note: Will do same padding.
64
+
65
+ Input:
66
+ A tensor of size B, T, C, H, W or T, B, C, H, W
67
+ Output:
68
+ A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
69
+ 0 - layer_output_list is the list of lists of length T of each output
70
+ 1 - last_state_list is the list of last states
71
+ each element of the list is a tuple (h, c) for hidden state and memory
72
+ Example:
73
+ >> x = torch.rand((32, 10, 64, 128, 128))
74
+ >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
75
+ >> _, last_states = convlstm(x)
76
+ >> h = last_states[0][0] # 0 for layer index, 0 for h index
77
+ """
78
+
79
+ def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
80
+ batch_first=False, bias=True, return_all_layers=False):
81
+ super(ConvLSTM, self).__init__()
82
+
83
+ self._check_kernel_size_consistency(kernel_size)
84
+
85
+ # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
86
+ kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
87
+ hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
88
+ if not len(kernel_size) == len(hidden_dim) == num_layers:
89
+ raise ValueError('Inconsistent list length.')
90
+
91
+ self.input_dim = input_dim
92
+ self.hidden_dim = hidden_dim
93
+ self.kernel_size = kernel_size
94
+ self.num_layers = num_layers
95
+ self.batch_first = batch_first
96
+ self.bias = bias
97
+ self.return_all_layers = return_all_layers
98
+
99
+ cell_list = []
100
+ for i in range(0, self.num_layers):
101
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
102
+ # print(f"Layer {i}: input_dim={cur_input_dim}, hidden_dim={self.hidden_dim[i]}")
103
+ cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
104
+ hidden_dim=self.hidden_dim[i],
105
+ kernel_size=self.kernel_size[i],
106
+ bias=self.bias))
107
+
108
+ self.cell_list = nn.ModuleList(cell_list)
109
+
110
+ def forward(self, input_tensor, hidden_state=None):
111
+ """
112
+
113
+ Parameters
114
+ ----------
115
+ input_tensor: todo
116
+ 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
117
+ hidden_state: todo
118
+ None. todo implement stateful
119
+
120
+ Returns
121
+ -------
122
+ last_state_list, layer_output
123
+ """
124
+ if not self.batch_first:
125
+ # (t, b, c, h, w) -> (b, t, c, h, w)
126
+ input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
127
+
128
+ b, _, _, h, w = input_tensor.size()
129
+
130
+ # Implement stateful ConvLSTM
131
+ if hidden_state is not None:
132
+ raise NotImplementedError()
133
+ else:
134
+ # Since the init is done in forward. Can send image size here
135
+ hidden_state = self._init_hidden(batch_size=b,
136
+ image_size=(h, w))
137
+
138
+ layer_output_list = []
139
+ last_state_list = []
140
+
141
+ seq_len = input_tensor.size(1)
142
+ cur_layer_input = input_tensor
143
+
144
+ for layer_idx in range(self.num_layers):
145
+
146
+ h, c = hidden_state[layer_idx]
147
+ output_inner = []
148
+ for t in range(seq_len):
149
+ # print(f"Layer {layer_idx}, Time {t}, Input shape: {cur_layer_input[:, t, :, :, :].shape}")
150
+ h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
151
+ cur_state=[h, c])
152
+ output_inner.append(h)
153
+
154
+ layer_output = torch.stack(output_inner, dim=1)
155
+ cur_layer_input = layer_output
156
+
157
+ # print(f"ConvLSTM Layer {layer_idx} output shape: {cur_layer_input.shape}")
158
+
159
+ layer_output_list.append(layer_output)
160
+ last_state_list.append([h, c])
161
+
162
+ if not self.return_all_layers:
163
+ layer_output_list = layer_output_list[-1:]
164
+ last_state_list = last_state_list[-1:]
165
+
166
+ return layer_output_list, last_state_list
167
+
168
+ def _init_hidden(self, batch_size, image_size):
169
+ init_states = []
170
+ for i in range(self.num_layers):
171
+ init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
172
+ return init_states
173
+
174
+ @staticmethod
175
+ def _check_kernel_size_consistency(kernel_size):
176
+ if not (isinstance(kernel_size, tuple) or
177
+ (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
178
+ raise ValueError('`kernel_size` must be tuple or list of tuples')
179
+
180
+ @staticmethod
181
+ def _extend_for_multilayer(param, num_layers):
182
+ if not isinstance(param, list):
183
+ param = [param] * num_layers
184
+ return param
no_veg/MultiTaskConvLSTM.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ConvLSTM import ConvLSTM
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import defaultdict
5
+
6
+ #MLP definition
7
+ class MLP_5D(nn.Module):
8
+ def __init__(self, height, width):
9
+ super(MLP_5D, self).__init__()
10
+ # Define the fully connected layers
11
+ self.fc1 = nn.Linear(64, 128) # Input channels = 41, output features = 128
12
+ self.dropout1 = nn.Dropout(0.05)
13
+ self.fc2 = nn.Linear(128, 64) # Output features = 64
14
+ self.dropout2 = nn.Dropout(0.05)
15
+ self.fc3 = nn.Linear(64, 1) # Final output, reducing to 1 channel
16
+
17
+ self.height = height
18
+ self.width = width
19
+
20
+ def forward(self, x):
21
+ batch_size, timesteps, channels, height, width = x.shape
22
+
23
+ # Ensure the input spatial dimensions match the expected height and width
24
+ assert height == self.height and width == self.width, "Height and width mismatch"
25
+
26
+ # Reshape to (batch * timesteps * height * width, channels)
27
+ x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels)
28
+ # print(x.shape)
29
+
30
+ # Apply MLP (Fully connected layers)
31
+ x = self.fc1(x)
32
+ x = torch.nn.functional.softplus(x)
33
+ x = self.dropout1(x)
34
+ x = self.fc2(x)
35
+ x = torch.nn.functional.softplus(x)
36
+ x = self.dropout2(x)
37
+ x = self.fc3(x)
38
+ x = torch.nn.functional.softplus(x)
39
+
40
+ # Reshape back to (batch, timesteps, 1, height, width)
41
+ x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3)
42
+
43
+ return x
44
+
45
+
46
+ # MultiTask ConvLSTM definition
47
+
48
+ class ConvLSTMNetwork(nn.Module):
49
+ def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)):
50
+ super(ConvLSTMNetwork, self).__init__()
51
+
52
+ # ConvLSTM module
53
+ self.convlstm = ConvLSTM(input_dim=input_dim,
54
+ hidden_dim=hidden_dims,
55
+ kernel_size=kernel_size,
56
+ num_layers=num_layers,
57
+ batch_first=batch_first,
58
+ bias=True,
59
+ return_all_layers=True)
60
+
61
+ # Batch Normalization for each ConvLSTM layer's output
62
+ self.batch_norms = nn.ModuleList([
63
+ nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims
64
+ ])
65
+
66
+ # Final Conv3D layer for regression pathway
67
+ self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1],
68
+ out_channels=output_channels,
69
+ kernel_size=(1, 3, 3),
70
+ padding=(0, 1, 1))
71
+
72
+ # MLP for regression output: (B,T,C,H,W) -> (B,T,1,H,W)
73
+ self.mlp = MLP_5D(height=81, width=97)
74
+
75
+ # Classification head for pixel-level zero precipitation probability
76
+ # We'll produce (B,T,1,H,W) as well:
77
+ # The classification head takes (B,C,T,H,W) input. We'll reorder dimensions before applying it.
78
+ # Then apply Sigmoid to get probabilities between 0 and 1.
79
+ self.classification_head = nn.Sequential(
80
+ nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), # from C to 1 channel
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ self.activation_variance = defaultdict(list)
85
+
86
+ def forward(self, x):
87
+ """
88
+ x: (B, T, input_dim, H, W)
89
+ """
90
+ # Forward through ConvLSTM
91
+ layer_output_list, last_state_list = self.convlstm(x)
92
+
93
+ # Apply batch norms
94
+ for i, output in enumerate(layer_output_list):
95
+ # output: (B, T, C, H, W)
96
+ output = output.permute(0, 2, 1, 3, 4) # (B, C, T, H, W) for BatchNorm3d
97
+ output = self.batch_norms[i](output)
98
+ output = output.permute(0, 2, 1, 3, 4) # back to (B, T, C, H, W)
99
+
100
+ #Track variance across spatial dimensions for hooks with activation tracking
101
+ activation_variance = output.var(dim=(3, 4)).mean().item()
102
+ self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance)
103
+
104
+ layer_output_list[i] = output
105
+
106
+ # Take output from the last ConvLSTM layer
107
+ final_output = layer_output_list[-1] # (B, T, C, H, W)
108
+
109
+ # Pass through Conv3D: needs (B,C,T,H,W)
110
+ final_output = final_output.permute(0, 2, 1, 3, 4) # (B,C,T,H,W)
111
+ final_output = self.conv3d(final_output)
112
+ # Now final_output: (B, output_channels, T, H, W)
113
+
114
+ # Return to (B,T,C,H,W) for MLP (regression)
115
+ final_output_t = final_output.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
116
+
117
+ # Regression output
118
+ regression_output = self.mlp(final_output_t) # (B,T,1,H,W)
119
+
120
+ # Classification output:
121
+ # The classification head is defined for (B,C,T,H,W), so reorder again
122
+ final_output_c = final_output # still (B,output_channels,T,H,W)
123
+ classification_output = self.classification_head(final_output_c)
124
+ # classification_output: (B,1,T,H,W)
125
+
126
+ # Permute classification output to match (B,T,1,H,W) format
127
+ classification_output = classification_output.permute(0, 2, 1, 3, 4) # (B,T,1,H,W)
128
+
129
+ return regression_output, classification_output
no_veg/MultiTaskConvLSTM_no_veg_variables.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82e89dda4a11281b4a9dbecf081951c304977d3f481f7f6024a2edd3261b02e9
3
+ size 1317646
no_veg/data/normalized_test_data_no_veg_input.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5b76c968c3a80b260f8db30d5e9c219c241ac26fd44856c3c87008394dac8e9
3
+ size 1644632048
no_veg/example_inference.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # example_inference
2
+ import torch
3
+ from MultiTaskConvLSTM import ConvLSTMNetwork
4
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
5
+ import torch
6
+ import toch.nn as nn
7
+ from tqdm.auto import tqdm
8
+ from utils import (
9
+ mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation,
10
+ spearman_correlation, percentage_error, percentage_bias,
11
+ kendall_tau, spatial_correlation
12
+ )
13
+ import torch.optim as optim
14
+
15
+
16
+ device = 'cpu'
17
+
18
+ height = 81
19
+ width = 97
20
+
21
+ set_lookback = 1
22
+ set_forecast_horizon = 1
23
+
24
+ #Define variables for evaluation
25
+ batch_size = 16
26
+ time_steps_out = set_forecast_horizon
27
+ channels = 9
28
+
29
+ #Variable names
30
+ #Variable names
31
+ variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'Total column rain water', 'Total precipitation', 'Time-integrated surface latent heat net flux']
32
+
33
+ # Adjust input_dim and output_channels according to your data specifics
34
+ model = ConvLSTMNetwork(
35
+ input_dim=9 * set_lookback,
36
+ hidden_dims=[9, 32, 64],
37
+ kernel_size=(3,3),
38
+ num_layers=3,
39
+ output_channels=64 * set_forecast_horizon,
40
+ batch_first=True
41
+ ).to(device)
42
+
43
+ # Define separate loss functions
44
+ loss_fn = nn.MSELoss() # For regression output
45
+ bce_loss_fn = nn.BCELoss() # For classification output
46
+
47
+ optimizer = optim.AdamW(model.parameters(), lr = 0.005)
48
+
49
+ checkpoint = torch.load("MultiTaskConvLSTM_no_veg_variables")
50
+ model.load_state_dict(checkpoint['model_state_dict'])
51
+
52
+ # If you want to move the model to the GPU (optional, depending on your setup)
53
+ model.to(device) # Assuming you have a variable `device` for CUDA or CPU
54
+
55
+ # Ensure that the model is in evaluation mode if you're using it for inference
56
+ model.eval()
57
+
58
+ print("Model loaded successfully")
59
+
60
+
61
+ threshold = 0.1
62
+ precip_index = 10
63
+
64
+ def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width):
65
+ """
66
+ Evaluate the model on the test set for both regression and classification tasks.
67
+ """
68
+ model.eval() # Set the model to evaluation model
69
+
70
+ # input_to_true = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
71
+ # input_to_pred_REG = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
72
+ # input_to_pred_CLASS = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
73
+
74
+ test_reg_loss = 0.0
75
+ test_class_loss = 0.0
76
+ test_total_loss = 0.0
77
+
78
+ y_true_reg = [] # List to store true values for regression
79
+ y_pred_reg = [] # List to store predicted values for regression
80
+
81
+ y_pred_reg2 = []
82
+
83
+ y_true_class = [] # List to store true values for classification
84
+ y_pred_class = [] # List to store predicted probabilities for classification
85
+
86
+ # Disable gradient computation
87
+ with torch.no_grad():
88
+ for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"):
89
+ # Move the batch to the device
90
+ X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device)
91
+
92
+ # Reshape inputs and targets
93
+ batch_size, time_steps_in, channels_in, grid_points = X_test.shape
94
+ batch_size, time_steps_out, channels_out, grid_points = y_test.shape
95
+ X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width)
96
+ y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width)
97
+ y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width)
98
+
99
+ # Forward pass
100
+ regression_output, classification_output = model(X_test)
101
+
102
+ classification_predictions = (classification_output > 0.7).float()
103
+
104
+ # Compute regression loss
105
+ reg_loss = reg_loss_fn(regression_output, y_test)
106
+
107
+ # Compute classification loss
108
+ class_loss = class_loss_fn(classification_output, y_zero_test)
109
+
110
+ # Total loss
111
+ total_loss = reg_loss + class_loss
112
+
113
+ regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions)
114
+
115
+ # Accumulate losses
116
+ test_reg_loss += reg_loss.item() * X_test.size(0)
117
+ test_class_loss += class_loss.item() * X_test.size(0)
118
+ test_total_loss += total_loss.item() * X_test.size(0)
119
+
120
+ # Collect true and predicted values for regression and classification
121
+ y_true_reg.append(y_test.cpu())
122
+ y_pred_reg.append(regression_output.cpu())
123
+ y_pred_reg2.append(regression_output2.cpu())
124
+ y_true_class.append(y_zero_test.cpu())
125
+ y_pred_class.append(classification_output.cpu())
126
+
127
+ # Normalize losses by the total dataset size
128
+ test_reg_loss /= len(test_loader)
129
+ test_class_loss /= len(test_loader)
130
+ test_total_loss /= len(test_loader)
131
+
132
+ print(f"Test Regression Loss: {test_reg_loss:.16f}")
133
+ print(f"Test Classification Loss: {test_class_loss:.16f}")
134
+ print(f"Test Total Loss: {test_total_loss:.16f}")
135
+
136
+ y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() # Keep as PyTorch tensor
137
+ y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() # Keep as PyTorch tensor
138
+ y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() # Keep as PyTorch tensor
139
+ y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() # Keep as PyTorch tensor
140
+
141
+ # Compute regression metrics
142
+ regression_metrics = {
143
+ "MSE": mse(y_true_reg_flat, y_pred_reg_flat),
144
+ "MAE": mae(y_true_reg_flat, y_pred_reg_flat),
145
+ "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
146
+ "R2": r2_score(y_true_reg_flat, y_pred_reg_flat),
147
+ "Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat),
148
+ "Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat),
149
+ "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
150
+ "Percentage Error": percentage_error(y_true_reg_flat, y_pred_reg_flat),
151
+ "Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat),
152
+ "Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat),
153
+ "Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)}
154
+
155
+ print("\nRegression Metrics:")
156
+ for metric, value in regression_metrics.items():
157
+ print(f"{metric}: {value:.16f}")
158
+
159
+
160
+ # Compute classification metrics
161
+ classification_metrics = {
162
+ "Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
163
+ "Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
164
+ "Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
165
+ "F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
166
+ "ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat),
167
+ }
168
+
169
+ print("\nClassification Metrics:")
170
+ for metric, value in classification_metrics.items():
171
+ print(f"{metric}: {value:.16f}")
172
+
173
+ torch.save({
174
+ 'y_true_reg': y_true_reg_flat,
175
+ 'y_pred_reg': y_pred_reg_flat,
176
+ 'y_true_class': y_true_class_flat,
177
+ 'y_pred_class': y_pred_class_flat,
178
+ }, 'results')
179
+
180
+ return test_total_loss, regression_metrics, classification_metrics
181
+
182
+
183
+ """
184
+ EXPECTED DATALOADER BATCH FORMAT (normalized_test_data):
185
+
186
+ Each batch must be a tuple: (X_batch, y_batch, y_zero_batch)
187
+
188
+ X_batch contains the previous hours variables. y_batch contains the next hour's precipitation.
189
+ y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and
190
+ 1 for precipitation >0.1mm.
191
+
192
+ Shapes BEFORE reshaping inside `evaluate`:
193
+ X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857
194
+ y_batch: (B, T_out, C_out, G)
195
+ y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets
196
+
197
+ If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference.
198
+
199
+ DTypes:
200
+ X_batch, y_batch: torch.float32
201
+ y_zero_batch: torch.float32 (will be used with BCELoss)
202
+
203
+ Reshaping done in 'evaluate':
204
+ X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97)
205
+ y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97)
206
+ y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W)
207
+
208
+ Model input:
209
+ model expects X_test shaped (B, T_in, input_dim, H, W)
210
+ where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9)
211
+
212
+ Notes:
213
+ • Make sure G == H*W (i.e., 7857 for 81x97).
214
+ • C_out for precipitation should be 1 (one target channel), and y_zero_batch
215
+ is the 0/1 mask for “zero precipitation” at each pixel & time.
216
+ • y_zero_batch should be probabilities/labels in {0,1} for BCELoss.
217
+ """
218
+
219
+ normalized_test_data = torch.load("data/normalized_test_data_no_veg_input.pth")
220
+
221
+
222
+ test_total_loss, regression_metrics, classification_metrics = evaluate(
223
+ model=model,
224
+ test_loader=normalized_test_data,
225
+ reg_loss_fn=loss_fn,
226
+ class_loss_fn=bce_loss_fn,
227
+ device=device,
228
+ variable_names=variable_names,
229
+ height=height,
230
+ width=width,
231
+ )
no_veg/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Definition of evaluation metrics
2
+ from scipy.stats import pearsonr, spearmanr
3
+ import torch.nn.functional as F
4
+ import torch
5
+ from scipy.stats import kendalltau
6
+ import scipy.stats as stats
7
+
8
+
9
+ def nash_sutcliffe_efficiency(observed, predicted):
10
+ # Ensure inputs are tensors on the CPU
11
+ observed = observed.cpu()
12
+ predicted = predicted.cpu()
13
+
14
+ # Compute the numerator and denominator
15
+ numerator = torch.sum((observed - predicted) ** 2)
16
+ denominator = torch.sum((observed - torch.mean(observed)) ** 2)
17
+
18
+ # Calculate NSE
19
+ nse = 1 - (numerator / denominator)
20
+ return nse.item()
21
+
22
+ def pearson_correlation(y_true, y_pred):
23
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
24
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
25
+
26
+ return pearsonr(y_true, y_pred)[0] # Return the correlation coefficient
27
+
28
+ def spearman_correlation(y_true, y_pred):
29
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
30
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
31
+
32
+ return spearmanr(y_true, y_pred).correlation # Return the Spearman correlation
33
+
34
+ def mse(y_true, y_pred):
35
+ # Ensure inputs are tensors on the CPU
36
+ y_true = y_true.cpu()
37
+ y_pred = y_pred.cpu()
38
+
39
+ return torch.mean((y_true - y_pred) ** 2).item()
40
+
41
+ def mae(y_true, y_pred):
42
+ # Ensure inputs are tensors on the CPU
43
+ y_true = y_true.cpu()
44
+ y_pred = y_pred.cpu()
45
+
46
+ return torch.mean(torch.abs(y_true - y_pred)).item()
47
+
48
+ def percentage_error(y_true, y_pred):
49
+ # Ensure inputs are tensors on the CPU
50
+ y_true = y_true.cpu()
51
+ y_pred = y_pred.cpu()
52
+
53
+ return 100 * torch.mean((y_pred - y_true) / (y_true + 1e-6)).item()
54
+
55
+ def percentage_bias(y_true, y_pred):
56
+ # Ensure inputs are tensors on the CPU
57
+ y_true = y_true.cpu()
58
+ y_pred = y_pred.cpu()
59
+
60
+ return 100 * torch.sum(y_pred - y_true) / (torch.sum(y_true) + 1e-6)
61
+
62
+ def kendall_tau(y_true, y_pred):
63
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
64
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
65
+
66
+ return kendalltau(y_true, y_pred).correlation # Return the Kendall Tau
67
+
68
+ def r2_score(y_true, y_pred):
69
+ # Ensure inputs are tensors on the CPU
70
+ y_true = y_true.cpu()
71
+ y_pred = y_pred.cpu()
72
+
73
+ ss_total = torch.sum((y_true - torch.mean(y_true)) ** 2)
74
+ ss_residual = torch.sum((y_true - y_pred) ** 2)
75
+
76
+ return 1 - (ss_residual / (ss_total + 1e-6)).item()
77
+
78
+ def spatial_correlation(y_true, y_pred):
79
+ # Flatten the tensors to work with them
80
+ y_true_flat = y_true.view(-1).cpu()
81
+ y_pred_flat = y_pred.view(-1).cpu()
82
+
83
+ # Compute the numerator: sum(P * T)
84
+ numerator = torch.sum(y_pred_flat * y_true_flat)
85
+
86
+ # Compute the denominator: sqrt(sum(P^2) * sum(T^2))
87
+ denominator = torch.sqrt(torch.sum(y_pred_flat ** 2) * torch.sum(y_true_flat ** 2))
88
+
89
+ # Compute the correlation (add epsilon to avoid division by zero)
90
+ correlation = numerator / (denominator)
91
+
92
+ return correlation.item()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ numpy>=1.24
3
+ scikit-learn>=1.3
4
+ tqdm>=4.65
veg/ConvLSTM.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ConvLSTM definition
2
+
3
+
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ class ConvLSTMCell(nn.Module):
9
+
10
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias):
11
+ #Input_um is the number of channels per input tensor, hidden_dim is the numer of channels of hidden state, bias is a booleam, wehther or not to add a bias
12
+
13
+ super(ConvLSTMCell, self).__init__()
14
+
15
+ self.input_dim = input_dim
16
+ self.hidden_dim = hidden_dim
17
+
18
+ self.kernel_size = kernel_size
19
+ self.padding = (kernel_size[0])// 2, (kernel_size[1]) // 2
20
+ self.bias = bias
21
+
22
+ self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
23
+ out_channels=4 * self.hidden_dim,
24
+ kernel_size=self.kernel_size,
25
+ padding=self.padding,
26
+ bias=self.bias)
27
+
28
+ def forward(self, input_tensor, cur_state):
29
+ h_cur, c_cur = cur_state
30
+
31
+ combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
32
+
33
+ combined_conv = self.conv(combined)
34
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
35
+ i = torch.sigmoid(cc_i)
36
+ f = torch.sigmoid(cc_f)
37
+ o = torch.sigmoid(cc_o)
38
+ g = torch.tanh(cc_g)
39
+
40
+ c_next = f * c_cur + i * g
41
+ h_next = o * torch.tanh(c_next)
42
+
43
+ return h_next, c_next
44
+
45
+ def init_hidden(self, batch_size, image_size):
46
+ height, width = image_size
47
+ return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
48
+ torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
49
+
50
+
51
+ class ConvLSTM(nn.Module):
52
+
53
+ """
54
+
55
+ Parameters:
56
+ input_dim: Number of channels in input
57
+ hidden_dim: Number of hidden channels
58
+ kernel_size: Size of kernel in convolutions
59
+ num_layers: Number of LSTM layers stacked on each other
60
+ batch_first: Whether or not dimension 0 is the batch or not
61
+ bias: Bias or no bias in Convolution
62
+ return_all_layers: Return the list of computations for all layers
63
+ Note: Will do same padding.
64
+
65
+ Input:
66
+ A tensor of size B, T, C, H, W or T, B, C, H, W
67
+ Output:
68
+ A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
69
+ 0 - layer_output_list is the list of lists of length T of each output
70
+ 1 - last_state_list is the list of last states
71
+ each element of the list is a tuple (h, c) for hidden state and memory
72
+ Example:
73
+ >> x = torch.rand((32, 10, 64, 128, 128))
74
+ >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
75
+ >> _, last_states = convlstm(x)
76
+ >> h = last_states[0][0] # 0 for layer index, 0 for h index
77
+ """
78
+
79
+ def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
80
+ batch_first=False, bias=True, return_all_layers=False):
81
+ super(ConvLSTM, self).__init__()
82
+
83
+ self._check_kernel_size_consistency(kernel_size)
84
+
85
+ # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
86
+ kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
87
+ hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
88
+ if not len(kernel_size) == len(hidden_dim) == num_layers:
89
+ raise ValueError('Inconsistent list length.')
90
+
91
+ self.input_dim = input_dim
92
+ self.hidden_dim = hidden_dim
93
+ self.kernel_size = kernel_size
94
+ self.num_layers = num_layers
95
+ self.batch_first = batch_first
96
+ self.bias = bias
97
+ self.return_all_layers = return_all_layers
98
+
99
+ cell_list = []
100
+ for i in range(0, self.num_layers):
101
+ cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
102
+ # print(f"Layer {i}: input_dim={cur_input_dim}, hidden_dim={self.hidden_dim[i]}")
103
+ cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
104
+ hidden_dim=self.hidden_dim[i],
105
+ kernel_size=self.kernel_size[i],
106
+ bias=self.bias))
107
+
108
+ self.cell_list = nn.ModuleList(cell_list)
109
+
110
+ def forward(self, input_tensor, hidden_state=None):
111
+ """
112
+
113
+ Parameters
114
+ ----------
115
+ input_tensor: todo
116
+ 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
117
+ hidden_state: todo
118
+ None. todo implement stateful
119
+
120
+ Returns
121
+ -------
122
+ last_state_list, layer_output
123
+ """
124
+ if not self.batch_first:
125
+ # (t, b, c, h, w) -> (b, t, c, h, w)
126
+ input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
127
+
128
+ b, _, _, h, w = input_tensor.size()
129
+
130
+ # Implement stateful ConvLSTM
131
+ if hidden_state is not None:
132
+ raise NotImplementedError()
133
+ else:
134
+ # Since the init is done in forward. Can send image size here
135
+ hidden_state = self._init_hidden(batch_size=b,
136
+ image_size=(h, w))
137
+
138
+ layer_output_list = []
139
+ last_state_list = []
140
+
141
+ seq_len = input_tensor.size(1)
142
+ cur_layer_input = input_tensor
143
+
144
+ for layer_idx in range(self.num_layers):
145
+
146
+ h, c = hidden_state[layer_idx]
147
+ output_inner = []
148
+ for t in range(seq_len):
149
+ # print(f"Layer {layer_idx}, Time {t}, Input shape: {cur_layer_input[:, t, :, :, :].shape}")
150
+ h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
151
+ cur_state=[h, c])
152
+ output_inner.append(h)
153
+
154
+ layer_output = torch.stack(output_inner, dim=1)
155
+ cur_layer_input = layer_output
156
+
157
+ # print(f"ConvLSTM Layer {layer_idx} output shape: {cur_layer_input.shape}")
158
+
159
+ layer_output_list.append(layer_output)
160
+ last_state_list.append([h, c])
161
+
162
+ if not self.return_all_layers:
163
+ layer_output_list = layer_output_list[-1:]
164
+ last_state_list = last_state_list[-1:]
165
+
166
+ return layer_output_list, last_state_list
167
+
168
+ def _init_hidden(self, batch_size, image_size):
169
+ init_states = []
170
+ for i in range(self.num_layers):
171
+ init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
172
+ return init_states
173
+
174
+ @staticmethod
175
+ def _check_kernel_size_consistency(kernel_size):
176
+ if not (isinstance(kernel_size, tuple) or
177
+ (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
178
+ raise ValueError('`kernel_size` must be tuple or list of tuples')
179
+
180
+ @staticmethod
181
+ def _extend_for_multilayer(param, num_layers):
182
+ if not isinstance(param, list):
183
+ param = [param] * num_layers
184
+ return param
veg/MultiTaskConvLSTM.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ConvLSTM import ConvLSTM
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import defaultdict
5
+
6
+ #MLP definition
7
+ class MLP_5D(nn.Module):
8
+ def __init__(self, height, width):
9
+ super(MLP_5D, self).__init__()
10
+ # Define the fully connected layers
11
+ self.fc1 = nn.Linear(64, 128) # Input channels = 41, output features = 128
12
+ self.dropout1 = nn.Dropout(0.05)
13
+ self.fc2 = nn.Linear(128, 64) # Output features = 64
14
+ self.dropout2 = nn.Dropout(0.05)
15
+ self.fc3 = nn.Linear(64, 1) # Final output, reducing to 1 channel
16
+
17
+ self.height = height
18
+ self.width = width
19
+
20
+ def forward(self, x):
21
+ batch_size, timesteps, channels, height, width = x.shape
22
+
23
+ # Ensure the input spatial dimensions match the expected height and width
24
+ assert height == self.height and width == self.width, "Height and width mismatch"
25
+
26
+ # Reshape to (batch * timesteps * height * width, channels)
27
+ x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels)
28
+ # print(x.shape)
29
+
30
+ # Apply MLP (Fully connected layers)
31
+ x = self.fc1(x)
32
+ x = torch.nn.functional.softplus(x)
33
+ x = self.dropout1(x)
34
+ x = self.fc2(x)
35
+ x = torch.nn.functional.softplus(x)
36
+ x = self.dropout2(x)
37
+ x = self.fc3(x)
38
+ x = torch.nn.functional.softplus(x)
39
+
40
+ # Reshape back to (batch, timesteps, 1, height, width)
41
+ x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3)
42
+
43
+ return x
44
+
45
+
46
+ # MultiTask ConvLSTM definition
47
+
48
+ class ConvLSTMNetwork(nn.Module):
49
+ def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)):
50
+ super(ConvLSTMNetwork, self).__init__()
51
+
52
+ # ConvLSTM module
53
+ self.convlstm = ConvLSTM(input_dim=input_dim,
54
+ hidden_dim=hidden_dims,
55
+ kernel_size=kernel_size,
56
+ num_layers=num_layers,
57
+ batch_first=batch_first,
58
+ bias=True,
59
+ return_all_layers=True)
60
+
61
+ # Batch Normalization for each ConvLSTM layer's output
62
+ self.batch_norms = nn.ModuleList([
63
+ nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims
64
+ ])
65
+
66
+ # Final Conv3D layer for regression pathway
67
+ self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1],
68
+ out_channels=output_channels,
69
+ kernel_size=(1, 3, 3),
70
+ padding=(0, 1, 1))
71
+
72
+ # MLP for regression output: (B,T,C,H,W) -> (B,T,1,H,W)
73
+ self.mlp = MLP_5D(height=81, width=97)
74
+
75
+ # Classification head for pixel-level zero precipitation probability
76
+ # We'll produce (B,T,1,H,W) as well:
77
+ # The classification head takes (B,C,T,H,W) input. We'll reorder dimensions before applying it.
78
+ # Then apply Sigmoid to get probabilities between 0 and 1.
79
+ self.classification_head = nn.Sequential(
80
+ nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), # from C to 1 channel
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ self.activation_variance = defaultdict(list)
85
+
86
+ def forward(self, x):
87
+ """
88
+ x: (B, T, input_dim, H, W)
89
+ """
90
+ # Forward through ConvLSTM
91
+ layer_output_list, last_state_list = self.convlstm(x)
92
+
93
+ # Apply batch norms
94
+ for i, output in enumerate(layer_output_list):
95
+ # output: (B, T, C, H, W)
96
+ output = output.permute(0, 2, 1, 3, 4) # (B, C, T, H, W) for BatchNorm3d
97
+ output = self.batch_norms[i](output)
98
+ output = output.permute(0, 2, 1, 3, 4) # back to (B, T, C, H, W)
99
+
100
+ #Track variance across spatial dimensions for hooks with activation tracking
101
+ activation_variance = output.var(dim=(3, 4)).mean().item()
102
+ self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance)
103
+
104
+ layer_output_list[i] = output
105
+
106
+ # Take output from the last ConvLSTM layer
107
+ final_output = layer_output_list[-1] # (B, T, C, H, W)
108
+
109
+ # Pass through Conv3D: needs (B,C,T,H,W)
110
+ final_output = final_output.permute(0, 2, 1, 3, 4) # (B,C,T,H,W)
111
+ final_output = self.conv3d(final_output)
112
+ # Now final_output: (B, output_channels, T, H, W)
113
+
114
+ # Return to (B,T,C,H,W) for MLP (regression)
115
+ final_output_t = final_output.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
116
+
117
+ # Regression output
118
+ regression_output = self.mlp(final_output_t) # (B,T,1,H,W)
119
+
120
+ # Classification output:
121
+ # The classification head is defined for (B,C,T,H,W), so reorder again
122
+ final_output_c = final_output # still (B,output_channels,T,H,W)
123
+ classification_output = self.classification_head(final_output_c)
124
+ # classification_output: (B,1,T,H,W)
125
+
126
+ # Permute classification output to match (B,T,1,H,W) format
127
+ classification_output = classification_output.permute(0, 2, 1, 3, 4) # (B,T,1,H,W)
128
+
129
+ return regression_output, classification_output
veg/MultiTaskConvLSTM_veg_variables.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7e8b5e8b33db227257dd794dd6ba5ff10d4759aeed1f758550d4f3acb69cc26
3
+ size 1383333
veg/data/normalized_test_data_veg_input.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:075c60cf83a2d7ef68720b72c77f41a034d6418b67085e016fff6a5bdfef878c
3
+ size 2631223074
veg/example_inference.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # example_inference
2
+ import torch
3
+ from MultiTaskConvLSTM import ConvLSTMNetwork
4
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
5
+ import torch
6
+ import toch.nn as nn
7
+ from tqdm.auto import tqdm
8
+ from utils import (
9
+ mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation,
10
+ spearman_correlation, percentage_error, percentage_bias,
11
+ kendall_tau, spatial_correlation
12
+ )
13
+ import torch.optim as optim
14
+
15
+
16
+ device = 'cpu'
17
+
18
+ height = 81
19
+ width = 97
20
+
21
+ set_lookback = 1
22
+ set_forecast_horizon = 1
23
+
24
+ #Define variables for evaluation
25
+ batch_size = 16
26
+ time_steps_out = set_forecast_horizon
27
+ channels = 14
28
+
29
+ #Variable names
30
+ variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'UV visible albedo for direct radiation (climatological)', 'Total column rain water', 'Volumetric soil water layer 1', 'Leaf area index, high vegetation', 'Leaf area index, low vegetation', 'Forecast surface roughness', 'Total precipitation', 'Time-integrated surface latent heat net flux', 'Evaporation']
31
+
32
+ # Adjust input_dim and output_channels according to your data specifics
33
+ model = ConvLSTMNetwork(
34
+ input_dim=14 * set_lookback,
35
+ hidden_dims=[14, 32, 64],
36
+ kernel_size=(3,3),
37
+ num_layers=3,
38
+ output_channels=64 * set_forecast_horizon,
39
+ batch_first=True
40
+ ).to(device)
41
+
42
+ # Define separate loss functions
43
+ loss_fn = nn.MSELoss() # For regression output
44
+ bce_loss_fn = nn.BCELoss() # For classification output
45
+
46
+ optimizer = optim.AdamW(model.parameters(), lr = 0.005)
47
+
48
+ checkpoint = torch.load("MultiTaskConvLSTM_veg_variables")
49
+ model.load_state_dict(checkpoint['model_state_dict'])
50
+
51
+ # If you want to move the model to the GPU (optional, depending on your setup)
52
+ model.to(device) # Assuming you have a variable `device` for CUDA or CPU
53
+
54
+ # Ensure that the model is in evaluation mode if you're using it for inference
55
+ model.eval()
56
+
57
+ print("Model loaded successfully")
58
+
59
+
60
+ threshold = 0.1
61
+ precip_index = 10
62
+
63
+ def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width):
64
+ """
65
+ Evaluate the model on the test set for both regression and classification tasks.
66
+ """
67
+ model.eval() # Set the model to evaluation model
68
+
69
+ # input_to_true = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
70
+ # input_to_pred_REG = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
71
+ # input_to_pred_CLASS = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
72
+
73
+ test_reg_loss = 0.0
74
+ test_class_loss = 0.0
75
+ test_total_loss = 0.0
76
+
77
+ y_true_reg = [] # List to store true values for regression
78
+ y_pred_reg = [] # List to store predicted values for regression
79
+
80
+ y_pred_reg2 = []
81
+
82
+ y_true_class = [] # List to store true values for classification
83
+ y_pred_class = [] # List to store predicted probabilities for classification
84
+
85
+ # Disable gradient computation
86
+ with torch.no_grad():
87
+ for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"):
88
+ # Move the batch to the device
89
+ X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device)
90
+
91
+ # Reshape inputs and targets
92
+ batch_size, time_steps_in, channels_in, grid_points = X_test.shape
93
+ batch_size, time_steps_out, channels_out, grid_points = y_test.shape
94
+ X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width)
95
+ y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width)
96
+ y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width)
97
+
98
+ # Forward pass
99
+ regression_output, classification_output = model(X_test)
100
+
101
+ classification_predictions = (classification_output > 0.7).float()
102
+
103
+ # Compute regression loss
104
+ reg_loss = reg_loss_fn(regression_output, y_test)
105
+
106
+ # Compute classification loss
107
+ class_loss = class_loss_fn(classification_output, y_zero_test)
108
+
109
+ # Total loss
110
+ total_loss = reg_loss + class_loss
111
+
112
+ regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions)
113
+
114
+ # Accumulate losses
115
+ test_reg_loss += reg_loss.item() * X_test.size(0)
116
+ test_class_loss += class_loss.item() * X_test.size(0)
117
+ test_total_loss += total_loss.item() * X_test.size(0)
118
+
119
+ # Collect true and predicted values for regression and classification
120
+ y_true_reg.append(y_test.cpu())
121
+ y_pred_reg.append(regression_output.cpu())
122
+ y_pred_reg2.append(regression_output2.cpu())
123
+ y_true_class.append(y_zero_test.cpu())
124
+ y_pred_class.append(classification_output.cpu())
125
+
126
+ # Normalize losses by the total dataset size
127
+ test_reg_loss /= len(test_loader)
128
+ test_class_loss /= len(test_loader)
129
+ test_total_loss /= len(test_loader)
130
+
131
+ print(f"Test Regression Loss: {test_reg_loss:.16f}")
132
+ print(f"Test Classification Loss: {test_class_loss:.16f}")
133
+ print(f"Test Total Loss: {test_total_loss:.16f}")
134
+
135
+ y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() # Keep as PyTorch tensor
136
+ y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() # Keep as PyTorch tensor
137
+ y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() # Keep as PyTorch tensor
138
+ y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() # Keep as PyTorch tensor
139
+
140
+ # Compute regression metrics
141
+ regression_metrics = {
142
+ "MSE": mse(y_true_reg_flat, y_pred_reg_flat),
143
+ "MAE": mae(y_true_reg_flat, y_pred_reg_flat),
144
+ "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
145
+ "R2": r2_score(y_true_reg_flat, y_pred_reg_flat),
146
+ "Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat),
147
+ "Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat),
148
+ "NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
149
+ "Percentage Error": percentage_error(y_true_reg_flat, y_pred_reg_flat),
150
+ "Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat),
151
+ "Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat),
152
+ "Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)}
153
+
154
+ print("\nRegression Metrics:")
155
+ for metric, value in regression_metrics.items():
156
+ print(f"{metric}: {value:.16f}")
157
+
158
+
159
+ # Compute classification metrics
160
+ classification_metrics = {
161
+ "Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
162
+ "Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
163
+ "Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
164
+ "F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
165
+ "ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat),
166
+ }
167
+
168
+ print("\nClassification Metrics:")
169
+ for metric, value in classification_metrics.items():
170
+ print(f"{metric}: {value:.16f}")
171
+
172
+ torch.save({
173
+ 'y_true_reg': y_true_reg_flat,
174
+ 'y_pred_reg': y_pred_reg_flat,
175
+ 'y_true_class': y_true_class_flat,
176
+ 'y_pred_class': y_pred_class_flat,
177
+ }, 'results')
178
+
179
+ return test_total_loss, regression_metrics, classification_metrics
180
+
181
+
182
+ """
183
+ EXPECTED DATALOADER BATCH FORMAT (normalized_test_data):
184
+
185
+ Each batch must be a tuple: (X_batch, y_batch, y_zero_batch)
186
+
187
+ X_batch contains the previous hours variables. y_batch contains the next hour's precipitation.
188
+ y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and
189
+ 1 for precipitation >0.1mm.
190
+
191
+ Shapes BEFORE reshaping inside `evaluate`:
192
+ X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857
193
+ y_batch: (B, T_out, C_out, G)
194
+ y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets
195
+
196
+ If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference.
197
+
198
+ DTypes:
199
+ X_batch, y_batch: torch.float32
200
+ y_zero_batch: torch.float32 (will be used with BCELoss)
201
+
202
+ Reshaping done in 'evaluate':
203
+ X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97)
204
+ y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97)
205
+ y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W)
206
+
207
+ Model input:
208
+ model expects X_test shaped (B, T_in, input_dim, H, W)
209
+ where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9)
210
+
211
+ Notes:
212
+ • Make sure G == H*W (i.e., 7857 for 81x97).
213
+ • C_out for precipitation should be 1 (one target channel), and y_zero_batch
214
+ is the 0/1 mask for “zero precipitation” at each pixel & time.
215
+ • y_zero_batch should be probabilities/labels in {0,1} for BCELoss.
216
+ """
217
+
218
+ normalized_test_data = torch.load("data/normalized_test_data_veg_input.pth")
219
+
220
+
221
+ test_total_loss, regression_metrics, classification_metrics = evaluate(
222
+ model=model,
223
+ test_loader=normalized_test_data,
224
+ reg_loss_fn=loss_fn,
225
+ class_loss_fn=bce_loss_fn,
226
+ device=device,
227
+ variable_names=variable_names,
228
+ height=height,
229
+ width=width,
230
+ )
veg/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Definition of evaluation metrics
2
+ from scipy.stats import pearsonr, spearmanr
3
+ import torch.nn.functional as F
4
+ import torch
5
+ from scipy.stats import kendalltau
6
+ import scipy.stats as stats
7
+
8
+
9
+ def nash_sutcliffe_efficiency(observed, predicted):
10
+ # Ensure inputs are tensors on the CPU
11
+ observed = observed.cpu()
12
+ predicted = predicted.cpu()
13
+
14
+ # Compute the numerator and denominator
15
+ numerator = torch.sum((observed - predicted) ** 2)
16
+ denominator = torch.sum((observed - torch.mean(observed)) ** 2)
17
+
18
+ # Calculate NSE
19
+ nse = 1 - (numerator / denominator)
20
+ return nse.item()
21
+
22
+ def pearson_correlation(y_true, y_pred):
23
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
24
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
25
+
26
+ return pearsonr(y_true, y_pred)[0] # Return the correlation coefficient
27
+
28
+ def spearman_correlation(y_true, y_pred):
29
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
30
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
31
+
32
+ return spearmanr(y_true, y_pred).correlation # Return the Spearman correlation
33
+
34
+ def mse(y_true, y_pred):
35
+ # Ensure inputs are tensors on the CPU
36
+ y_true = y_true.cpu()
37
+ y_pred = y_pred.cpu()
38
+
39
+ return torch.mean((y_true - y_pred) ** 2).item()
40
+
41
+ def mae(y_true, y_pred):
42
+ # Ensure inputs are tensors on the CPU
43
+ y_true = y_true.cpu()
44
+ y_pred = y_pred.cpu()
45
+
46
+ return torch.mean(torch.abs(y_true - y_pred)).item()
47
+
48
+ def percentage_error(y_true, y_pred):
49
+ # Ensure inputs are tensors on the CPU
50
+ y_true = y_true.cpu()
51
+ y_pred = y_pred.cpu()
52
+
53
+ return 100 * torch.mean((y_pred - y_true) / (y_true + 1e-6)).item()
54
+
55
+ def percentage_bias(y_true, y_pred):
56
+ # Ensure inputs are tensors on the CPU
57
+ y_true = y_true.cpu()
58
+ y_pred = y_pred.cpu()
59
+
60
+ return 100 * torch.sum(y_pred - y_true) / (torch.sum(y_true) + 1e-6)
61
+
62
+ def kendall_tau(y_true, y_pred):
63
+ y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
64
+ y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
65
+
66
+ return kendalltau(y_true, y_pred).correlation # Return the Kendall Tau
67
+
68
+ def r2_score(y_true, y_pred):
69
+ # Ensure inputs are tensors on the CPU
70
+ y_true = y_true.cpu()
71
+ y_pred = y_pred.cpu()
72
+
73
+ ss_total = torch.sum((y_true - torch.mean(y_true)) ** 2)
74
+ ss_residual = torch.sum((y_true - y_pred) ** 2)
75
+
76
+ return 1 - (ss_residual / (ss_total + 1e-6)).item()
77
+
78
+ def spatial_correlation(y_true, y_pred):
79
+ # Flatten the tensors to work with them
80
+ y_true_flat = y_true.view(-1).cpu()
81
+ y_pred_flat = y_pred.view(-1).cpu()
82
+
83
+ # Compute the numerator: sum(P * T)
84
+ numerator = torch.sum(y_pred_flat * y_true_flat)
85
+
86
+ # Compute the denominator: sqrt(sum(P^2) * sum(T^2))
87
+ denominator = torch.sqrt(torch.sum(y_pred_flat ** 2) * torch.sum(y_true_flat ** 2))
88
+
89
+ # Compute the correlation (add epsilon to avoid division by zero)
90
+ correlation = numerator / (denominator)
91
+
92
+ return correlation.item()