import os import torch import torchvision from torch import nn from torchvision.utils import save_image from torch.utils.data import DataLoader from torch.optim import Adam, SGD from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN from torchvision.transforms import Lambda import Diffusion.losses as losses import random import glob import numpy as np import utils from Dataloader.dataloader0 import get_dataloader from Dataloader.dataLoader import * from Dataloader.dataloader_utils import thresh_img import yaml import argparse tsfm = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), ] ) Data_Loader=get_dataloader(data_name = 'lct', mode='train') dataset = Data_Loader( target_res=[128] * 3, transforms=None, noise_scale=4.0e-05, ) train_loader = DataLoader( dataset, batch_size=32, # shuffle=False, shuffle=True, drop_last=True, ) dataset2 = OminiDataset_v1(transform=None) train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True) dataset = OminiDataset_paired(transform=None, ROIs = ['leg']) train_loader = DataLoader(dataset, batch_size=32, shuffle=True) # print(dataset.get_all_ROI()) # print(dataset.getitem()) # print(dataset.get_ALLdata()) # print(dataset.getitem(idx=11)) # exit() for i, batch in enumerate(train_loader): x0, x1 = batch print(x0.shape,x1.shape) print(x0.dtype,x1.dtype) print(x0.min(),x0.max()) break exit()