| from config import config
|
|
|
| from transform import myTransform
|
| from torch.utils.data import DataLoader
|
| from BBDMScheduler import BBDMScheduler
|
| from tqdm import tqdm
|
| import cv2 as cv
|
| import time
|
| import os
|
| from dataset import mySingleDataset
|
| from monai.utils import set_determinism
|
|
|
| import numpy as np
|
| import torch
|
|
|
| set_determinism(42)
|
|
|
|
|
| def eval():
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| output_path = "dit_output_N8"
|
| w = 1.5
|
| cxr_path = os.path.join("SZCH-X-Rays-741", "CXR")
|
|
|
| model = torch.load("YOUR DEBONEDIT MODEL PATH").to(device).eval().requires_grad_(False)
|
| VQGAN = torch.load("YOUR DAE MODEL PATH").to(device).eval().requires_grad_(False)
|
| testset_list = "SZCH-X-Rays_testset.txt"
|
| myTestSet = mySingleDataset(testset_list, cxr_path, myTransform['testTransform'])
|
| myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
|
| noise_scheduler = BBDMScheduler(num_train_timesteps=config.num_train_timesteps)
|
|
|
| with torch.no_grad():
|
| progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
|
| total_start = time.time()
|
| for step, batch in progress_bar:
|
| noise_scheduler.set_timesteps(config.num_infer_timesteps, device="cuda",
|
| original_inference_steps=config.num_train_timesteps)
|
| cxr = batch[0].to(device=device, non_blocking=True).float()
|
| filename = batch[1][0]
|
| cxr = VQGAN.encode_stage_2_inputs(cxr)
|
| sample = cxr.clone()
|
|
|
| for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
|
| sample_uc = torch.cat((sample, torch.randn_like(cxr)), dim=1)
|
| residual_uc = model(sample_uc, torch.Tensor((t,)).to(device).long())[:, :4].to(device)
|
|
|
| sample = torch.cat((sample, cxr), dim=1)
|
| residual = model(sample, torch.Tensor((t,)).to(device).long())[:, :4].to(device)
|
|
|
| residual = w * residual + (1 - w) * residual_uc
|
|
|
| samples = noise_scheduler.step(residual, t, sample[:, :4], cxr)
|
| sample = samples.prev_sample
|
|
|
| if not config.use_server:
|
| bs_show = np.transpose(np.squeeze(sample.cpu().detach().numpy())[0:3], (1, 2, 0))
|
|
|
| bs_show = bs_show * 0.5 + 0.5
|
| bs_show = np.clip(bs_show, 0, 1)
|
| cv.imshow("win1", bs_show)
|
| cv.waitKey(1)
|
|
|
| bs = VQGAN.decode(sample)
|
| bs = np.array(bs.detach().to("cpu"))
|
| bs = np.squeeze(bs)
|
| bs = bs * 0.5 + 0.5
|
| bs = np.clip(bs, 0, 1)
|
| bs *= 255
|
| bs = bs.astype(np.uint8)
|
| cv.imwrite(os.path.join(output_path, filename), bs)
|
| total_time = time.time() - total_start
|
| print(f"Total time: {total_time}.")
|
|
|
|
|
| if __name__ == "__main__":
|
| eval()
|
|
|