easylearning commited on
Commit
0c82ea5
·
verified ·
1 Parent(s): e7f7ebb

Update Exp3_Kuroshio_forecasting/dataloader_api/dataloader_kuroshio_128.py

Browse files
Exp3_Kuroshio_forecasting/dataloader_api/dataloader_kuroshio_128.py CHANGED
@@ -6,33 +6,75 @@ import netCDF4 as nc
6
  import numpy as np
7
 
8
  class OceanCurrentDataset(Dataset):
9
- def __init__(self, data_path, input_steps=10, output_steps=10, transform=None):
 
 
 
 
 
 
 
 
 
 
 
 
10
  self.data_path = data_path
11
  self.input_steps = input_steps
12
  self.output_steps = output_steps
13
  self.transform = transform
14
  self.total_steps = input_steps + output_steps
 
 
15
 
16
- self.data = self._load_and_process_data()
17
- self.mean, self.std = 0, 1
 
 
 
 
 
 
 
18
 
19
  def _load_and_process_data(self):
 
20
  with nc.Dataset(self.data_path, 'r') as ds:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def process_var(var):
22
- arr = var[:]
23
  if '_FillValue' in var.ncattrs():
24
  fill_value = var._FillValue
25
  arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
26
  return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
27
 
28
- ugos = process_var(ds['ugos'])
 
29
  vgos = process_var(ds['vgos'])
30
 
31
- # [time, channels, lat, lon]
32
- return torch.stack([ugos, vgos], dim=1)
 
33
 
34
  def _compute_stats(self):
35
- return torch.mean(self.data[:10000]), torch.std(self.data[:10000])
 
 
36
 
37
  def __len__(self):
38
  return len(self.data) - self.total_steps + 1
@@ -40,8 +82,10 @@ class OceanCurrentDataset(Dataset):
40
  def __getitem__(self, idx):
41
  window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
42
 
 
43
  window = (window - self.mean) / self.std
44
 
 
45
  input_seq = window[:self.input_steps]
46
  target_seq = window[self.input_steps:]
47
 
@@ -49,27 +93,42 @@ class OceanCurrentDataset(Dataset):
49
  input_seq = self.transform(input_seq)
50
  target_seq = self.transform(target_seq)
51
 
52
- return input_seq[:,:,::2,::2], target_seq[:,:,::2,::2]
 
53
 
54
  def create_dataloaders(config):
55
- full_dataset = OceanCurrentDataset(
 
 
 
 
 
 
56
  data_path=config['data_path'],
57
  input_steps=config['input_steps'],
58
- output_steps=config['output_steps']
 
 
59
  )
60
 
61
- train_size = 10000 - config['input_steps'] - config['output_steps'] + 1
62
- val_size = 500
63
- test_size = len(full_dataset) - train_size - val_size
64
 
65
- train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
66
- full_dataset, [train_size, val_size, test_size],
67
- generator=torch.Generator().manual_seed(config['seed'])
 
 
 
 
 
 
68
  )
69
 
 
70
  train_sampler = DistributedSampler(train_dataset, shuffle=True)
71
  val_sampler = DistributedSampler(val_dataset, shuffle=False)
72
- test_sampler = DistributedSampler(test_dataset, shuffle=False)
73
 
74
  dataloader_train = DataLoader(
75
  train_dataset,
@@ -90,7 +149,7 @@ def create_dataloaders(config):
90
  )
91
 
92
  dataloader_test = DataLoader(
93
- test_dataset,
94
  batch_size=config['val_batch_size'],
95
  sampler=test_sampler,
96
  num_workers=config['num_workers'],
@@ -98,4 +157,4 @@ def create_dataloaders(config):
98
  drop_last=True
99
  )
100
 
101
- return dataloader_train, dataloader_val, dataloader_test, full_dataset.mean, full_dataset.std
 
6
  import numpy as np
7
 
8
  class OceanCurrentDataset(Dataset):
9
+ def __init__(self, data_path, input_steps=10, output_steps=10,
10
+ start_time=None, end_time=None, mean=None, std=None, transform=None):
11
+ """
12
+ Ocean current dataset class
13
+ :param data_path: Path to NetCDF file
14
+ :param input_steps: Number of input time steps
15
+ :param output_steps: Number of prediction time steps
16
+ :param start_time: Start time (days since 1950-01-01)
17
+ :param end_time: End time (days since 1950-01-01)
18
+ :param mean: Precomputed mean
19
+ :param std: Precomputed standard deviation
20
+ :param transform: Data augmentation transform
21
+ """
22
  self.data_path = data_path
23
  self.input_steps = input_steps
24
  self.output_steps = output_steps
25
  self.transform = transform
26
  self.total_steps = input_steps + output_steps
27
+ self.start_time = start_time
28
+ self.end_time = end_time
29
 
30
+ # Load and preprocess data
31
+ self.data, self.time_values = self._load_and_process_data()
32
+
33
+ # Set statistics
34
+ if mean is not None and std is not None:
35
+ self.mean = 0
36
+ self.std = 1
37
+ else:
38
+ self.mean, self.std = 0, 1
39
 
40
  def _load_and_process_data(self):
41
+ """Load and process NetCDF data"""
42
  with nc.Dataset(self.data_path, 'r') as ds:
43
+ # Read time variable
44
+ time_var = ds['time']
45
+ time_values = time_var[:]
46
+
47
+ # Filter indices within time range
48
+ if self.start_time is not None or self.end_time is not None:
49
+ time_mask = np.ones_like(time_values, dtype=bool)
50
+ if self.start_time is not None:
51
+ time_mask &= (time_values >= self.start_time)
52
+ if self.end_time is not None:
53
+ time_mask &= (time_values <= self.end_time)
54
+ valid_indices = np.where(time_mask)[0]
55
+ else:
56
+ valid_indices = np.arange(len(time_values))
57
+
58
+ # Handle missing values
59
  def process_var(var):
60
+ arr = var[:][valid_indices] # Slice by time range
61
  if '_FillValue' in var.ncattrs():
62
  fill_value = var._FillValue
63
  arr = np.ma.masked_values(arr, fill_value).filled(np.nan)
64
  return torch.nan_to_num(torch.FloatTensor(arr), nan=0.0)
65
 
66
+ # Load and merge UV components
67
+ ugos = process_var(ds['ugos']) # (time, lat, lon)
68
  vgos = process_var(ds['vgos'])
69
 
70
+ # Adjust dimension order [time, channels, lat, lon]
71
+ data = torch.stack([ugos, vgos], dim=1)
72
+ return data, time_values[valid_indices]
73
 
74
  def _compute_stats(self):
75
+ """Calculate dataset statistics"""
76
+ # Compute mean and std using entire dataset
77
+ return torch.mean(self.data), torch.std(self.data)
78
 
79
  def __len__(self):
80
  return len(self.data) - self.total_steps + 1
 
82
  def __getitem__(self, idx):
83
  window = self.data[idx:idx+self.total_steps] # [T_total, C, H, W]
84
 
85
+ # Normalization
86
  window = (window - self.mean) / self.std
87
 
88
+ # Split input and output
89
  input_seq = window[:self.input_steps]
90
  target_seq = window[self.input_steps:]
91
 
 
93
  input_seq = self.transform(input_seq)
94
  target_seq = self.transform(target_seq)
95
 
96
+ # Spatial downsampling
97
+ return input_seq[:, :, ::2, ::2], target_seq[:, :, ::2, ::2]
98
 
99
  def create_dataloaders(config):
100
+ # Define time range (days since 1950-01-01)
101
+ # 1993-01-01 ≈ 15706 (confirmed from first value of time variable in raw data)
102
+ # 2020-12-31 ≈ 25931 (365*71 + 18 leap years - 1 day)
103
+ # 2023-12-31 ≈ 27027 (end time of raw data)
104
+
105
+ # Create training set (1993-2020)
106
+ train_dataset = OceanCurrentDataset(
107
  data_path=config['data_path'],
108
  input_steps=config['input_steps'],
109
+ output_steps=config['output_steps'],
110
+ start_time=15706, # 1993-01-01
111
+ end_time=25931 # 2020-12-31
112
  )
113
 
114
+ # Get training set statistics
115
+ train_mean, train_std = train_dataset.mean, train_dataset.std
 
116
 
117
+ # Create validation and test sets (2021-2024) not include 2024 year
118
+ val_dataset = OceanCurrentDataset(
119
+ data_path=config['data_path'],
120
+ input_steps=config['input_steps'],
121
+ output_steps=config['output_steps'],
122
+ start_time=25932, # 2021-01-01
123
+ end_time=27027, # 2023-12-31
124
+ mean=train_mean,
125
+ std=train_std
126
  )
127
 
128
+
129
  train_sampler = DistributedSampler(train_dataset, shuffle=True)
130
  val_sampler = DistributedSampler(val_dataset, shuffle=False)
131
+ test_sampler = DistributedSampler(val_dataset, shuffle=False)
132
 
133
  dataloader_train = DataLoader(
134
  train_dataset,
 
149
  )
150
 
151
  dataloader_test = DataLoader(
152
+ val_dataset,
153
  batch_size=config['val_batch_size'],
154
  sampler=test_sampler,
155
  num_workers=config['num_workers'],
 
157
  drop_last=True
158
  )
159
 
160
+ return dataloader_train, dataloader_val, dataloader_test, train_mean, train_std