Lilly Makkos
commited on
Commit
·
ef16512
0
Parent(s):
fresh new main branch
Browse files- .gitattributes +6 -0
- README.md +56 -0
- no_veg/ConvLSTM.py +184 -0
- no_veg/MultiTaskConvLSTM.py +129 -0
- no_veg/MultiTaskConvLSTM_no_veg_variables.pth +3 -0
- no_veg/data/normalized_test_data_no_veg_input.pth +3 -0
- no_veg/example_inference.py +231 -0
- no_veg/utils.py +92 -0
- requirements.txt +4 -0
- veg/ConvLSTM.py +184 -0
- veg/MultiTaskConvLSTM.py +129 -0
- veg/MultiTaskConvLSTM_veg_variables.pth +3 -0
- veg/data/normalized_test_data_veg_input.pth +3 -0
- veg/example_inference.py +230 -0
- veg/utils.py +92 -0
.gitattributes
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.safetensor filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
veg/MultiTaskConvLSTM_veg_variables filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
no_veg/MultiTaskConvLSTM_no_veg_variables filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MultiTask ConvLSTM for Precipitation Prediction
|
| 2 |
+
|
| 3 |
+
This repository contains two MultiTask ConvLSTM models:
|
| 4 |
+
- **veg/**: Model trained with vegetation input variables
|
| 5 |
+
- **noveg/**: Model trained without vegetation input variables
|
| 6 |
+
|
| 7 |
+
Both directories include:
|
| 8 |
+
- `convlstm.py`: base ConvLSTM layers
|
| 9 |
+
- `model.py`: MultiTask ConvLSTM model definition
|
| 10 |
+
- `example_inference.py`: inference script
|
| 11 |
+
- `data/`: example `.pth` files (test)
|
| 12 |
+
|
| 13 |
+
These scripts are provided for reproducibility of the model architecture and workflow.
|
| 14 |
+
Exact runtime and performance may vary depending on hardware.
|
| 15 |
+
|
| 16 |
+
## Example Data
|
| 17 |
+
|
| 18 |
+
We provide a large test `.pth` files
|
| 19 |
+
so you can immediately run the inference script without preprocessing.
|
| 20 |
+
These files are already preprocessed and normalized from the ECWMF REA5 reanalysis data.
|
| 21 |
+
|
| 22 |
+
Each `.pth` file loads as a list of batches:
|
| 23 |
+
|
| 24 |
+
- `X_batch`: shape `(B, T_in, C_in, H*W)`
|
| 25 |
+
- `y_batch`: shape `(B, T_out, C_out, H*W)`
|
| 26 |
+
- `y_zero_batch`: shape `(B, T_out, C_out, H*W)`
|
| 27 |
+
|
| 28 |
+
with `H=81`, `W=97`. Inside `evaluate(...)`, these are reshaped to `(B, T, C, H, W)`.
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## How to Use
|
| 33 |
+
|
| 34 |
+
Ensure all files are in the correct directory then run the example_inference.py file.
|
| 35 |
+
|
| 36 |
+
# 1 Get the repo
|
| 37 |
+
git clone https://huggingface.co/<your-username>/MultiTaskConvLSTM
|
| 38 |
+
cd MultiTaskConvLSTM
|
| 39 |
+
|
| 40 |
+
# 2 Install minimal deps
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
|
| 43 |
+
# 3 Run inference (choose one variant)
|
| 44 |
+
python veg/example_inference.py
|
| 45 |
+
# or
|
| 46 |
+
python noveg/example_inference.py
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## Citation If you use this model, please cite: > Lilly Horvath-Makkos (2025). [title] [journal] BibTeX:
|
| 50 |
+
bibtex
|
| 51 |
+
@article{horvathmakkos2025,
|
| 52 |
+
title={Title},
|
| 53 |
+
author={Horvath-Makkos, Lilly},
|
| 54 |
+
journal={Journal},
|
| 55 |
+
year={2025}
|
| 56 |
+
}
|
no_veg/ConvLSTM.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ConvLSTM definition
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvLSTMCell(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
| 11 |
+
#Input_um is the number of channels per input tensor, hidden_dim is the numer of channels of hidden state, bias is a booleam, wehther or not to add a bias
|
| 12 |
+
|
| 13 |
+
super(ConvLSTMCell, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.input_dim = input_dim
|
| 16 |
+
self.hidden_dim = hidden_dim
|
| 17 |
+
|
| 18 |
+
self.kernel_size = kernel_size
|
| 19 |
+
self.padding = (kernel_size[0])// 2, (kernel_size[1]) // 2
|
| 20 |
+
self.bias = bias
|
| 21 |
+
|
| 22 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 23 |
+
out_channels=4 * self.hidden_dim,
|
| 24 |
+
kernel_size=self.kernel_size,
|
| 25 |
+
padding=self.padding,
|
| 26 |
+
bias=self.bias)
|
| 27 |
+
|
| 28 |
+
def forward(self, input_tensor, cur_state):
|
| 29 |
+
h_cur, c_cur = cur_state
|
| 30 |
+
|
| 31 |
+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
|
| 32 |
+
|
| 33 |
+
combined_conv = self.conv(combined)
|
| 34 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 35 |
+
i = torch.sigmoid(cc_i)
|
| 36 |
+
f = torch.sigmoid(cc_f)
|
| 37 |
+
o = torch.sigmoid(cc_o)
|
| 38 |
+
g = torch.tanh(cc_g)
|
| 39 |
+
|
| 40 |
+
c_next = f * c_cur + i * g
|
| 41 |
+
h_next = o * torch.tanh(c_next)
|
| 42 |
+
|
| 43 |
+
return h_next, c_next
|
| 44 |
+
|
| 45 |
+
def init_hidden(self, batch_size, image_size):
|
| 46 |
+
height, width = image_size
|
| 47 |
+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
|
| 48 |
+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvLSTM(nn.Module):
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
input_dim: Number of channels in input
|
| 57 |
+
hidden_dim: Number of hidden channels
|
| 58 |
+
kernel_size: Size of kernel in convolutions
|
| 59 |
+
num_layers: Number of LSTM layers stacked on each other
|
| 60 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
| 61 |
+
bias: Bias or no bias in Convolution
|
| 62 |
+
return_all_layers: Return the list of computations for all layers
|
| 63 |
+
Note: Will do same padding.
|
| 64 |
+
|
| 65 |
+
Input:
|
| 66 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
| 67 |
+
Output:
|
| 68 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
| 69 |
+
0 - layer_output_list is the list of lists of length T of each output
|
| 70 |
+
1 - last_state_list is the list of last states
|
| 71 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
| 72 |
+
Example:
|
| 73 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
| 74 |
+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
| 75 |
+
>> _, last_states = convlstm(x)
|
| 76 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
|
| 80 |
+
batch_first=False, bias=True, return_all_layers=False):
|
| 81 |
+
super(ConvLSTM, self).__init__()
|
| 82 |
+
|
| 83 |
+
self._check_kernel_size_consistency(kernel_size)
|
| 84 |
+
|
| 85 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
| 86 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
| 87 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
| 88 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
| 89 |
+
raise ValueError('Inconsistent list length.')
|
| 90 |
+
|
| 91 |
+
self.input_dim = input_dim
|
| 92 |
+
self.hidden_dim = hidden_dim
|
| 93 |
+
self.kernel_size = kernel_size
|
| 94 |
+
self.num_layers = num_layers
|
| 95 |
+
self.batch_first = batch_first
|
| 96 |
+
self.bias = bias
|
| 97 |
+
self.return_all_layers = return_all_layers
|
| 98 |
+
|
| 99 |
+
cell_list = []
|
| 100 |
+
for i in range(0, self.num_layers):
|
| 101 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
| 102 |
+
# print(f"Layer {i}: input_dim={cur_input_dim}, hidden_dim={self.hidden_dim[i]}")
|
| 103 |
+
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
|
| 104 |
+
hidden_dim=self.hidden_dim[i],
|
| 105 |
+
kernel_size=self.kernel_size[i],
|
| 106 |
+
bias=self.bias))
|
| 107 |
+
|
| 108 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 109 |
+
|
| 110 |
+
def forward(self, input_tensor, hidden_state=None):
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
input_tensor: todo
|
| 116 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
| 117 |
+
hidden_state: todo
|
| 118 |
+
None. todo implement stateful
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
last_state_list, layer_output
|
| 123 |
+
"""
|
| 124 |
+
if not self.batch_first:
|
| 125 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
| 126 |
+
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
|
| 127 |
+
|
| 128 |
+
b, _, _, h, w = input_tensor.size()
|
| 129 |
+
|
| 130 |
+
# Implement stateful ConvLSTM
|
| 131 |
+
if hidden_state is not None:
|
| 132 |
+
raise NotImplementedError()
|
| 133 |
+
else:
|
| 134 |
+
# Since the init is done in forward. Can send image size here
|
| 135 |
+
hidden_state = self._init_hidden(batch_size=b,
|
| 136 |
+
image_size=(h, w))
|
| 137 |
+
|
| 138 |
+
layer_output_list = []
|
| 139 |
+
last_state_list = []
|
| 140 |
+
|
| 141 |
+
seq_len = input_tensor.size(1)
|
| 142 |
+
cur_layer_input = input_tensor
|
| 143 |
+
|
| 144 |
+
for layer_idx in range(self.num_layers):
|
| 145 |
+
|
| 146 |
+
h, c = hidden_state[layer_idx]
|
| 147 |
+
output_inner = []
|
| 148 |
+
for t in range(seq_len):
|
| 149 |
+
# print(f"Layer {layer_idx}, Time {t}, Input shape: {cur_layer_input[:, t, :, :, :].shape}")
|
| 150 |
+
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
|
| 151 |
+
cur_state=[h, c])
|
| 152 |
+
output_inner.append(h)
|
| 153 |
+
|
| 154 |
+
layer_output = torch.stack(output_inner, dim=1)
|
| 155 |
+
cur_layer_input = layer_output
|
| 156 |
+
|
| 157 |
+
# print(f"ConvLSTM Layer {layer_idx} output shape: {cur_layer_input.shape}")
|
| 158 |
+
|
| 159 |
+
layer_output_list.append(layer_output)
|
| 160 |
+
last_state_list.append([h, c])
|
| 161 |
+
|
| 162 |
+
if not self.return_all_layers:
|
| 163 |
+
layer_output_list = layer_output_list[-1:]
|
| 164 |
+
last_state_list = last_state_list[-1:]
|
| 165 |
+
|
| 166 |
+
return layer_output_list, last_state_list
|
| 167 |
+
|
| 168 |
+
def _init_hidden(self, batch_size, image_size):
|
| 169 |
+
init_states = []
|
| 170 |
+
for i in range(self.num_layers):
|
| 171 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
| 172 |
+
return init_states
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _check_kernel_size_consistency(kernel_size):
|
| 176 |
+
if not (isinstance(kernel_size, tuple) or
|
| 177 |
+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
|
| 178 |
+
raise ValueError('`kernel_size` must be tuple or list of tuples')
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _extend_for_multilayer(param, num_layers):
|
| 182 |
+
if not isinstance(param, list):
|
| 183 |
+
param = [param] * num_layers
|
| 184 |
+
return param
|
no_veg/MultiTaskConvLSTM.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ConvLSTM import ConvLSTM
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
#MLP definition
|
| 7 |
+
class MLP_5D(nn.Module):
|
| 8 |
+
def __init__(self, height, width):
|
| 9 |
+
super(MLP_5D, self).__init__()
|
| 10 |
+
# Define the fully connected layers
|
| 11 |
+
self.fc1 = nn.Linear(64, 128) # Input channels = 41, output features = 128
|
| 12 |
+
self.dropout1 = nn.Dropout(0.05)
|
| 13 |
+
self.fc2 = nn.Linear(128, 64) # Output features = 64
|
| 14 |
+
self.dropout2 = nn.Dropout(0.05)
|
| 15 |
+
self.fc3 = nn.Linear(64, 1) # Final output, reducing to 1 channel
|
| 16 |
+
|
| 17 |
+
self.height = height
|
| 18 |
+
self.width = width
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
batch_size, timesteps, channels, height, width = x.shape
|
| 22 |
+
|
| 23 |
+
# Ensure the input spatial dimensions match the expected height and width
|
| 24 |
+
assert height == self.height and width == self.width, "Height and width mismatch"
|
| 25 |
+
|
| 26 |
+
# Reshape to (batch * timesteps * height * width, channels)
|
| 27 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels)
|
| 28 |
+
# print(x.shape)
|
| 29 |
+
|
| 30 |
+
# Apply MLP (Fully connected layers)
|
| 31 |
+
x = self.fc1(x)
|
| 32 |
+
x = torch.nn.functional.softplus(x)
|
| 33 |
+
x = self.dropout1(x)
|
| 34 |
+
x = self.fc2(x)
|
| 35 |
+
x = torch.nn.functional.softplus(x)
|
| 36 |
+
x = self.dropout2(x)
|
| 37 |
+
x = self.fc3(x)
|
| 38 |
+
x = torch.nn.functional.softplus(x)
|
| 39 |
+
|
| 40 |
+
# Reshape back to (batch, timesteps, 1, height, width)
|
| 41 |
+
x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3)
|
| 42 |
+
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# MultiTask ConvLSTM definition
|
| 47 |
+
|
| 48 |
+
class ConvLSTMNetwork(nn.Module):
|
| 49 |
+
def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)):
|
| 50 |
+
super(ConvLSTMNetwork, self).__init__()
|
| 51 |
+
|
| 52 |
+
# ConvLSTM module
|
| 53 |
+
self.convlstm = ConvLSTM(input_dim=input_dim,
|
| 54 |
+
hidden_dim=hidden_dims,
|
| 55 |
+
kernel_size=kernel_size,
|
| 56 |
+
num_layers=num_layers,
|
| 57 |
+
batch_first=batch_first,
|
| 58 |
+
bias=True,
|
| 59 |
+
return_all_layers=True)
|
| 60 |
+
|
| 61 |
+
# Batch Normalization for each ConvLSTM layer's output
|
| 62 |
+
self.batch_norms = nn.ModuleList([
|
| 63 |
+
nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
# Final Conv3D layer for regression pathway
|
| 67 |
+
self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1],
|
| 68 |
+
out_channels=output_channels,
|
| 69 |
+
kernel_size=(1, 3, 3),
|
| 70 |
+
padding=(0, 1, 1))
|
| 71 |
+
|
| 72 |
+
# MLP for regression output: (B,T,C,H,W) -> (B,T,1,H,W)
|
| 73 |
+
self.mlp = MLP_5D(height=81, width=97)
|
| 74 |
+
|
| 75 |
+
# Classification head for pixel-level zero precipitation probability
|
| 76 |
+
# We'll produce (B,T,1,H,W) as well:
|
| 77 |
+
# The classification head takes (B,C,T,H,W) input. We'll reorder dimensions before applying it.
|
| 78 |
+
# Then apply Sigmoid to get probabilities between 0 and 1.
|
| 79 |
+
self.classification_head = nn.Sequential(
|
| 80 |
+
nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), # from C to 1 channel
|
| 81 |
+
nn.Sigmoid()
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.activation_variance = defaultdict(list)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
"""
|
| 88 |
+
x: (B, T, input_dim, H, W)
|
| 89 |
+
"""
|
| 90 |
+
# Forward through ConvLSTM
|
| 91 |
+
layer_output_list, last_state_list = self.convlstm(x)
|
| 92 |
+
|
| 93 |
+
# Apply batch norms
|
| 94 |
+
for i, output in enumerate(layer_output_list):
|
| 95 |
+
# output: (B, T, C, H, W)
|
| 96 |
+
output = output.permute(0, 2, 1, 3, 4) # (B, C, T, H, W) for BatchNorm3d
|
| 97 |
+
output = self.batch_norms[i](output)
|
| 98 |
+
output = output.permute(0, 2, 1, 3, 4) # back to (B, T, C, H, W)
|
| 99 |
+
|
| 100 |
+
#Track variance across spatial dimensions for hooks with activation tracking
|
| 101 |
+
activation_variance = output.var(dim=(3, 4)).mean().item()
|
| 102 |
+
self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance)
|
| 103 |
+
|
| 104 |
+
layer_output_list[i] = output
|
| 105 |
+
|
| 106 |
+
# Take output from the last ConvLSTM layer
|
| 107 |
+
final_output = layer_output_list[-1] # (B, T, C, H, W)
|
| 108 |
+
|
| 109 |
+
# Pass through Conv3D: needs (B,C,T,H,W)
|
| 110 |
+
final_output = final_output.permute(0, 2, 1, 3, 4) # (B,C,T,H,W)
|
| 111 |
+
final_output = self.conv3d(final_output)
|
| 112 |
+
# Now final_output: (B, output_channels, T, H, W)
|
| 113 |
+
|
| 114 |
+
# Return to (B,T,C,H,W) for MLP (regression)
|
| 115 |
+
final_output_t = final_output.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
|
| 116 |
+
|
| 117 |
+
# Regression output
|
| 118 |
+
regression_output = self.mlp(final_output_t) # (B,T,1,H,W)
|
| 119 |
+
|
| 120 |
+
# Classification output:
|
| 121 |
+
# The classification head is defined for (B,C,T,H,W), so reorder again
|
| 122 |
+
final_output_c = final_output # still (B,output_channels,T,H,W)
|
| 123 |
+
classification_output = self.classification_head(final_output_c)
|
| 124 |
+
# classification_output: (B,1,T,H,W)
|
| 125 |
+
|
| 126 |
+
# Permute classification output to match (B,T,1,H,W) format
|
| 127 |
+
classification_output = classification_output.permute(0, 2, 1, 3, 4) # (B,T,1,H,W)
|
| 128 |
+
|
| 129 |
+
return regression_output, classification_output
|
no_veg/MultiTaskConvLSTM_no_veg_variables.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82e89dda4a11281b4a9dbecf081951c304977d3f481f7f6024a2edd3261b02e9
|
| 3 |
+
size 1317646
|
no_veg/data/normalized_test_data_no_veg_input.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5b76c968c3a80b260f8db30d5e9c219c241ac26fd44856c3c87008394dac8e9
|
| 3 |
+
size 1644632048
|
no_veg/example_inference.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# example_inference
|
| 2 |
+
import torch
|
| 3 |
+
from MultiTaskConvLSTM import ConvLSTMNetwork
|
| 4 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
| 5 |
+
import torch
|
| 6 |
+
import toch.nn as nn
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
from utils import (
|
| 9 |
+
mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation,
|
| 10 |
+
spearman_correlation, percentage_error, percentage_bias,
|
| 11 |
+
kendall_tau, spatial_correlation
|
| 12 |
+
)
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
device = 'cpu'
|
| 17 |
+
|
| 18 |
+
height = 81
|
| 19 |
+
width = 97
|
| 20 |
+
|
| 21 |
+
set_lookback = 1
|
| 22 |
+
set_forecast_horizon = 1
|
| 23 |
+
|
| 24 |
+
#Define variables for evaluation
|
| 25 |
+
batch_size = 16
|
| 26 |
+
time_steps_out = set_forecast_horizon
|
| 27 |
+
channels = 9
|
| 28 |
+
|
| 29 |
+
#Variable names
|
| 30 |
+
#Variable names
|
| 31 |
+
variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'Total column rain water', 'Total precipitation', 'Time-integrated surface latent heat net flux']
|
| 32 |
+
|
| 33 |
+
# Adjust input_dim and output_channels according to your data specifics
|
| 34 |
+
model = ConvLSTMNetwork(
|
| 35 |
+
input_dim=9 * set_lookback,
|
| 36 |
+
hidden_dims=[9, 32, 64],
|
| 37 |
+
kernel_size=(3,3),
|
| 38 |
+
num_layers=3,
|
| 39 |
+
output_channels=64 * set_forecast_horizon,
|
| 40 |
+
batch_first=True
|
| 41 |
+
).to(device)
|
| 42 |
+
|
| 43 |
+
# Define separate loss functions
|
| 44 |
+
loss_fn = nn.MSELoss() # For regression output
|
| 45 |
+
bce_loss_fn = nn.BCELoss() # For classification output
|
| 46 |
+
|
| 47 |
+
optimizer = optim.AdamW(model.parameters(), lr = 0.005)
|
| 48 |
+
|
| 49 |
+
checkpoint = torch.load("MultiTaskConvLSTM_no_veg_variables")
|
| 50 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 51 |
+
|
| 52 |
+
# If you want to move the model to the GPU (optional, depending on your setup)
|
| 53 |
+
model.to(device) # Assuming you have a variable `device` for CUDA or CPU
|
| 54 |
+
|
| 55 |
+
# Ensure that the model is in evaluation mode if you're using it for inference
|
| 56 |
+
model.eval()
|
| 57 |
+
|
| 58 |
+
print("Model loaded successfully")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
threshold = 0.1
|
| 62 |
+
precip_index = 10
|
| 63 |
+
|
| 64 |
+
def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width):
|
| 65 |
+
"""
|
| 66 |
+
Evaluate the model on the test set for both regression and classification tasks.
|
| 67 |
+
"""
|
| 68 |
+
model.eval() # Set the model to evaluation model
|
| 69 |
+
|
| 70 |
+
# input_to_true = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 71 |
+
# input_to_pred_REG = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 72 |
+
# input_to_pred_CLASS = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 73 |
+
|
| 74 |
+
test_reg_loss = 0.0
|
| 75 |
+
test_class_loss = 0.0
|
| 76 |
+
test_total_loss = 0.0
|
| 77 |
+
|
| 78 |
+
y_true_reg = [] # List to store true values for regression
|
| 79 |
+
y_pred_reg = [] # List to store predicted values for regression
|
| 80 |
+
|
| 81 |
+
y_pred_reg2 = []
|
| 82 |
+
|
| 83 |
+
y_true_class = [] # List to store true values for classification
|
| 84 |
+
y_pred_class = [] # List to store predicted probabilities for classification
|
| 85 |
+
|
| 86 |
+
# Disable gradient computation
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"):
|
| 89 |
+
# Move the batch to the device
|
| 90 |
+
X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device)
|
| 91 |
+
|
| 92 |
+
# Reshape inputs and targets
|
| 93 |
+
batch_size, time_steps_in, channels_in, grid_points = X_test.shape
|
| 94 |
+
batch_size, time_steps_out, channels_out, grid_points = y_test.shape
|
| 95 |
+
X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width)
|
| 96 |
+
y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width)
|
| 97 |
+
y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width)
|
| 98 |
+
|
| 99 |
+
# Forward pass
|
| 100 |
+
regression_output, classification_output = model(X_test)
|
| 101 |
+
|
| 102 |
+
classification_predictions = (classification_output > 0.7).float()
|
| 103 |
+
|
| 104 |
+
# Compute regression loss
|
| 105 |
+
reg_loss = reg_loss_fn(regression_output, y_test)
|
| 106 |
+
|
| 107 |
+
# Compute classification loss
|
| 108 |
+
class_loss = class_loss_fn(classification_output, y_zero_test)
|
| 109 |
+
|
| 110 |
+
# Total loss
|
| 111 |
+
total_loss = reg_loss + class_loss
|
| 112 |
+
|
| 113 |
+
regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions)
|
| 114 |
+
|
| 115 |
+
# Accumulate losses
|
| 116 |
+
test_reg_loss += reg_loss.item() * X_test.size(0)
|
| 117 |
+
test_class_loss += class_loss.item() * X_test.size(0)
|
| 118 |
+
test_total_loss += total_loss.item() * X_test.size(0)
|
| 119 |
+
|
| 120 |
+
# Collect true and predicted values for regression and classification
|
| 121 |
+
y_true_reg.append(y_test.cpu())
|
| 122 |
+
y_pred_reg.append(regression_output.cpu())
|
| 123 |
+
y_pred_reg2.append(regression_output2.cpu())
|
| 124 |
+
y_true_class.append(y_zero_test.cpu())
|
| 125 |
+
y_pred_class.append(classification_output.cpu())
|
| 126 |
+
|
| 127 |
+
# Normalize losses by the total dataset size
|
| 128 |
+
test_reg_loss /= len(test_loader)
|
| 129 |
+
test_class_loss /= len(test_loader)
|
| 130 |
+
test_total_loss /= len(test_loader)
|
| 131 |
+
|
| 132 |
+
print(f"Test Regression Loss: {test_reg_loss:.16f}")
|
| 133 |
+
print(f"Test Classification Loss: {test_class_loss:.16f}")
|
| 134 |
+
print(f"Test Total Loss: {test_total_loss:.16f}")
|
| 135 |
+
|
| 136 |
+
y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() # Keep as PyTorch tensor
|
| 137 |
+
y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() # Keep as PyTorch tensor
|
| 138 |
+
y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() # Keep as PyTorch tensor
|
| 139 |
+
y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() # Keep as PyTorch tensor
|
| 140 |
+
|
| 141 |
+
# Compute regression metrics
|
| 142 |
+
regression_metrics = {
|
| 143 |
+
"MSE": mse(y_true_reg_flat, y_pred_reg_flat),
|
| 144 |
+
"MAE": mae(y_true_reg_flat, y_pred_reg_flat),
|
| 145 |
+
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
|
| 146 |
+
"R2": r2_score(y_true_reg_flat, y_pred_reg_flat),
|
| 147 |
+
"Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat),
|
| 148 |
+
"Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat),
|
| 149 |
+
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
|
| 150 |
+
"Percentage Error": percentage_error(y_true_reg_flat, y_pred_reg_flat),
|
| 151 |
+
"Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat),
|
| 152 |
+
"Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat),
|
| 153 |
+
"Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)}
|
| 154 |
+
|
| 155 |
+
print("\nRegression Metrics:")
|
| 156 |
+
for metric, value in regression_metrics.items():
|
| 157 |
+
print(f"{metric}: {value:.16f}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Compute classification metrics
|
| 161 |
+
classification_metrics = {
|
| 162 |
+
"Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 163 |
+
"Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 164 |
+
"Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 165 |
+
"F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 166 |
+
"ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat),
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
print("\nClassification Metrics:")
|
| 170 |
+
for metric, value in classification_metrics.items():
|
| 171 |
+
print(f"{metric}: {value:.16f}")
|
| 172 |
+
|
| 173 |
+
torch.save({
|
| 174 |
+
'y_true_reg': y_true_reg_flat,
|
| 175 |
+
'y_pred_reg': y_pred_reg_flat,
|
| 176 |
+
'y_true_class': y_true_class_flat,
|
| 177 |
+
'y_pred_class': y_pred_class_flat,
|
| 178 |
+
}, 'results')
|
| 179 |
+
|
| 180 |
+
return test_total_loss, regression_metrics, classification_metrics
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
EXPECTED DATALOADER BATCH FORMAT (normalized_test_data):
|
| 185 |
+
|
| 186 |
+
Each batch must be a tuple: (X_batch, y_batch, y_zero_batch)
|
| 187 |
+
|
| 188 |
+
X_batch contains the previous hours variables. y_batch contains the next hour's precipitation.
|
| 189 |
+
y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and
|
| 190 |
+
1 for precipitation >0.1mm.
|
| 191 |
+
|
| 192 |
+
Shapes BEFORE reshaping inside `evaluate`:
|
| 193 |
+
X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857
|
| 194 |
+
y_batch: (B, T_out, C_out, G)
|
| 195 |
+
y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets
|
| 196 |
+
|
| 197 |
+
If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference.
|
| 198 |
+
|
| 199 |
+
DTypes:
|
| 200 |
+
X_batch, y_batch: torch.float32
|
| 201 |
+
y_zero_batch: torch.float32 (will be used with BCELoss)
|
| 202 |
+
|
| 203 |
+
Reshaping done in 'evaluate':
|
| 204 |
+
X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97)
|
| 205 |
+
y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97)
|
| 206 |
+
y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W)
|
| 207 |
+
|
| 208 |
+
Model input:
|
| 209 |
+
model expects X_test shaped (B, T_in, input_dim, H, W)
|
| 210 |
+
where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9)
|
| 211 |
+
|
| 212 |
+
Notes:
|
| 213 |
+
• Make sure G == H*W (i.e., 7857 for 81x97).
|
| 214 |
+
• C_out for precipitation should be 1 (one target channel), and y_zero_batch
|
| 215 |
+
is the 0/1 mask for “zero precipitation” at each pixel & time.
|
| 216 |
+
• y_zero_batch should be probabilities/labels in {0,1} for BCELoss.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
normalized_test_data = torch.load("data/normalized_test_data_no_veg_input.pth")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
test_total_loss, regression_metrics, classification_metrics = evaluate(
|
| 223 |
+
model=model,
|
| 224 |
+
test_loader=normalized_test_data,
|
| 225 |
+
reg_loss_fn=loss_fn,
|
| 226 |
+
class_loss_fn=bce_loss_fn,
|
| 227 |
+
device=device,
|
| 228 |
+
variable_names=variable_names,
|
| 229 |
+
height=height,
|
| 230 |
+
width=width,
|
| 231 |
+
)
|
no_veg/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Definition of evaluation metrics
|
| 2 |
+
from scipy.stats import pearsonr, spearmanr
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch
|
| 5 |
+
from scipy.stats import kendalltau
|
| 6 |
+
import scipy.stats as stats
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def nash_sutcliffe_efficiency(observed, predicted):
|
| 10 |
+
# Ensure inputs are tensors on the CPU
|
| 11 |
+
observed = observed.cpu()
|
| 12 |
+
predicted = predicted.cpu()
|
| 13 |
+
|
| 14 |
+
# Compute the numerator and denominator
|
| 15 |
+
numerator = torch.sum((observed - predicted) ** 2)
|
| 16 |
+
denominator = torch.sum((observed - torch.mean(observed)) ** 2)
|
| 17 |
+
|
| 18 |
+
# Calculate NSE
|
| 19 |
+
nse = 1 - (numerator / denominator)
|
| 20 |
+
return nse.item()
|
| 21 |
+
|
| 22 |
+
def pearson_correlation(y_true, y_pred):
|
| 23 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 24 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 25 |
+
|
| 26 |
+
return pearsonr(y_true, y_pred)[0] # Return the correlation coefficient
|
| 27 |
+
|
| 28 |
+
def spearman_correlation(y_true, y_pred):
|
| 29 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 30 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 31 |
+
|
| 32 |
+
return spearmanr(y_true, y_pred).correlation # Return the Spearman correlation
|
| 33 |
+
|
| 34 |
+
def mse(y_true, y_pred):
|
| 35 |
+
# Ensure inputs are tensors on the CPU
|
| 36 |
+
y_true = y_true.cpu()
|
| 37 |
+
y_pred = y_pred.cpu()
|
| 38 |
+
|
| 39 |
+
return torch.mean((y_true - y_pred) ** 2).item()
|
| 40 |
+
|
| 41 |
+
def mae(y_true, y_pred):
|
| 42 |
+
# Ensure inputs are tensors on the CPU
|
| 43 |
+
y_true = y_true.cpu()
|
| 44 |
+
y_pred = y_pred.cpu()
|
| 45 |
+
|
| 46 |
+
return torch.mean(torch.abs(y_true - y_pred)).item()
|
| 47 |
+
|
| 48 |
+
def percentage_error(y_true, y_pred):
|
| 49 |
+
# Ensure inputs are tensors on the CPU
|
| 50 |
+
y_true = y_true.cpu()
|
| 51 |
+
y_pred = y_pred.cpu()
|
| 52 |
+
|
| 53 |
+
return 100 * torch.mean((y_pred - y_true) / (y_true + 1e-6)).item()
|
| 54 |
+
|
| 55 |
+
def percentage_bias(y_true, y_pred):
|
| 56 |
+
# Ensure inputs are tensors on the CPU
|
| 57 |
+
y_true = y_true.cpu()
|
| 58 |
+
y_pred = y_pred.cpu()
|
| 59 |
+
|
| 60 |
+
return 100 * torch.sum(y_pred - y_true) / (torch.sum(y_true) + 1e-6)
|
| 61 |
+
|
| 62 |
+
def kendall_tau(y_true, y_pred):
|
| 63 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 64 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 65 |
+
|
| 66 |
+
return kendalltau(y_true, y_pred).correlation # Return the Kendall Tau
|
| 67 |
+
|
| 68 |
+
def r2_score(y_true, y_pred):
|
| 69 |
+
# Ensure inputs are tensors on the CPU
|
| 70 |
+
y_true = y_true.cpu()
|
| 71 |
+
y_pred = y_pred.cpu()
|
| 72 |
+
|
| 73 |
+
ss_total = torch.sum((y_true - torch.mean(y_true)) ** 2)
|
| 74 |
+
ss_residual = torch.sum((y_true - y_pred) ** 2)
|
| 75 |
+
|
| 76 |
+
return 1 - (ss_residual / (ss_total + 1e-6)).item()
|
| 77 |
+
|
| 78 |
+
def spatial_correlation(y_true, y_pred):
|
| 79 |
+
# Flatten the tensors to work with them
|
| 80 |
+
y_true_flat = y_true.view(-1).cpu()
|
| 81 |
+
y_pred_flat = y_pred.view(-1).cpu()
|
| 82 |
+
|
| 83 |
+
# Compute the numerator: sum(P * T)
|
| 84 |
+
numerator = torch.sum(y_pred_flat * y_true_flat)
|
| 85 |
+
|
| 86 |
+
# Compute the denominator: sqrt(sum(P^2) * sum(T^2))
|
| 87 |
+
denominator = torch.sqrt(torch.sum(y_pred_flat ** 2) * torch.sum(y_true_flat ** 2))
|
| 88 |
+
|
| 89 |
+
# Compute the correlation (add epsilon to avoid division by zero)
|
| 90 |
+
correlation = numerator / (denominator)
|
| 91 |
+
|
| 92 |
+
return correlation.item()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
scikit-learn>=1.3
|
| 4 |
+
tqdm>=4.65
|
veg/ConvLSTM.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ConvLSTM definition
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvLSTMCell(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
|
| 11 |
+
#Input_um is the number of channels per input tensor, hidden_dim is the numer of channels of hidden state, bias is a booleam, wehther or not to add a bias
|
| 12 |
+
|
| 13 |
+
super(ConvLSTMCell, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.input_dim = input_dim
|
| 16 |
+
self.hidden_dim = hidden_dim
|
| 17 |
+
|
| 18 |
+
self.kernel_size = kernel_size
|
| 19 |
+
self.padding = (kernel_size[0])// 2, (kernel_size[1]) // 2
|
| 20 |
+
self.bias = bias
|
| 21 |
+
|
| 22 |
+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
|
| 23 |
+
out_channels=4 * self.hidden_dim,
|
| 24 |
+
kernel_size=self.kernel_size,
|
| 25 |
+
padding=self.padding,
|
| 26 |
+
bias=self.bias)
|
| 27 |
+
|
| 28 |
+
def forward(self, input_tensor, cur_state):
|
| 29 |
+
h_cur, c_cur = cur_state
|
| 30 |
+
|
| 31 |
+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
|
| 32 |
+
|
| 33 |
+
combined_conv = self.conv(combined)
|
| 34 |
+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
|
| 35 |
+
i = torch.sigmoid(cc_i)
|
| 36 |
+
f = torch.sigmoid(cc_f)
|
| 37 |
+
o = torch.sigmoid(cc_o)
|
| 38 |
+
g = torch.tanh(cc_g)
|
| 39 |
+
|
| 40 |
+
c_next = f * c_cur + i * g
|
| 41 |
+
h_next = o * torch.tanh(c_next)
|
| 42 |
+
|
| 43 |
+
return h_next, c_next
|
| 44 |
+
|
| 45 |
+
def init_hidden(self, batch_size, image_size):
|
| 46 |
+
height, width = image_size
|
| 47 |
+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
|
| 48 |
+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvLSTM(nn.Module):
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
Parameters:
|
| 56 |
+
input_dim: Number of channels in input
|
| 57 |
+
hidden_dim: Number of hidden channels
|
| 58 |
+
kernel_size: Size of kernel in convolutions
|
| 59 |
+
num_layers: Number of LSTM layers stacked on each other
|
| 60 |
+
batch_first: Whether or not dimension 0 is the batch or not
|
| 61 |
+
bias: Bias or no bias in Convolution
|
| 62 |
+
return_all_layers: Return the list of computations for all layers
|
| 63 |
+
Note: Will do same padding.
|
| 64 |
+
|
| 65 |
+
Input:
|
| 66 |
+
A tensor of size B, T, C, H, W or T, B, C, H, W
|
| 67 |
+
Output:
|
| 68 |
+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
| 69 |
+
0 - layer_output_list is the list of lists of length T of each output
|
| 70 |
+
1 - last_state_list is the list of last states
|
| 71 |
+
each element of the list is a tuple (h, c) for hidden state and memory
|
| 72 |
+
Example:
|
| 73 |
+
>> x = torch.rand((32, 10, 64, 128, 128))
|
| 74 |
+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
| 75 |
+
>> _, last_states = convlstm(x)
|
| 76 |
+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
|
| 80 |
+
batch_first=False, bias=True, return_all_layers=False):
|
| 81 |
+
super(ConvLSTM, self).__init__()
|
| 82 |
+
|
| 83 |
+
self._check_kernel_size_consistency(kernel_size)
|
| 84 |
+
|
| 85 |
+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
| 86 |
+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
|
| 87 |
+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
| 88 |
+
if not len(kernel_size) == len(hidden_dim) == num_layers:
|
| 89 |
+
raise ValueError('Inconsistent list length.')
|
| 90 |
+
|
| 91 |
+
self.input_dim = input_dim
|
| 92 |
+
self.hidden_dim = hidden_dim
|
| 93 |
+
self.kernel_size = kernel_size
|
| 94 |
+
self.num_layers = num_layers
|
| 95 |
+
self.batch_first = batch_first
|
| 96 |
+
self.bias = bias
|
| 97 |
+
self.return_all_layers = return_all_layers
|
| 98 |
+
|
| 99 |
+
cell_list = []
|
| 100 |
+
for i in range(0, self.num_layers):
|
| 101 |
+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
| 102 |
+
# print(f"Layer {i}: input_dim={cur_input_dim}, hidden_dim={self.hidden_dim[i]}")
|
| 103 |
+
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
|
| 104 |
+
hidden_dim=self.hidden_dim[i],
|
| 105 |
+
kernel_size=self.kernel_size[i],
|
| 106 |
+
bias=self.bias))
|
| 107 |
+
|
| 108 |
+
self.cell_list = nn.ModuleList(cell_list)
|
| 109 |
+
|
| 110 |
+
def forward(self, input_tensor, hidden_state=None):
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
input_tensor: todo
|
| 116 |
+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
|
| 117 |
+
hidden_state: todo
|
| 118 |
+
None. todo implement stateful
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
last_state_list, layer_output
|
| 123 |
+
"""
|
| 124 |
+
if not self.batch_first:
|
| 125 |
+
# (t, b, c, h, w) -> (b, t, c, h, w)
|
| 126 |
+
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
|
| 127 |
+
|
| 128 |
+
b, _, _, h, w = input_tensor.size()
|
| 129 |
+
|
| 130 |
+
# Implement stateful ConvLSTM
|
| 131 |
+
if hidden_state is not None:
|
| 132 |
+
raise NotImplementedError()
|
| 133 |
+
else:
|
| 134 |
+
# Since the init is done in forward. Can send image size here
|
| 135 |
+
hidden_state = self._init_hidden(batch_size=b,
|
| 136 |
+
image_size=(h, w))
|
| 137 |
+
|
| 138 |
+
layer_output_list = []
|
| 139 |
+
last_state_list = []
|
| 140 |
+
|
| 141 |
+
seq_len = input_tensor.size(1)
|
| 142 |
+
cur_layer_input = input_tensor
|
| 143 |
+
|
| 144 |
+
for layer_idx in range(self.num_layers):
|
| 145 |
+
|
| 146 |
+
h, c = hidden_state[layer_idx]
|
| 147 |
+
output_inner = []
|
| 148 |
+
for t in range(seq_len):
|
| 149 |
+
# print(f"Layer {layer_idx}, Time {t}, Input shape: {cur_layer_input[:, t, :, :, :].shape}")
|
| 150 |
+
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
|
| 151 |
+
cur_state=[h, c])
|
| 152 |
+
output_inner.append(h)
|
| 153 |
+
|
| 154 |
+
layer_output = torch.stack(output_inner, dim=1)
|
| 155 |
+
cur_layer_input = layer_output
|
| 156 |
+
|
| 157 |
+
# print(f"ConvLSTM Layer {layer_idx} output shape: {cur_layer_input.shape}")
|
| 158 |
+
|
| 159 |
+
layer_output_list.append(layer_output)
|
| 160 |
+
last_state_list.append([h, c])
|
| 161 |
+
|
| 162 |
+
if not self.return_all_layers:
|
| 163 |
+
layer_output_list = layer_output_list[-1:]
|
| 164 |
+
last_state_list = last_state_list[-1:]
|
| 165 |
+
|
| 166 |
+
return layer_output_list, last_state_list
|
| 167 |
+
|
| 168 |
+
def _init_hidden(self, batch_size, image_size):
|
| 169 |
+
init_states = []
|
| 170 |
+
for i in range(self.num_layers):
|
| 171 |
+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
|
| 172 |
+
return init_states
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _check_kernel_size_consistency(kernel_size):
|
| 176 |
+
if not (isinstance(kernel_size, tuple) or
|
| 177 |
+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
|
| 178 |
+
raise ValueError('`kernel_size` must be tuple or list of tuples')
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _extend_for_multilayer(param, num_layers):
|
| 182 |
+
if not isinstance(param, list):
|
| 183 |
+
param = [param] * num_layers
|
| 184 |
+
return param
|
veg/MultiTaskConvLSTM.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ConvLSTM import ConvLSTM
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
#MLP definition
|
| 7 |
+
class MLP_5D(nn.Module):
|
| 8 |
+
def __init__(self, height, width):
|
| 9 |
+
super(MLP_5D, self).__init__()
|
| 10 |
+
# Define the fully connected layers
|
| 11 |
+
self.fc1 = nn.Linear(64, 128) # Input channels = 41, output features = 128
|
| 12 |
+
self.dropout1 = nn.Dropout(0.05)
|
| 13 |
+
self.fc2 = nn.Linear(128, 64) # Output features = 64
|
| 14 |
+
self.dropout2 = nn.Dropout(0.05)
|
| 15 |
+
self.fc3 = nn.Linear(64, 1) # Final output, reducing to 1 channel
|
| 16 |
+
|
| 17 |
+
self.height = height
|
| 18 |
+
self.width = width
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
batch_size, timesteps, channels, height, width = x.shape
|
| 22 |
+
|
| 23 |
+
# Ensure the input spatial dimensions match the expected height and width
|
| 24 |
+
assert height == self.height and width == self.width, "Height and width mismatch"
|
| 25 |
+
|
| 26 |
+
# Reshape to (batch * timesteps * height * width, channels)
|
| 27 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(-1, channels)
|
| 28 |
+
# print(x.shape)
|
| 29 |
+
|
| 30 |
+
# Apply MLP (Fully connected layers)
|
| 31 |
+
x = self.fc1(x)
|
| 32 |
+
x = torch.nn.functional.softplus(x)
|
| 33 |
+
x = self.dropout1(x)
|
| 34 |
+
x = self.fc2(x)
|
| 35 |
+
x = torch.nn.functional.softplus(x)
|
| 36 |
+
x = self.dropout2(x)
|
| 37 |
+
x = self.fc3(x)
|
| 38 |
+
x = torch.nn.functional.softplus(x)
|
| 39 |
+
|
| 40 |
+
# Reshape back to (batch, timesteps, 1, height, width)
|
| 41 |
+
x = x.view(batch_size, timesteps, self.height, self.width, 1).permute(0, 1, 4, 2, 3)
|
| 42 |
+
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# MultiTask ConvLSTM definition
|
| 47 |
+
|
| 48 |
+
class ConvLSTMNetwork(nn.Module):
|
| 49 |
+
def __init__(self, input_dim, hidden_dims, kernel_size, num_layers, output_channels, batch_first=True, pool_size=(2,2)):
|
| 50 |
+
super(ConvLSTMNetwork, self).__init__()
|
| 51 |
+
|
| 52 |
+
# ConvLSTM module
|
| 53 |
+
self.convlstm = ConvLSTM(input_dim=input_dim,
|
| 54 |
+
hidden_dim=hidden_dims,
|
| 55 |
+
kernel_size=kernel_size,
|
| 56 |
+
num_layers=num_layers,
|
| 57 |
+
batch_first=batch_first,
|
| 58 |
+
bias=True,
|
| 59 |
+
return_all_layers=True)
|
| 60 |
+
|
| 61 |
+
# Batch Normalization for each ConvLSTM layer's output
|
| 62 |
+
self.batch_norms = nn.ModuleList([
|
| 63 |
+
nn.BatchNorm3d(hidden_dim) for hidden_dim in hidden_dims
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
# Final Conv3D layer for regression pathway
|
| 67 |
+
self.conv3d = nn.Conv3d(in_channels=hidden_dims[-1],
|
| 68 |
+
out_channels=output_channels,
|
| 69 |
+
kernel_size=(1, 3, 3),
|
| 70 |
+
padding=(0, 1, 1))
|
| 71 |
+
|
| 72 |
+
# MLP for regression output: (B,T,C,H,W) -> (B,T,1,H,W)
|
| 73 |
+
self.mlp = MLP_5D(height=81, width=97)
|
| 74 |
+
|
| 75 |
+
# Classification head for pixel-level zero precipitation probability
|
| 76 |
+
# We'll produce (B,T,1,H,W) as well:
|
| 77 |
+
# The classification head takes (B,C,T,H,W) input. We'll reorder dimensions before applying it.
|
| 78 |
+
# Then apply Sigmoid to get probabilities between 0 and 1.
|
| 79 |
+
self.classification_head = nn.Sequential(
|
| 80 |
+
nn.Conv3d(output_channels, 1, kernel_size=(1,1,1)), # from C to 1 channel
|
| 81 |
+
nn.Sigmoid()
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.activation_variance = defaultdict(list)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
"""
|
| 88 |
+
x: (B, T, input_dim, H, W)
|
| 89 |
+
"""
|
| 90 |
+
# Forward through ConvLSTM
|
| 91 |
+
layer_output_list, last_state_list = self.convlstm(x)
|
| 92 |
+
|
| 93 |
+
# Apply batch norms
|
| 94 |
+
for i, output in enumerate(layer_output_list):
|
| 95 |
+
# output: (B, T, C, H, W)
|
| 96 |
+
output = output.permute(0, 2, 1, 3, 4) # (B, C, T, H, W) for BatchNorm3d
|
| 97 |
+
output = self.batch_norms[i](output)
|
| 98 |
+
output = output.permute(0, 2, 1, 3, 4) # back to (B, T, C, H, W)
|
| 99 |
+
|
| 100 |
+
#Track variance across spatial dimensions for hooks with activation tracking
|
| 101 |
+
activation_variance = output.var(dim=(3, 4)).mean().item()
|
| 102 |
+
self.activation_variance[f"ConvLSTM_layer_{i}"].append(activation_variance)
|
| 103 |
+
|
| 104 |
+
layer_output_list[i] = output
|
| 105 |
+
|
| 106 |
+
# Take output from the last ConvLSTM layer
|
| 107 |
+
final_output = layer_output_list[-1] # (B, T, C, H, W)
|
| 108 |
+
|
| 109 |
+
# Pass through Conv3D: needs (B,C,T,H,W)
|
| 110 |
+
final_output = final_output.permute(0, 2, 1, 3, 4) # (B,C,T,H,W)
|
| 111 |
+
final_output = self.conv3d(final_output)
|
| 112 |
+
# Now final_output: (B, output_channels, T, H, W)
|
| 113 |
+
|
| 114 |
+
# Return to (B,T,C,H,W) for MLP (regression)
|
| 115 |
+
final_output_t = final_output.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
|
| 116 |
+
|
| 117 |
+
# Regression output
|
| 118 |
+
regression_output = self.mlp(final_output_t) # (B,T,1,H,W)
|
| 119 |
+
|
| 120 |
+
# Classification output:
|
| 121 |
+
# The classification head is defined for (B,C,T,H,W), so reorder again
|
| 122 |
+
final_output_c = final_output # still (B,output_channels,T,H,W)
|
| 123 |
+
classification_output = self.classification_head(final_output_c)
|
| 124 |
+
# classification_output: (B,1,T,H,W)
|
| 125 |
+
|
| 126 |
+
# Permute classification output to match (B,T,1,H,W) format
|
| 127 |
+
classification_output = classification_output.permute(0, 2, 1, 3, 4) # (B,T,1,H,W)
|
| 128 |
+
|
| 129 |
+
return regression_output, classification_output
|
veg/MultiTaskConvLSTM_veg_variables.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7e8b5e8b33db227257dd794dd6ba5ff10d4759aeed1f758550d4f3acb69cc26
|
| 3 |
+
size 1383333
|
veg/data/normalized_test_data_veg_input.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:075c60cf83a2d7ef68720b72c77f41a034d6418b67085e016fff6a5bdfef878c
|
| 3 |
+
size 2631223074
|
veg/example_inference.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# example_inference
|
| 2 |
+
import torch
|
| 3 |
+
from MultiTaskConvLSTM import ConvLSTMNetwork
|
| 4 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
| 5 |
+
import torch
|
| 6 |
+
import toch.nn as nn
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
from utils import (
|
| 9 |
+
mse, mae, nash_sutcliffe_efficiency, r2_score, pearson_correlation,
|
| 10 |
+
spearman_correlation, percentage_error, percentage_bias,
|
| 11 |
+
kendall_tau, spatial_correlation
|
| 12 |
+
)
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
device = 'cpu'
|
| 17 |
+
|
| 18 |
+
height = 81
|
| 19 |
+
width = 97
|
| 20 |
+
|
| 21 |
+
set_lookback = 1
|
| 22 |
+
set_forecast_horizon = 1
|
| 23 |
+
|
| 24 |
+
#Define variables for evaluation
|
| 25 |
+
batch_size = 16
|
| 26 |
+
time_steps_out = set_forecast_horizon
|
| 27 |
+
channels = 14
|
| 28 |
+
|
| 29 |
+
#Variable names
|
| 30 |
+
variable_names = ['10 metre U wind component', '10 metre V wind component', '2 metre dewpoint temperature', '2 metre temperature', 'UV visible albedo for direct radiation (climatological)', 'Total column rain water', 'Volumetric soil water layer 1', 'Leaf area index, high vegetation', 'Leaf area index, low vegetation', 'Forecast surface roughness', 'Total precipitation', 'Time-integrated surface latent heat net flux', 'Evaporation']
|
| 31 |
+
|
| 32 |
+
# Adjust input_dim and output_channels according to your data specifics
|
| 33 |
+
model = ConvLSTMNetwork(
|
| 34 |
+
input_dim=14 * set_lookback,
|
| 35 |
+
hidden_dims=[14, 32, 64],
|
| 36 |
+
kernel_size=(3,3),
|
| 37 |
+
num_layers=3,
|
| 38 |
+
output_channels=64 * set_forecast_horizon,
|
| 39 |
+
batch_first=True
|
| 40 |
+
).to(device)
|
| 41 |
+
|
| 42 |
+
# Define separate loss functions
|
| 43 |
+
loss_fn = nn.MSELoss() # For regression output
|
| 44 |
+
bce_loss_fn = nn.BCELoss() # For classification output
|
| 45 |
+
|
| 46 |
+
optimizer = optim.AdamW(model.parameters(), lr = 0.005)
|
| 47 |
+
|
| 48 |
+
checkpoint = torch.load("MultiTaskConvLSTM_veg_variables")
|
| 49 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 50 |
+
|
| 51 |
+
# If you want to move the model to the GPU (optional, depending on your setup)
|
| 52 |
+
model.to(device) # Assuming you have a variable `device` for CUDA or CPU
|
| 53 |
+
|
| 54 |
+
# Ensure that the model is in evaluation mode if you're using it for inference
|
| 55 |
+
model.eval()
|
| 56 |
+
|
| 57 |
+
print("Model loaded successfully")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
threshold = 0.1
|
| 61 |
+
precip_index = 10
|
| 62 |
+
|
| 63 |
+
def evaluate(model, test_loader, reg_loss_fn, class_loss_fn, device, variable_names, height, width):
|
| 64 |
+
"""
|
| 65 |
+
Evaluate the model on the test set for both regression and classification tasks.
|
| 66 |
+
"""
|
| 67 |
+
model.eval() # Set the model to evaluation model
|
| 68 |
+
|
| 69 |
+
# input_to_true = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 70 |
+
# input_to_pred_REG = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 71 |
+
# input_to_pred_CLASS = {'zero_to_non_zero': 0, 'non_zero_to_zero': 0}
|
| 72 |
+
|
| 73 |
+
test_reg_loss = 0.0
|
| 74 |
+
test_class_loss = 0.0
|
| 75 |
+
test_total_loss = 0.0
|
| 76 |
+
|
| 77 |
+
y_true_reg = [] # List to store true values for regression
|
| 78 |
+
y_pred_reg = [] # List to store predicted values for regression
|
| 79 |
+
|
| 80 |
+
y_pred_reg2 = []
|
| 81 |
+
|
| 82 |
+
y_true_class = [] # List to store true values for classification
|
| 83 |
+
y_pred_class = [] # List to store predicted probabilities for classification
|
| 84 |
+
|
| 85 |
+
# Disable gradient computation
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
for X_test, y_test, y_zero_test in tqdm(test_loader, desc="Evaluating on Test Set"):
|
| 88 |
+
# Move the batch to the device
|
| 89 |
+
X_test, y_test, y_zero_test = X_test.to(device), y_test.to(device), y_zero_test.to(device)
|
| 90 |
+
|
| 91 |
+
# Reshape inputs and targets
|
| 92 |
+
batch_size, time_steps_in, channels_in, grid_points = X_test.shape
|
| 93 |
+
batch_size, time_steps_out, channels_out, grid_points = y_test.shape
|
| 94 |
+
X_test = X_test.view(batch_size, time_steps_in, channels_in, height, width)
|
| 95 |
+
y_test = y_test.view(batch_size, time_steps_out, channels_out, height, width)
|
| 96 |
+
y_zero_test = y_zero_test.view(batch_size, time_steps_out, channels_out, height, width)
|
| 97 |
+
|
| 98 |
+
# Forward pass
|
| 99 |
+
regression_output, classification_output = model(X_test)
|
| 100 |
+
|
| 101 |
+
classification_predictions = (classification_output > 0.7).float()
|
| 102 |
+
|
| 103 |
+
# Compute regression loss
|
| 104 |
+
reg_loss = reg_loss_fn(regression_output, y_test)
|
| 105 |
+
|
| 106 |
+
# Compute classification loss
|
| 107 |
+
class_loss = class_loss_fn(classification_output, y_zero_test)
|
| 108 |
+
|
| 109 |
+
# Total loss
|
| 110 |
+
total_loss = reg_loss + class_loss
|
| 111 |
+
|
| 112 |
+
regression_output2 = torch.where(classification_predictions == 0, regression_output, classification_predictions)
|
| 113 |
+
|
| 114 |
+
# Accumulate losses
|
| 115 |
+
test_reg_loss += reg_loss.item() * X_test.size(0)
|
| 116 |
+
test_class_loss += class_loss.item() * X_test.size(0)
|
| 117 |
+
test_total_loss += total_loss.item() * X_test.size(0)
|
| 118 |
+
|
| 119 |
+
# Collect true and predicted values for regression and classification
|
| 120 |
+
y_true_reg.append(y_test.cpu())
|
| 121 |
+
y_pred_reg.append(regression_output.cpu())
|
| 122 |
+
y_pred_reg2.append(regression_output2.cpu())
|
| 123 |
+
y_true_class.append(y_zero_test.cpu())
|
| 124 |
+
y_pred_class.append(classification_output.cpu())
|
| 125 |
+
|
| 126 |
+
# Normalize losses by the total dataset size
|
| 127 |
+
test_reg_loss /= len(test_loader)
|
| 128 |
+
test_class_loss /= len(test_loader)
|
| 129 |
+
test_total_loss /= len(test_loader)
|
| 130 |
+
|
| 131 |
+
print(f"Test Regression Loss: {test_reg_loss:.16f}")
|
| 132 |
+
print(f"Test Classification Loss: {test_class_loss:.16f}")
|
| 133 |
+
print(f"Test Total Loss: {test_total_loss:.16f}")
|
| 134 |
+
|
| 135 |
+
y_true_reg_flat = torch.cat(y_true_reg, dim=0).flatten() # Keep as PyTorch tensor
|
| 136 |
+
y_pred_reg_flat = torch.cat(y_pred_reg, dim=0).flatten() # Keep as PyTorch tensor
|
| 137 |
+
y_true_class_flat = torch.cat(y_true_class, dim=0).flatten() # Keep as PyTorch tensor
|
| 138 |
+
y_pred_class_flat = torch.cat(y_pred_class, dim=0).flatten() # Keep as PyTorch tensor
|
| 139 |
+
|
| 140 |
+
# Compute regression metrics
|
| 141 |
+
regression_metrics = {
|
| 142 |
+
"MSE": mse(y_true_reg_flat, y_pred_reg_flat),
|
| 143 |
+
"MAE": mae(y_true_reg_flat, y_pred_reg_flat),
|
| 144 |
+
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
|
| 145 |
+
"R2": r2_score(y_true_reg_flat, y_pred_reg_flat),
|
| 146 |
+
"Pearson": pearson_correlation(y_true_reg_flat, y_pred_reg_flat),
|
| 147 |
+
"Spearman": spearman_correlation(y_true_reg_flat, y_pred_reg_flat),
|
| 148 |
+
"NSE": nash_sutcliffe_efficiency(y_true_reg_flat, y_pred_reg_flat),
|
| 149 |
+
"Percentage Error": percentage_error(y_true_reg_flat, y_pred_reg_flat),
|
| 150 |
+
"Percentage Bias": percentage_bias(y_true_reg_flat, y_pred_reg_flat),
|
| 151 |
+
"Kendall Tau": kendall_tau(y_true_reg_flat, y_pred_reg_flat),
|
| 152 |
+
"Spatial Correlation": spatial_correlation(y_true_reg_flat, y_pred_reg_flat)}
|
| 153 |
+
|
| 154 |
+
print("\nRegression Metrics:")
|
| 155 |
+
for metric, value in regression_metrics.items():
|
| 156 |
+
print(f"{metric}: {value:.16f}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Compute classification metrics
|
| 160 |
+
classification_metrics = {
|
| 161 |
+
"Accuracy": accuracy_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 162 |
+
"Precision": precision_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 163 |
+
"Recall": recall_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 164 |
+
"F1": f1_score(y_true_class_flat, (y_pred_class_flat > 0.7)),
|
| 165 |
+
"ROC-AUC": roc_auc_score(y_true_class_flat, y_pred_class_flat),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
print("\nClassification Metrics:")
|
| 169 |
+
for metric, value in classification_metrics.items():
|
| 170 |
+
print(f"{metric}: {value:.16f}")
|
| 171 |
+
|
| 172 |
+
torch.save({
|
| 173 |
+
'y_true_reg': y_true_reg_flat,
|
| 174 |
+
'y_pred_reg': y_pred_reg_flat,
|
| 175 |
+
'y_true_class': y_true_class_flat,
|
| 176 |
+
'y_pred_class': y_pred_class_flat,
|
| 177 |
+
}, 'results')
|
| 178 |
+
|
| 179 |
+
return test_total_loss, regression_metrics, classification_metrics
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
"""
|
| 183 |
+
EXPECTED DATALOADER BATCH FORMAT (normalized_test_data):
|
| 184 |
+
|
| 185 |
+
Each batch must be a tuple: (X_batch, y_batch, y_zero_batch)
|
| 186 |
+
|
| 187 |
+
X_batch contains the previous hours variables. y_batch contains the next hour's precipitation.
|
| 188 |
+
y_zero_batch contains the next hour's precipitation thresholded as 0 for precipiation <=0.1mm/h and
|
| 189 |
+
1 for precipitation >0.1mm.
|
| 190 |
+
|
| 191 |
+
Shapes BEFORE reshaping inside `evaluate`:
|
| 192 |
+
X_batch: (B, T_in, C_in, G) # G = H*W = 81*97 = 7857
|
| 193 |
+
y_batch: (B, T_out, C_out, G)
|
| 194 |
+
y_zero_batch: (B, T_out, C_out, G) # binary 0/1 "zero-precip" targets
|
| 195 |
+
|
| 196 |
+
If your preprocessing produces (B,T, C, H, W), reshape to (B, T, C, H*W) before inference.
|
| 197 |
+
|
| 198 |
+
DTypes:
|
| 199 |
+
X_batch, y_batch: torch.float32
|
| 200 |
+
y_zero_batch: torch.float32 (will be used with BCELoss)
|
| 201 |
+
|
| 202 |
+
Reshaping done in 'evaluate':
|
| 203 |
+
X_test = X_batch.view(B, T_in, C_in, H, W) -> (B, T_in, C_in, 81, 97)
|
| 204 |
+
y_test = y_batch.view(B, T_out, C_out, H, W) -> (B, T_out, C_out, 81, 97)
|
| 205 |
+
y_zero_test = y_zero_batch.view(B, T_out, C_out, H, W)
|
| 206 |
+
|
| 207 |
+
Model input:
|
| 208 |
+
model expects X_test shaped (B, T_in, input_dim, H, W)
|
| 209 |
+
where input_dim == 9 * set_lookback (with set_lookback=1 -> input_dim=9)
|
| 210 |
+
|
| 211 |
+
Notes:
|
| 212 |
+
• Make sure G == H*W (i.e., 7857 for 81x97).
|
| 213 |
+
• C_out for precipitation should be 1 (one target channel), and y_zero_batch
|
| 214 |
+
is the 0/1 mask for “zero precipitation” at each pixel & time.
|
| 215 |
+
• y_zero_batch should be probabilities/labels in {0,1} for BCELoss.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
normalized_test_data = torch.load("data/normalized_test_data_veg_input.pth")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
test_total_loss, regression_metrics, classification_metrics = evaluate(
|
| 222 |
+
model=model,
|
| 223 |
+
test_loader=normalized_test_data,
|
| 224 |
+
reg_loss_fn=loss_fn,
|
| 225 |
+
class_loss_fn=bce_loss_fn,
|
| 226 |
+
device=device,
|
| 227 |
+
variable_names=variable_names,
|
| 228 |
+
height=height,
|
| 229 |
+
width=width,
|
| 230 |
+
)
|
veg/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Definition of evaluation metrics
|
| 2 |
+
from scipy.stats import pearsonr, spearmanr
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch
|
| 5 |
+
from scipy.stats import kendalltau
|
| 6 |
+
import scipy.stats as stats
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def nash_sutcliffe_efficiency(observed, predicted):
|
| 10 |
+
# Ensure inputs are tensors on the CPU
|
| 11 |
+
observed = observed.cpu()
|
| 12 |
+
predicted = predicted.cpu()
|
| 13 |
+
|
| 14 |
+
# Compute the numerator and denominator
|
| 15 |
+
numerator = torch.sum((observed - predicted) ** 2)
|
| 16 |
+
denominator = torch.sum((observed - torch.mean(observed)) ** 2)
|
| 17 |
+
|
| 18 |
+
# Calculate NSE
|
| 19 |
+
nse = 1 - (numerator / denominator)
|
| 20 |
+
return nse.item()
|
| 21 |
+
|
| 22 |
+
def pearson_correlation(y_true, y_pred):
|
| 23 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 24 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 25 |
+
|
| 26 |
+
return pearsonr(y_true, y_pred)[0] # Return the correlation coefficient
|
| 27 |
+
|
| 28 |
+
def spearman_correlation(y_true, y_pred):
|
| 29 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 30 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 31 |
+
|
| 32 |
+
return spearmanr(y_true, y_pred).correlation # Return the Spearman correlation
|
| 33 |
+
|
| 34 |
+
def mse(y_true, y_pred):
|
| 35 |
+
# Ensure inputs are tensors on the CPU
|
| 36 |
+
y_true = y_true.cpu()
|
| 37 |
+
y_pred = y_pred.cpu()
|
| 38 |
+
|
| 39 |
+
return torch.mean((y_true - y_pred) ** 2).item()
|
| 40 |
+
|
| 41 |
+
def mae(y_true, y_pred):
|
| 42 |
+
# Ensure inputs are tensors on the CPU
|
| 43 |
+
y_true = y_true.cpu()
|
| 44 |
+
y_pred = y_pred.cpu()
|
| 45 |
+
|
| 46 |
+
return torch.mean(torch.abs(y_true - y_pred)).item()
|
| 47 |
+
|
| 48 |
+
def percentage_error(y_true, y_pred):
|
| 49 |
+
# Ensure inputs are tensors on the CPU
|
| 50 |
+
y_true = y_true.cpu()
|
| 51 |
+
y_pred = y_pred.cpu()
|
| 52 |
+
|
| 53 |
+
return 100 * torch.mean((y_pred - y_true) / (y_true + 1e-6)).item()
|
| 54 |
+
|
| 55 |
+
def percentage_bias(y_true, y_pred):
|
| 56 |
+
# Ensure inputs are tensors on the CPU
|
| 57 |
+
y_true = y_true.cpu()
|
| 58 |
+
y_pred = y_pred.cpu()
|
| 59 |
+
|
| 60 |
+
return 100 * torch.sum(y_pred - y_true) / (torch.sum(y_true) + 1e-6)
|
| 61 |
+
|
| 62 |
+
def kendall_tau(y_true, y_pred):
|
| 63 |
+
y_true = y_true.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 64 |
+
y_pred = y_pred.view(-1).cpu().numpy() # Flatten and move to CPU
|
| 65 |
+
|
| 66 |
+
return kendalltau(y_true, y_pred).correlation # Return the Kendall Tau
|
| 67 |
+
|
| 68 |
+
def r2_score(y_true, y_pred):
|
| 69 |
+
# Ensure inputs are tensors on the CPU
|
| 70 |
+
y_true = y_true.cpu()
|
| 71 |
+
y_pred = y_pred.cpu()
|
| 72 |
+
|
| 73 |
+
ss_total = torch.sum((y_true - torch.mean(y_true)) ** 2)
|
| 74 |
+
ss_residual = torch.sum((y_true - y_pred) ** 2)
|
| 75 |
+
|
| 76 |
+
return 1 - (ss_residual / (ss_total + 1e-6)).item()
|
| 77 |
+
|
| 78 |
+
def spatial_correlation(y_true, y_pred):
|
| 79 |
+
# Flatten the tensors to work with them
|
| 80 |
+
y_true_flat = y_true.view(-1).cpu()
|
| 81 |
+
y_pred_flat = y_pred.view(-1).cpu()
|
| 82 |
+
|
| 83 |
+
# Compute the numerator: sum(P * T)
|
| 84 |
+
numerator = torch.sum(y_pred_flat * y_true_flat)
|
| 85 |
+
|
| 86 |
+
# Compute the denominator: sqrt(sum(P^2) * sum(T^2))
|
| 87 |
+
denominator = torch.sqrt(torch.sum(y_pred_flat ** 2) * torch.sum(y_true_flat ** 2))
|
| 88 |
+
|
| 89 |
+
# Compute the correlation (add epsilon to avoid division by zero)
|
| 90 |
+
correlation = numerator / (denominator)
|
| 91 |
+
|
| 92 |
+
return correlation.item()
|