DeBoneDiT / code /dit_eval.py
diaoquesang's picture
Upload 15 files
6107278 verified
Raw
History Blame Contribute Delete
3.05 kB
from config import config
from transform import myTransform
from torch.utils.data import DataLoader
from BBDMScheduler import BBDMScheduler
from tqdm import tqdm
import cv2 as cv
import time
import os
from dataset import mySingleDataset
from monai.utils import set_determinism
import numpy as np
import torch
set_determinism(42)
def eval():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
output_path = "dit_output_N8"
w = 1.5 # 1.5 for SZCH-X-Rays; 1.0 for JSRT
cxr_path = os.path.join("SZCH-X-Rays-741", "CXR")
model = torch.load("YOUR DEBONEDIT MODEL PATH").to(device).eval().requires_grad_(False)
VQGAN = torch.load("YOUR DAE MODEL PATH").to(device).eval().requires_grad_(False)
testset_list = "SZCH-X-Rays_testset.txt"
myTestSet = mySingleDataset(testset_list, cxr_path, myTransform['testTransform'])
myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
noise_scheduler = BBDMScheduler(num_train_timesteps=config.num_train_timesteps)
with torch.no_grad():
progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
total_start = time.time()
for step, batch in progress_bar:
noise_scheduler.set_timesteps(config.num_infer_timesteps, device="cuda",
original_inference_steps=config.num_train_timesteps)
cxr = batch[0].to(device=device, non_blocking=True).float()
filename = batch[1][0]
cxr = VQGAN.encode_stage_2_inputs(cxr)
sample = cxr.clone()
for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
sample_uc = torch.cat((sample, torch.randn_like(cxr)), dim=1)
residual_uc = model(sample_uc, torch.Tensor((t,)).to(device).long())[:, :4].to(device)
sample = torch.cat((sample, cxr), dim=1)
residual = model(sample, torch.Tensor((t,)).to(device).long())[:, :4].to(device)
residual = w * residual + (1 - w) * residual_uc
samples = noise_scheduler.step(residual, t, sample[:, :4], cxr)
sample = samples.prev_sample
if not config.use_server:
bs_show = np.transpose(np.squeeze(sample.cpu().detach().numpy())[0:3], (1, 2, 0))
bs_show = bs_show * 0.5 + 0.5
bs_show = np.clip(bs_show, 0, 1)
cv.imshow("win1", bs_show)
cv.waitKey(1)
bs = VQGAN.decode(sample)
bs = np.array(bs.detach().to("cpu"))
bs = np.squeeze(bs) # HW
bs = bs * 0.5 + 0.5
bs = np.clip(bs, 0, 1)
bs *= 255
bs = bs.astype(np.uint8)
cv.imwrite(os.path.join(output_path, filename), bs)
total_time = time.time() - total_start
print(f"Total time: {total_time}.")
if __name__ == "__main__":
eval()