File size: 4,583 Bytes
0f8411f b94ff1b 27e3844 0f8411f b94ff1b d80de43 0f8411f b94ff1b 0f8411f b94ff1b 0f8411f d80de43 0f8411f daaccda 27e3844 0f8411f b94ff1b e90d12e b94ff1b ecf378d b94ff1b 0f8411f b94ff1b 27e3844 0f8411f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # vicca_api.py
import os
import time
import cv2
import pandas as pd
import torch
import supervision as sv
from weights_utils import ensure_all_vicca_weights, get_weight
from vg_token_attention import run_token_ca_visualization
from VG.groundingdino.util.inference import annotate, load_image
# Make sure all heavy weights are present once per container
ensure_all_vicca_weights()
from inference import (
gen_cxr,
cal_shift,
get_local_bbox,
extract_tensor,
chexbert_pathology,
)
def run_vicca(
image_path: str,
text_prompt: str,
box_threshold: float = 0.2,
text_threshold: float = 0.2,
num_samples: int = 4,
output_path: str = "CXRGen/test/samples/output/",
attn_terms=None,
):
"""
Top-level VICCA API used by app.py / Gradio.
"""
# Use locally cached VICCA weights instead of re-downloading from Hub
weight_path_gencxr = get_weight("CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth")
weight_path_vg = get_weight("VG/weights/checkpoint0399.pth")
os.makedirs(output_path, exist_ok=True)
# 1) Generate CXR samples (force CPU to avoid GPU OOM on T4)
gen_cxr(
weight_path=weight_path_gencxr,
image_path=image_path,
text_prompt=text_prompt,
num_samples=num_samples,
output_path=output_path,
device="cpu", # 👈 important: run diffusion on CPU
)
# 2) Pick best sample by similarity
time.sleep(4) # ensure CSV is written; could be tuned down
csv_path = os.path.join(output_path, "info_path_similarity.csv")
df = pd.read_csv(csv_path)
sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]]
max_sim_index = sim_ratios.index(max(sim_ratios))
max_sim_gen_path = df["gen_sample_path"][max_sim_index]
# 3) Compute shifts (still fine on CPU)
sx, sy = cal_shift(image_path, max_sim_gen_path)
image_source, _ = load_image(image_path)
# 4) Visual grounding
boxes, logits, phrases = get_local_bbox(
weight_path_vg,
image_path,
text_prompt,
box_threshold,
text_threshold,
)
annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=8, text_thickness=1)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict)
VG_path = os.path.join(output_path, "VG_annotations.jpg")
cv2.imwrite(VG_path, annotated_frame)
# 5) SSIM per bbox
image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE)
ssim_scores = []
from ssim import ssim # local import to avoid cycles
for bbox in boxes:
x1, y1, x2, y2 = bbox
bbox1 = [x1, y1, x2 - x1, y2 - y1]
bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1]
bx1, by1, bw1, bh1 = [int(val) for val in bbox1]
bx2, by2, bw2, bh2 = [int(val) for val in bbox2]
roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1]
roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2]
if roi_org.shape == roi_gen.shape and roi_org.size > 0:
score = ssim(roi_org, roi_gen)
ssim_scores.append(score)
# Optional: attention visualization for terms (e.g. from CheXbert)
attn_paths = None
attn_terms = chexbert_pathology(text_prompt)
if attn_terms:
cfg_path = "VG/groundingdino/config/GroundingDINO_SwinT_OGC_2.py"
vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth")
attn_out_dir = os.path.join(output_path, "attn_overlays")
try:
attn_paths = run_token_ca_visualization(
cfg_path=cfg_path,
ckpt_path=vg_ckpt_path,
image_path=image_path, # or max_sim_gen_path if you prefer generated CXR
prompt=text_prompt,
terms=attn_terms,
out_dir=attn_out_dir,
device="cuda" if torch.cuda.is_available() else "cpu",
score_thresh=0.25,
topk=100,
term_agg="mean",
save_per_term=True,
)
except RuntimeError as e:
print("Token attention visualization failed:", e)
attn_paths = {}
return {
"boxes": boxes,
"logits": logits,
"phrases": phrases,
"ssim_scores": ssim_scores,
"shift_x": sx,
"shift_y": sy,
"best_generated_image_path": max_sim_gen_path,
"attention_overlays": attn_paths,
"VG_annotated_image_path": VG_path,
} |