File size: 7,307 Bytes
6434535 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | from config import config
import numpy as np
from dataset import myC2BDataset
from transform import myTransform
from torch.utils.data import DataLoader
from diffusers import LCMScheduler
from tqdm import tqdm
import cv2 as cv
import torch
import time
import os
from monai.utils import set_determinism
set_determinism(42)
def eval():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置运行环境
output_path = os.path.join("lcm_output_bs", "BS")
masked_output_path = os.path.join("lcm_output_bs", "Masked_BS")
fusion_output_path = os.path.join("lcm_output_bs", "Fusion_BS")
cxr_path = os.path.join("SZCH-X-Rays", "CXR")
masked_cxr_path = os.path.join("SZCH-X-Rays", "Masked_CXR")
mask_path = os.path.join("SZCH-X-Rays", "Mask")
model = torch.load("masked_lcm-600-2024-12-19-myModel.pth").to(device).eval()
VQGAN = torch.load("2024-12-12-Mask-SZCH-VQGAN.pth").to(device).eval()
testset_list = "SZCH.txt"
myTestSet = myC2BDataset(testset_list, cxr_path, masked_cxr_path, myTransform['testTransform'])
myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
# 设置噪声调度器
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
clip_sample=config.clip_sample,
clip_sample_range=config.initial_clip_sample_range_g)
noise_scheduler.set_timesteps(config.num_infer_timesteps)
with torch.no_grad():
progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
total_start = time.time()
for step, batch in progress_bar:
cxr = batch[0].to(device=device, non_blocking=True).float()
masked_cxr = batch[1].to(device=device, non_blocking=True).float()
filename = batch[2][0]
cxr_copy = np.array(cxr.detach().to("cpu"))
cxr_copy = np.squeeze(cxr_copy) # HW
cxr_copy = cxr_copy * 0.5 + 0.5
cxr_copy *= 255
cxr_copy = cxr_copy.astype(np.int8)
cxr = VQGAN.encode_stage_2_inputs(cxr)
masked_cxr = VQGAN.encode_stage_2_inputs(masked_cxr)
noise = torch.randn_like(cxr).to(device)
sample = torch.cat((noise, cxr), dim=1).to(device) # BCHW
masked_sample = torch.cat((noise, masked_cxr), dim=1).to(device) # BCHW
for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
residual = model(sample, torch.Tensor((t,)).to(device).long()).to(device)
masked_residual = model(masked_sample, torch.Tensor((t,)).to(device).long()).to(device)
# masked_residual = (1 - config.alpha) * residual + config.alpha * masked_residual
masked_residual = config.alpha * masked_residual + (1 - config.alpha) * torch.randn_like(
masked_residual).to(device) / torch.std(masked_residual)
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
clip_sample=config.clip_sample,
clip_sample_range=
config.initial_clip_sample_range_g
+ config.clip_rate * j
)
noise_scheduler.set_timesteps(config.num_infer_timesteps)
sample = noise_scheduler.step(residual, t, sample).prev_sample
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
clip_sample=config.clip_sample,
clip_sample_range=
config.initial_clip_sample_range_l
+ config.clip_rate * j
)
noise_scheduler.set_timesteps(config.num_infer_timesteps)
masked_sample = noise_scheduler.step(masked_residual, t, masked_sample).prev_sample
sample = torch.cat((sample[:, :4], cxr), dim=1) # BCHW
masked_sample = torch.cat((masked_sample[:, :4], masked_cxr), dim=1).to(device) # BCHW
if config.output_feature_map:
bs_show = np.array(sample[:, 0].detach().to("cpu"))
bs_show = np.squeeze(bs_show) # HW
bs_show = bs_show * 0.5 + 0.5
bs_show = np.clip(bs_show, 0, 1)
masked_bs_show = np.array(masked_sample[:, 0].detach().to("cpu"))
masked_bs_show = np.squeeze(masked_bs_show) # HW
masked_bs_show = masked_bs_show * 0.5 + 0.5
masked_bs_show = np.clip(masked_bs_show, 0, 1)
if not config.use_server:
cv.imshow("win1", bs_show)
cv.imshow("win2", masked_bs_show)
cv.waitKey(1)
mask = cv.imread(os.path.join(mask_path, filename), 0)
mask[mask < 255] = 0
bs = VQGAN.decode((sample[:, :4]))
bs = np.array(bs.detach().to("cpu"))
bs = np.squeeze(bs) # HW
bs = bs * 0.5 + 0.5
bs[cxr_copy == 0] = 0
masked_bs = VQGAN.decode((masked_sample[:, :4]))
masked_bs = np.array(masked_bs.detach().to("cpu"))
masked_bs = np.squeeze(masked_bs) # HW
masked_bs = masked_bs * 0.5 + 0.5
masked_bs[mask > 0] = masked_bs[mask > 0] + np.mean(bs[mask > 0]) - np.mean(masked_bs[mask > 0])
masked_bs[cxr_copy == 0] = 0
if not config.use_server:
cv.imshow("win3", bs)
cv.imshow("win4", masked_bs)
cv.waitKey(1)
bs *= 255
cv.imwrite(os.path.join(output_path, filename), bs)
masked_bs *= 255
cv.imwrite(os.path.join(masked_output_path, filename), masked_bs)
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask)
min_area = 100
for i in range(1, num_labels):
if stats[i, cv.CC_STAT_AREA] < min_area:
labels[labels == i] = 0
mask[labels == 0] = 0
br = cv.boundingRect(mask)
p = (br[0] + br[2] // 2, br[1] + br[3] // 2)
masked_bs = np.clip(masked_bs, 0, 255)
masked_bs = cv.cvtColor(masked_bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
bs = np.clip(bs, 0, 255)
bs = cv.cvtColor(bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
fusion_bs = cv.seamlessClone(masked_bs, bs, mask, p, cv.MONOCHROME_TRANSFER)
# cv.rectangle(fusion_bs, br, (0, 255, 0), 2)
# fusion_bs[mask==255]=(255, 0, 0)
cv.imwrite(os.path.join(fusion_output_path, filename), fusion_bs)
total_time = time.time() - total_start
print(f"Total time: {total_time}.")
if __name__ == "__main__":
eval()
|