File size: 1,546 Bytes
75854b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import os
import torch
import torchvision
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
from torchvision.transforms import Lambda
import Diffusion.losses as losses
import random
import glob
import numpy as np
import utils
from Dataloader.dataloader0 import get_dataloader
from Dataloader.dataLoader import *
from Dataloader.dataloader_utils import thresh_img
import yaml
import argparse
tsfm = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
Data_Loader=get_dataloader(data_name = 'lct', mode='train')
dataset = Data_Loader(
target_res=[128] * 3,
transforms=None,
noise_scale=4.0e-05,
)
train_loader = DataLoader(
dataset,
batch_size=32,
# shuffle=False,
shuffle=True,
drop_last=True,
)
dataset2 = OminiDataset_v1(transform=None)
train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True)
dataset = OminiDataset_paired(transform=None, ROIs = ['leg'])
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# print(dataset.get_all_ROI())
# print(dataset.getitem())
# print(dataset.get_ALLdata())
# print(dataset.getitem(idx=11))
# exit()
for i, batch in enumerate(train_loader):
x0, x1 = batch
print(x0.shape,x1.shape)
print(x0.dtype,x1.dtype)
print(x0.min(),x0.max())
break
exit()
|