Update Exp3_Kuroshio_forecasting/dataloader_api/dataloader_kuroshio_128.py
Browse files
Exp3_Kuroshio_forecasting/dataloader_api/dataloader_kuroshio_128.py
CHANGED
|
@@ -6,33 +6,75 @@ import netCDF4 as nc
|
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
class OceanCurrentDataset(Dataset):
|
| 9 |
-
def __init__(self, data_path, input_steps=10, output_steps=10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
self.data_path = data_path
|
| 11 |
self.input_steps = input_steps
|
| 12 |
self.output_steps = output_steps
|
| 13 |
self.transform = transform
|
| 14 |
self.total_steps = input_steps + output_steps
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def _load_and_process_data(self):
|
|
|
|
| 20 |
with nc.Dataset(self.data_path, 'r') as ds:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def process_var(var):
|
| 22 |
-
arr = var[:]
|
| 23 |
if '_FillValue' in var.ncattrs():
|
| 24 |
fill_value = var._FillValue
|
| 25 |
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 26 |
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
vgos = process_var(ds['vgos'])
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
def _compute_stats(self):
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def __len__(self):
|
| 38 |
return len(self.data) - self.total_steps + 1
|
|
@@ -40,8 +82,10 @@ class OceanCurrentDataset(Dataset):
|
|
| 40 |
def __getitem__(self, idx):
|
| 41 |
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 42 |
|
|
|
|
| 43 |
window = (window - self.mean) / self.std
|
| 44 |
|
|
|
|
| 45 |
input_seq = window[:self.input_steps]
|
| 46 |
target_seq = window[self.input_steps:]
|
| 47 |
|
|
@@ -49,27 +93,42 @@ class OceanCurrentDataset(Dataset):
|
|
| 49 |
input_seq = self.transform(input_seq)
|
| 50 |
target_seq = self.transform(target_seq)
|
| 51 |
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
def create_dataloaders(config):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
data_path=config['data_path'],
|
| 57 |
input_steps=config['input_steps'],
|
| 58 |
-
output_steps=config['output_steps']
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
test_size = len(full_dataset) - train_size - val_size
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
|
|
|
|
| 70 |
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 71 |
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 72 |
-
test_sampler = DistributedSampler(
|
| 73 |
|
| 74 |
dataloader_train = DataLoader(
|
| 75 |
train_dataset,
|
|
@@ -90,7 +149,7 @@ def create_dataloaders(config):
|
|
| 90 |
)
|
| 91 |
|
| 92 |
dataloader_test = DataLoader(
|
| 93 |
-
|
| 94 |
batch_size=config['val_batch_size'],
|
| 95 |
sampler=test_sampler,
|
| 96 |
num_workers=config['num_workers'],
|
|
@@ -98,4 +157,4 @@ def create_dataloaders(config):
|
|
| 98 |
drop_last=True
|
| 99 |
)
|
| 100 |
|
| 101 |
-
return dataloader_train, dataloader_val, dataloader_test,
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
class OceanCurrentDataset(Dataset):
|
| 9 |
+
def __init__(self, data_path, input_steps=10, output_steps=10,
|
| 10 |
+
start_time=None, end_time=None, mean=None, std=None, transform=None):
|
| 11 |
+
"""
|
| 12 |
+
Ocean current dataset class
|
| 13 |
+
:param data_path: Path to NetCDF file
|
| 14 |
+
:param input_steps: Number of input time steps
|
| 15 |
+
:param output_steps: Number of prediction time steps
|
| 16 |
+
:param start_time: Start time (days since 1950-01-01)
|
| 17 |
+
:param end_time: End time (days since 1950-01-01)
|
| 18 |
+
:param mean: Precomputed mean
|
| 19 |
+
:param std: Precomputed standard deviation
|
| 20 |
+
:param transform: Data augmentation transform
|
| 21 |
+
"""
|
| 22 |
self.data_path = data_path
|
| 23 |
self.input_steps = input_steps
|
| 24 |
self.output_steps = output_steps
|
| 25 |
self.transform = transform
|
| 26 |
self.total_steps = input_steps + output_steps
|
| 27 |
+
self.start_time = start_time
|
| 28 |
+
self.end_time = end_time
|
| 29 |
|
| 30 |
+
# Load and preprocess data
|
| 31 |
+
self.data, self.time_values = self._load_and_process_data()
|
| 32 |
+
|
| 33 |
+
# Set statistics
|
| 34 |
+
if mean is not None and std is not None:
|
| 35 |
+
self.mean = 0
|
| 36 |
+
self.std = 1
|
| 37 |
+
else:
|
| 38 |
+
self.mean, self.std = 0, 1
|
| 39 |
|
| 40 |
def _load_and_process_data(self):
|
| 41 |
+
"""Load and process NetCDF data"""
|
| 42 |
with nc.Dataset(self.data_path, 'r') as ds:
|
| 43 |
+
# Read time variable
|
| 44 |
+
time_var = ds['time']
|
| 45 |
+
time_values = time_var[:]
|
| 46 |
+
|
| 47 |
+
# Filter indices within time range
|
| 48 |
+
if self.start_time is not None or self.end_time is not None:
|
| 49 |
+
time_mask = np.ones_like(time_values, dtype=bool)
|
| 50 |
+
if self.start_time is not None:
|
| 51 |
+
time_mask &= (time_values >= self.start_time)
|
| 52 |
+
if self.end_time is not None:
|
| 53 |
+
time_mask &= (time_values <= self.end_time)
|
| 54 |
+
valid_indices = np.where(time_mask)[0]
|
| 55 |
+
else:
|
| 56 |
+
valid_indices = np.arange(len(time_values))
|
| 57 |
+
|
| 58 |
+
# Handle missing values
|
| 59 |
def process_var(var):
|
| 60 |
+
arr = var[:][valid_indices] # Slice by time range
|
| 61 |
if '_FillValue' in var.ncattrs():
|
| 62 |
fill_value = var._FillValue
|
| 63 |
arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
|
| 64 |
return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
|
| 65 |
|
| 66 |
+
# Load and merge UV components
|
| 67 |
+
ugos = process_var(ds['ugos']) # (time, lat, lon)
|
| 68 |
vgos = process_var(ds['vgos'])
|
| 69 |
|
| 70 |
+
# Adjust dimension order [time, channels, lat, lon]
|
| 71 |
+
data = torch.stack([ugos, vgos], dim=1)
|
| 72 |
+
return data, time_values[valid_indices]
|
| 73 |
|
| 74 |
def _compute_stats(self):
|
| 75 |
+
"""Calculate dataset statistics"""
|
| 76 |
+
# Compute mean and std using entire dataset
|
| 77 |
+
return torch.mean(self.data), torch.std(self.data)
|
| 78 |
|
| 79 |
def __len__(self):
|
| 80 |
return len(self.data) - self.total_steps + 1
|
|
|
|
| 82 |
def __getitem__(self, idx):
|
| 83 |
window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
|
| 84 |
|
| 85 |
+
# Normalization
|
| 86 |
window = (window - self.mean) / self.std
|
| 87 |
|
| 88 |
+
# Split input and output
|
| 89 |
input_seq = window[:self.input_steps]
|
| 90 |
target_seq = window[self.input_steps:]
|
| 91 |
|
|
|
|
| 93 |
input_seq = self.transform(input_seq)
|
| 94 |
target_seq = self.transform(target_seq)
|
| 95 |
|
| 96 |
+
# Spatial downsampling
|
| 97 |
+
return input_seq[:, :, ::2, ::2], target_seq[:, :, ::2, ::2]
|
| 98 |
|
| 99 |
def create_dataloaders(config):
|
| 100 |
+
# Define time range (days since 1950-01-01)
|
| 101 |
+
# 1993-01-01 ≈ 15706 (confirmed from first value of time variable in raw data)
|
| 102 |
+
# 2020-12-31 ≈ 25931 (365*71 + 18 leap years - 1 day)
|
| 103 |
+
# 2023-12-31 ≈ 27027 (end time of raw data)
|
| 104 |
+
|
| 105 |
+
# Create training set (1993-2020)
|
| 106 |
+
train_dataset = OceanCurrentDataset(
|
| 107 |
data_path=config['data_path'],
|
| 108 |
input_steps=config['input_steps'],
|
| 109 |
+
output_steps=config['output_steps'],
|
| 110 |
+
start_time=15706, # 1993-01-01
|
| 111 |
+
end_time=25931 # 2020-12-31
|
| 112 |
)
|
| 113 |
|
| 114 |
+
# Get training set statistics
|
| 115 |
+
train_mean, train_std = train_dataset.mean, train_dataset.std
|
|
|
|
| 116 |
|
| 117 |
+
# Create validation and test sets (2021-2024) not include 2024 year
|
| 118 |
+
val_dataset = OceanCurrentDataset(
|
| 119 |
+
data_path=config['data_path'],
|
| 120 |
+
input_steps=config['input_steps'],
|
| 121 |
+
output_steps=config['output_steps'],
|
| 122 |
+
start_time=25932, # 2021-01-01
|
| 123 |
+
end_time=27027, # 2023-12-31
|
| 124 |
+
mean=train_mean,
|
| 125 |
+
std=train_std
|
| 126 |
)
|
| 127 |
|
| 128 |
+
|
| 129 |
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
| 130 |
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 131 |
+
test_sampler = DistributedSampler(val_dataset, shuffle=False)
|
| 132 |
|
| 133 |
dataloader_train = DataLoader(
|
| 134 |
train_dataset,
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
dataloader_test = DataLoader(
|
| 152 |
+
val_dataset,
|
| 153 |
batch_size=config['val_batch_size'],
|
| 154 |
sampler=test_sampler,
|
| 155 |
num_workers=config['num_workers'],
|
|
|
|
| 157 |
drop_last=True
|
| 158 |
)
|
| 159 |
|
| 160 |
+
return dataloader_train, dataloader_val, dataloader_test, train_mean, train_std
|