Omini3D / dataloader_tester.py
maxmo2009's picture
Initial upload: OmniMorph codebase
75854b3 verified
import os
import torch
import torchvision
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
from torchvision.transforms import Lambda
import Diffusion.losses as losses
import random
import glob
import numpy as np
import utils
from Dataloader.dataloader0 import get_dataloader
from Dataloader.dataLoader import *
from Dataloader.dataloader_utils import thresh_img
import yaml
import argparse
tsfm = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
Data_Loader=get_dataloader(data_name = 'lct', mode='train')
dataset = Data_Loader(
target_res=[128] * 3,
transforms=None,
noise_scale=4.0e-05,
)
train_loader = DataLoader(
dataset,
batch_size=32,
# shuffle=False,
shuffle=True,
drop_last=True,
)
dataset2 = OminiDataset_v1(transform=None)
train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True)
dataset = OminiDataset_paired(transform=None, ROIs = ['leg'])
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# print(dataset.get_all_ROI())
# print(dataset.getitem())
# print(dataset.get_ALLdata())
# print(dataset.getitem(idx=11))
# exit()
for i, batch in enumerate(train_loader):
x0, x1 = batch
print(x0.shape,x1.shape)
print(x0.dtype,x1.dtype)
print(x0.min(),x0.max())
break
exit()