sayehghp commited on
Commit
e09b1c8
·
1 Parent(s): ecf378d

Visualization

Browse files
Files changed (3) hide show
  1. CXRGen/sample_generation.py +142 -36
  2. inference.py +5 -0
  3. vg_token_attention.py +11 -2
CXRGen/sample_generation.py CHANGED
@@ -56,21 +56,81 @@ def get_args_parser():
56
  apply_uniformer = UniformerDetector()
57
  apply_canny = CannyDetector()
58
 
59
- def process(input_image, prompt, model, num_samples, image_resolution=512, ddim_steps=10, guess_mode=False, strength=1, scale=9, seed=-1, eta=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
  ddim_sampler = DDIMSampler(model)
 
62
  img = resize_image(HWC3(input_image), image_resolution)
63
- # detected_map = apply_uniformer(resize_image(input_image, image_resolution))
64
  H, W, C = img.shape
65
 
66
  detected_map = apply_canny(img, 100, 200)
67
  detected_map = HWC3(detected_map)
68
- # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
69
 
70
- # control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
71
- control = torch.from_numpy(detected_map.copy()).float().cpu() / 255.0
72
  control = torch.stack([control for _ in range(num_samples)], dim=0)
73
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
 
74
 
75
  if seed == -1:
76
  seed = random.randint(0, 65535)
@@ -79,29 +139,45 @@ def process(input_image, prompt, model, num_samples, image_resolution=512, ddim_
79
  if config.save_memory:
80
  model.low_vram_shift(is_diffusing=False)
81
 
82
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)]}
83
- #cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
84
- #un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
 
85
 
86
  shape = (4, H // 8, W // 8)
87
 
88
  if config.save_memory:
89
  model.low_vram_shift(is_diffusing=True)
90
 
91
- model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
92
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
93
- shape, cond, verbose=False, eta=eta,
94
- unconditional_guidance_scale=scale)
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  if config.save_memory:
97
  model.low_vram_shift(is_diffusing=False)
98
 
99
  x_samples = model.decode_first_stage(samples)
100
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
 
 
101
 
102
  results = [x_samples[i] for i in range(num_samples)]
 
103
  return [255 - detected_map] + results
104
 
 
105
  def imageEncoder(img):
106
  image_source, image = load_image(img)
107
  return image
@@ -121,42 +197,72 @@ def main(args):
121
  # if args.device == 'cuda':
122
  # model = model.cuda()
123
 
124
- # respect the passed device, but fall back safely
125
  if getattr(args, "device", "cpu") == "cuda" and torch.cuda.is_available():
126
  device = torch.device("cuda")
127
  else:
128
  device = torch.device("cpu")
129
 
130
- model = create_model('./CXRGen/models/cldm_v15_biovlp.yaml').cpu()
 
 
131
  state = load_state_dict(args.weight_path, location="cpu")
132
  model.load_state_dict(state, strict=False)
133
 
134
- # only move to GPU if we really decided to
135
- if device.type == "cuda":
136
- model = model.to(device)
137
- # # Decide device once
138
- # device = "cuda" if torch.cuda.is_available() else "cpu"
139
- # print(f"[VICCA] Using device: {device}", flush=True)
140
-
141
- # # Make sure the rest of the code sees the same device
142
- # args.device = device
143
-
144
- # # Create model on CPU then move to device
145
- # model = create_model("./CXRGen/models/cldm_v15_biovlp.yaml")
146
-
147
- # # Load weights with correct map_location
148
- # state_dict = load_state_dict(args.weight_path, location=device)
149
- # model.load_state_dict(state_dict, strict=False)
150
-
151
- # model = model.to(device)
152
  model.eval()
153
-
154
 
155
  prompt = args.text_prompt
156
  img_org = cv2.imread(args.image_path)
157
  img_w, img_h, c = img_org.shape
 
158
  input_img = lungsegment(args.image_path)
159
- gen_img = process(input_img, prompt, model, args.num_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  if args.plot_gen_image:
162
  for i in range(1,len(gen_img)):
 
56
  apply_uniformer = UniformerDetector()
57
  apply_canny = CannyDetector()
58
 
59
+ # def process(input_image, prompt, model, num_samples, image_resolution=512, ddim_steps=10, guess_mode=False, strength=1, scale=9, seed=-1, eta=0):
60
+ # with torch.no_grad():
61
+ # ddim_sampler = DDIMSampler(model)
62
+ # img = resize_image(HWC3(input_image), image_resolution)
63
+ # # detected_map = apply_uniformer(resize_image(input_image, image_resolution))
64
+ # H, W, C = img.shape
65
+
66
+ # detected_map = apply_canny(img, 100, 200)
67
+ # detected_map = HWC3(detected_map)
68
+ # # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
69
+
70
+ # # control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
71
+ # control = torch.from_numpy(detected_map.copy()).float().cpu() / 255.0
72
+ # control = torch.stack([control for _ in range(num_samples)], dim=0)
73
+ # control = einops.rearrange(control, 'b h w c -> b c h w').clone()
74
+
75
+ # if seed == -1:
76
+ # seed = random.randint(0, 65535)
77
+ # seed_everything(seed)
78
+
79
+ # if config.save_memory:
80
+ # model.low_vram_shift(is_diffusing=False)
81
+
82
+ # cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)]}
83
+ # #cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
84
+ # #un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
85
+
86
+ # shape = (4, H // 8, W // 8)
87
+
88
+ # if config.save_memory:
89
+ # model.low_vram_shift(is_diffusing=True)
90
+
91
+ # model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
92
+ # samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
93
+ # shape, cond, verbose=False, eta=eta,
94
+ # unconditional_guidance_scale=scale)
95
+
96
+ # if config.save_memory:
97
+ # model.low_vram_shift(is_diffusing=False)
98
+
99
+ # x_samples = model.decode_first_stage(samples)
100
+ # x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
101
+
102
+ # results = [x_samples[i] for i in range(num_samples)]
103
+ # return [255 - detected_map] + results
104
+
105
+ def process(
106
+ input_image,
107
+ prompt,
108
+ model,
109
+ num_samples,
110
+ device,
111
+ image_resolution=512,
112
+ ddim_steps=10,
113
+ guess_mode=False,
114
+ strength=1,
115
+ scale=9,
116
+ seed=-1,
117
+ eta=0,
118
+ ):
119
+ model = model.to(device)
120
+
121
  with torch.no_grad():
122
  ddim_sampler = DDIMSampler(model)
123
+
124
  img = resize_image(HWC3(input_image), image_resolution)
 
125
  H, W, C = img.shape
126
 
127
  detected_map = apply_canny(img, 100, 200)
128
  detected_map = HWC3(detected_map)
 
129
 
130
+ control = torch.from_numpy(detected_map.copy()).float() / 255.0
 
131
  control = torch.stack([control for _ in range(num_samples)], dim=0)
132
+ control = einops.rearrange(control, "b h w c -> b c h w").clone()
133
+ control = control.to(device)
134
 
135
  if seed == -1:
136
  seed = random.randint(0, 65535)
 
139
  if config.save_memory:
140
  model.low_vram_shift(is_diffusing=False)
141
 
142
+ cond = {
143
+ "c_concat": [control],
144
+ "c_crossattn": [model.get_learned_conditioning([prompt] * num_samples)],
145
+ }
146
 
147
  shape = (4, H // 8, W // 8)
148
 
149
  if config.save_memory:
150
  model.low_vram_shift(is_diffusing=True)
151
 
152
+ model.control_scales = (
153
+ [strength * (0.825 ** float(12 - i)) for i in range(13)]
154
+ if guess_mode
155
+ else ([strength] * 13)
156
+ )
157
+
158
+ samples, intermediates = ddim_sampler.sample(
159
+ ddim_steps,
160
+ num_samples,
161
+ shape,
162
+ cond,
163
+ verbose=False,
164
+ eta=eta,
165
+ unconditional_guidance_scale=scale,
166
+ )
167
 
168
  if config.save_memory:
169
  model.low_vram_shift(is_diffusing=False)
170
 
171
  x_samples = model.decode_first_stage(samples)
172
+ x_samples = (
173
+ einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5
174
+ ).cpu().numpy().clip(0, 255).astype(np.uint8)
175
 
176
  results = [x_samples[i] for i in range(num_samples)]
177
+
178
  return [255 - detected_map] + results
179
 
180
+
181
  def imageEncoder(img):
182
  image_source, image = load_image(img)
183
  return image
 
197
  # if args.device == 'cuda':
198
  # model = model.cuda()
199
 
 
200
  if getattr(args, "device", "cpu") == "cuda" and torch.cuda.is_available():
201
  device = torch.device("cuda")
202
  else:
203
  device = torch.device("cpu")
204
 
205
+ print(f"[CXRGen] Using device: {device}", flush=True)
206
+
207
+ model = create_model("./CXRGen/models/cldm_v15_biovlp.yaml").cpu()
208
  state = load_state_dict(args.weight_path, location="cpu")
209
  model.load_state_dict(state, strict=False)
210
 
211
+ model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  model.eval()
 
213
 
214
  prompt = args.text_prompt
215
  img_org = cv2.imread(args.image_path)
216
  img_w, img_h, c = img_org.shape
217
+
218
  input_img = lungsegment(args.image_path)
219
+
220
+ gen_img = process(
221
+ input_img,
222
+ prompt,
223
+ model,
224
+ args.num_samples,
225
+ device=device,
226
+ )
227
+
228
+
229
+
230
+ # # respect the passed device, but fall back safely
231
+ # if getattr(args, "device", "cpu") == "cuda" and torch.cuda.is_available():
232
+ # device = torch.device("cuda")
233
+ # else:
234
+ # device = torch.device("cpu")
235
+
236
+ # model = create_model('./CXRGen/models/cldm_v15_biovlp.yaml').cpu()
237
+ # state = load_state_dict(args.weight_path, location="cpu")
238
+ # model.load_state_dict(state, strict=False)
239
+
240
+ # # only move to GPU if we really decided to
241
+ # if device.type == "cuda":
242
+ # model = model.to(device)
243
+ # # # Decide device once
244
+ # # device = "cuda" if torch.cuda.is_available() else "cpu"
245
+ # # print(f"[VICCA] Using device: {device}", flush=True)
246
+
247
+ # # # Make sure the rest of the code sees the same device
248
+ # # args.device = device
249
+
250
+ # # # Create model on CPU then move to device
251
+ # # model = create_model("./CXRGen/models/cldm_v15_biovlp.yaml")
252
+
253
+ # # # Load weights with correct map_location
254
+ # # state_dict = load_state_dict(args.weight_path, location=device)
255
+ # # model.load_state_dict(state_dict, strict=False)
256
+
257
+ # # model = model.to(device)
258
+ # model.eval()
259
+
260
+
261
+ # prompt = args.text_prompt
262
+ # img_org = cv2.imread(args.image_path)
263
+ # img_w, img_h, c = img_org.shape
264
+ # input_img = lungsegment(args.image_path)
265
+ # gen_img = process(input_img, prompt, model, args.num_samples)
266
 
267
  if args.plot_gen_image:
268
  for i in range(1,len(gen_img)):
inference.py CHANGED
@@ -115,6 +115,7 @@ from DETR import svc
115
  from DETR.arguments import get_args_parser as get_detr_args_parser
116
  from VG import localization
117
  from ssim import ssim
 
118
 
119
  from CheXbert.src.label import label
120
 
@@ -214,6 +215,10 @@ def gen_cxr(weight_path, image_path, text_prompt, num_samples, output_path, devi
214
  args.num_samples = num_samples
215
  args.output_path = output_path
216
  args.weight_path = get_weight(weight_path)
 
 
 
 
217
  args.device = device
218
  sample_generation.main(args)
219
 
 
115
  from DETR.arguments import get_args_parser as get_detr_args_parser
116
  from VG import localization
117
  from ssim import ssim
118
+ import torch
119
 
120
  from CheXbert.src.label import label
121
 
 
215
  args.num_samples = num_samples
216
  args.output_path = output_path
217
  args.weight_path = get_weight(weight_path)
218
+ if torch.cuda.is_available():
219
+ device = torch.device("cuda")
220
+ else:
221
+ device = torch.device("cpu")
222
  args.device = device
223
  sample_generation.main(args)
224
 
vg_token_attention.py CHANGED
@@ -269,9 +269,18 @@ def run_token_ca_visualization(
269
  """
270
  if isinstance(terms, str):
271
  terms = [terms]
272
- terms = [t.strip() for t in terms if t and t.strip()]
 
 
 
 
 
273
  if not terms:
274
- raise ValueError("No terms provided for attention visualization.")
 
 
 
 
275
 
276
  device = device or DEVICE_DEFAULT
277
  model = load_model(cfg_path, ckpt_path).to(device).eval()
 
269
  """
270
  if isinstance(terms, str):
271
  terms = [terms]
272
+
273
+ prompt_lower = prompt.lower()
274
+
275
+ # Keep only terms that actually appear in the prompt (case-insensitive)
276
+ terms = [t for t in terms if t.lower() in prompt_lower]
277
+
278
  if not terms:
279
+ print(f"[TokenCA] No configured terms found in prompt: {prompt!r}")
280
+ return {} # or an empty dict / list, whatever you expect upstream
281
+ # terms = [t.strip() for t in terms if t and t.strip()]
282
+ # if not terms:
283
+ # raise ValueError("No terms provided for attention visualization.")
284
 
285
  device = device or DEVICE_DEFAULT
286
  model = load_model(cfg_path, ckpt_path).to(device).eval()