ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from torch.utils.data import DataLoader, Dataset
class DatasetTopoDiff(Dataset):
def __init__(self, topologies, stress, strain, load_im, constraints):
self.topologies = topologies
self.constraints = constraints
self.image_size = topologies.shape[1]
self.stress = stress
self.strain = strain
self.load_im = load_im
def __len__(self):
return self.topologies.shape[0]
def __getitem__(self, idx):
cons = self.constraints[idx]
vol_frac = cons["VOL_FRAC"]
cons = np.zeros((5, self.image_size, self.image_size))
cons[0] = self.stress[idx]
cons[1] = self.strain[idx]
cons[2] = self.load_im[idx][:, :, 0]
cons[3] = self.load_im[idx][:, :, 1]
cons[4] = np.ones((self.image_size, self.image_size)) * vol_frac
return np.expand_dims(self.topologies[idx], 0) * 2 - 1, cons
def load_data_topodiff(
topologies, constraints, stress, strain, load_img, batch_size, deterministic=False
):
dataset = DatasetTopoDiff(topologies, stress, strain, load_img, constraints)
if deterministic:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
)
else:
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
)
while True:
yield from loader