vicca / vicca_api.py
sayehghp's picture
Solved thickness
daaccda
# vicca_api.py
import os
import time
import cv2
import pandas as pd
import torch
import supervision as sv
from weights_utils import ensure_all_vicca_weights, get_weight
from vg_token_attention import run_token_ca_visualization
from VG.groundingdino.util.inference import annotate, load_image
# Make sure all heavy weights are present once per container
ensure_all_vicca_weights()
from inference import (
gen_cxr,
cal_shift,
get_local_bbox,
extract_tensor,
chexbert_pathology,
)
def run_vicca(
image_path: str,
text_prompt: str,
box_threshold: float = 0.2,
text_threshold: float = 0.2,
num_samples: int = 4,
output_path: str = "CXRGen/test/samples/output/",
attn_terms=None,
):
"""
Top-level VICCA API used by app.py / Gradio.
"""
# Use locally cached VICCA weights instead of re-downloading from Hub
weight_path_gencxr = get_weight("CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth")
weight_path_vg = get_weight("VG/weights/checkpoint0399.pth")
os.makedirs(output_path, exist_ok=True)
# 1) Generate CXR samples (force CPU to avoid GPU OOM on T4)
gen_cxr(
weight_path=weight_path_gencxr,
image_path=image_path,
text_prompt=text_prompt,
num_samples=num_samples,
output_path=output_path,
device="cpu", # ๐Ÿ‘ˆ important: run diffusion on CPU
)
# 2) Pick best sample by similarity
time.sleep(4) # ensure CSV is written; could be tuned down
csv_path = os.path.join(output_path, "info_path_similarity.csv")
df = pd.read_csv(csv_path)
sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]]
max_sim_index = sim_ratios.index(max(sim_ratios))
max_sim_gen_path = df["gen_sample_path"][max_sim_index]
# 3) Compute shifts (still fine on CPU)
sx, sy = cal_shift(image_path, max_sim_gen_path)
image_source, _ = load_image(image_path)
# 4) Visual grounding
boxes, logits, phrases = get_local_bbox(
weight_path_vg,
image_path,
text_prompt,
box_threshold,
text_threshold,
)
annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=8, text_thickness=1)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict)
VG_path = os.path.join(output_path, "VG_annotations.jpg")
cv2.imwrite(VG_path, annotated_frame)
# 5) SSIM per bbox
image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE)
ssim_scores = []
from ssim import ssim # local import to avoid cycles
for bbox in boxes:
x1, y1, x2, y2 = bbox
bbox1 = [x1, y1, x2 - x1, y2 - y1]
bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1]
bx1, by1, bw1, bh1 = [int(val) for val in bbox1]
bx2, by2, bw2, bh2 = [int(val) for val in bbox2]
roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1]
roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2]
if roi_org.shape == roi_gen.shape and roi_org.size > 0:
score = ssim(roi_org, roi_gen)
ssim_scores.append(score)
# Optional: attention visualization for terms (e.g. from CheXbert)
attn_paths = None
attn_terms = chexbert_pathology(text_prompt)
if attn_terms:
cfg_path = "VG/groundingdino/config/GroundingDINO_SwinT_OGC_2.py"
vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth")
attn_out_dir = os.path.join(output_path, "attn_overlays")
try:
attn_paths = run_token_ca_visualization(
cfg_path=cfg_path,
ckpt_path=vg_ckpt_path,
image_path=image_path, # or max_sim_gen_path if you prefer generated CXR
prompt=text_prompt,
terms=attn_terms,
out_dir=attn_out_dir,
device="cuda" if torch.cuda.is_available() else "cpu",
score_thresh=0.25,
topk=100,
term_agg="mean",
save_per_term=True,
)
except RuntimeError as e:
print("Token attention visualization failed:", e)
attn_paths = {}
return {
"boxes": boxes,
"logits": logits,
"phrases": phrases,
"ssim_scores": ssim_scores,
"shift_x": sx,
"shift_y": sy,
"best_generated_image_path": max_sim_gen_path,
"attention_overlays": attn_paths,
"VG_annotated_image_path": VG_path,
}