| import os
|
| import torch
|
| import torchvision
|
| from torch import nn
|
| from torchvision.utils import save_image
|
| from torch.utils.data import DataLoader
|
| from torch.optim import Adam, SGD
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
| from torchvision.transforms import Lambda
|
| import Diffusion.losses as losses
|
| import random
|
| import glob
|
| import numpy as np
|
| import utils
|
|
|
| from Dataloader.dataloader0 import get_dataloader
|
| from Dataloader.dataLoader import *
|
| from Dataloader.dataloader_utils import thresh_img
|
| import yaml
|
| import argparse
|
|
|
| tsfm = torchvision.transforms.Compose(
|
| [
|
| torchvision.transforms.ToTensor(),
|
| ]
|
| )
|
| Data_Loader=get_dataloader(data_name = 'lct', mode='train')
|
|
|
| dataset = Data_Loader(
|
| target_res=[128] * 3,
|
| transforms=None,
|
| noise_scale=4.0e-05,
|
| )
|
| train_loader = DataLoader(
|
| dataset,
|
| batch_size=32,
|
|
|
| shuffle=True,
|
| drop_last=True,
|
| )
|
|
|
|
|
| dataset2 = OminiDataset_v1(transform=None)
|
| train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True)
|
|
|
|
|
| dataset = OminiDataset_paired(transform=None, ROIs = ['leg'])
|
| train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| for i, batch in enumerate(train_loader):
|
| x0, x1 = batch
|
| print(x0.shape,x1.shape)
|
| print(x0.dtype,x1.dtype)
|
| print(x0.min(),x0.max())
|
| break
|
| exit()
|
|
|
|
|