File size: 4,418 Bytes
a57e1d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.utils.data as data_utils

class train_Dataset(data_utils.Dataset):
    def __init__(self, data, args):
        super(train_Dataset, self).__init__()
        self.args = args
        self.input_length = args['input_length']
        self.target_length = args['target_length']
        self.downsample_factor = args['downsample_factor']

        # The data is expected to be a tensor of shape [num_samples, time_steps, H, W]
        # Add a variables dimension to make it [num_samples, time_steps, variables, H, W]
        # Since we have only one variable (vorticity), variables = 1
        self.data = data.unsqueeze(2)  # Shape: [num_samples, time_steps, 1, H, W]

        # Split the data into training set (first 80%)
        total_samples = self.data.shape[0]
        self.start_index = 0
        self.end_index = int(0.8 * total_samples)
        self.data = self.data[self.start_index:self.end_index]

        self.num_samples = self.data.shape[0]
        self.num_time_steps = self.data.shape[1]
        self.variables_input = args.get('variables_input', [0])
        self.variables_output = args.get('variables_output', [0])

        # Create indices for sampling
        self.sample_indices = []

        max_t = self.num_time_steps - self.input_length - self.target_length + 1
        for s in range(self.num_samples):
            for t in range(self.input_length, max_t + self.input_length):
                self.sample_indices.append((s, t))

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        s, t = self.sample_indices[idx]

        # Extract input and target sequences
        input_seq = self.data[s, t - self.input_length:t, self.variables_input, :, :]  # [input_length, variables_input, H, W]
        target_seq = self.data[s, t:t + self.target_length, self.variables_output, :, :]  # [target_length, variables_output, H, W]

        # Apply downsampling if needed
        dsf = self.downsample_factor
        input_seq = input_seq[:, :, ::dsf, ::dsf]
        target_seq = target_seq[:, :, ::dsf, ::dsf]

        input_seq = input_seq.float()
        target_seq = target_seq.float()

        return input_seq, target_seq  # Shapes: [input_length, variables_input, H', W']

class test_Dataset(data_utils.Dataset):
    def __init__(self, data, args):
        super(test_Dataset, self).__init__()
        self.args = args
        self.input_length = args['input_length']
        self.target_length = args['target_length']
        self.downsample_factor = args['downsample_factor']

        # The data is expected to be a tensor of shape [num_samples, time_steps, H, W]
        # Add a variables dimension to make it [num_samples, time_steps, variables, H, W]
        self.data = data.unsqueeze(2)  # Shape: [num_samples, time_steps, 1, H, W]

        # Split the data into test set (last 10%)
        total_samples = self.data.shape[0]
        self.start_index = int(0.9 * total_samples)
        self.end_index = total_samples
        self.data = self.data[self.start_index:self.end_index]

        self.num_samples = self.data.shape[0]
        self.num_time_steps = self.data.shape[1]
        self.variables_input = args.get('variables_input', [0])
        self.variables_output = args.get('variables_output', [0])

        # Create indices for sampling
        self.sample_indices = []

        max_t = self.num_time_steps - self.input_length - self.target_length + 1
        for s in range(self.num_samples):
            for t in range(self.input_length, max_t + self.input_length):
                self.sample_indices.append((s, t))

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        s, t = self.sample_indices[idx]

        # Extract input and target sequences
        input_seq = self.data[s, t - self.input_length:t, self.variables_input, :, :]  # [input_length, variables_input, H, W]
        target_seq = self.data[s, t:t + self.target_length, self.variables_output, :, :]  # [target_length, variables_output, H, W]

        # Apply downsampling if needed
        dsf = self.downsample_factor
        input_seq = input_seq[:, :, ::dsf, ::dsf]
        target_seq = target_seq[:, :, ::dsf, ::dsf]

        input_seq = input_seq.float()
        target_seq = target_seq.float()

        return input_seq, target_seq  # Shapes: [target_length, variables_output, H', W']