File size: 4,122 Bytes
830a558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os

import h5py
from torch.utils import data


class Dedalus2DDataset(data.Dataset):
    "Dataset for MHD 2D Dataset"

    def __init__(
        self,
        data_path,
        output_names="output-",
        field_names=["magnetic field", "velocity"],
        num_train=None,
        num_test=None,
        num=None,
        use_train=True,
    ):
        self.data_path = data_path
        output_names = "output-" + "?"*len(str(len(os.listdir(data_path))))
        self.output_names = output_names
        raw_path = os.path.join(data_path, output_names, "*.h5")
        files_raw = sorted(glob.glob(raw_path))
        self.files_raw = files_raw
        self.num_files_raw = num_files_raw = len(files_raw)
        self.field_names = field_names
        self.use_train = use_train

        # Handle num parameter: -1 means use full dataset, otherwise limit to specified number
        if num is not None and num > 0:
            num_files_raw = min(num, num_files_raw)
            files_raw = files_raw[:num_files_raw]
            self.files_raw = files_raw
            self.num_files_raw = num_files_raw

        # Handle percentage-based splits
        if num_train is not None and num_train <= 1.0:
            # num_train is a percentage
            num_train = int(num_train * num_files_raw)
        elif num_train is None or num_train > num_files_raw:
            num_train = num_files_raw

        if num_test is not None and num_test <= 1.0:
            # num_test is a percentage
            num_test = int(num_test * num_files_raw)
        elif num_test is None or num_test > (num_files_raw - num_train):
            num_test = num_files_raw - num_train

        self.num_train = num_train
        self.train_files = self.files_raw[:num_train]
        self.num_test = num_test
        self.test_end = test_end = num_train + num_test
        self.test_files = self.files_raw[num_train:test_end]
        
        if (self.use_train) or (self.test_files is None):
            files = self.train_files
        else:
            files = self.test_files
        self.files = files
        self.num_files = num_files = len(files)

    def __len__(self):
        length = len(self.files)
        return length

    def __getitem__(self, index):
        "Gets item for dataloader"
        file = self.files[index]

        field_names = self.field_names
        fields = {}
        coords = []
        with h5py.File(file, mode="r") as h5file:
            data_file = h5file["tasks"]
            keys = list(data_file.keys())
            if field_names is None:
                field_names = keys
            for field_name in field_names:
                if field_name in data_file:
                    field = data_file[field_name][:]
                    fields[field_name] = field
                else:
                    print(f"field name {field_name} not found")
        dataset = fields
        return dataset

    def get_coords(self, index):
        "Gets coordinates of t, x, y for dataloader"
        file = self.files[index]
        with h5py.File(file, mode="r") as h5file:
            data_file = h5file["tasks"]
            keys = list(data_file.keys())
            dims = data_file[keys[0]].dims

            ndims = len(dims)
            t = dims[0]["sim_time"][:]
            x = dims[ndims - 2][0][:]
            y = dims[ndims - 1][0][:]
        return t, x, y