# vicca_api.py import os import time import cv2 import pandas as pd import torch import supervision as sv from weights_utils import ensure_all_vicca_weights, get_weight from vg_token_attention import run_token_ca_visualization from VG.groundingdino.util.inference import annotate, load_image # Make sure all heavy weights are present once per container ensure_all_vicca_weights() from inference import ( gen_cxr, cal_shift, get_local_bbox, extract_tensor, chexbert_pathology, ) def run_vicca( image_path: str, text_prompt: str, box_threshold: float = 0.2, text_threshold: float = 0.2, num_samples: int = 4, output_path: str = "CXRGen/test/samples/output/", attn_terms=None, ): """ Top-level VICCA API used by app.py / Gradio. """ # Use locally cached VICCA weights instead of re-downloading from Hub weight_path_gencxr = get_weight("CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth") weight_path_vg = get_weight("VG/weights/checkpoint0399.pth") os.makedirs(output_path, exist_ok=True) # 1) Generate CXR samples (force CPU to avoid GPU OOM on T4) gen_cxr( weight_path=weight_path_gencxr, image_path=image_path, text_prompt=text_prompt, num_samples=num_samples, output_path=output_path, device="cpu", # 👈 important: run diffusion on CPU ) # 2) Pick best sample by similarity time.sleep(4) # ensure CSV is written; could be tuned down csv_path = os.path.join(output_path, "info_path_similarity.csv") df = pd.read_csv(csv_path) sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]] max_sim_index = sim_ratios.index(max(sim_ratios)) max_sim_gen_path = df["gen_sample_path"][max_sim_index] # 3) Compute shifts (still fine on CPU) sx, sy = cal_shift(image_path, max_sim_gen_path) image_source, _ = load_image(image_path) # 4) Visual grounding boxes, logits, phrases = get_local_bbox( weight_path_vg, image_path, text_prompt, box_threshold, text_threshold, ) annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=8, text_thickness=1) annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict) VG_path = os.path.join(output_path, "VG_annotations.jpg") cv2.imwrite(VG_path, annotated_frame) # 5) SSIM per bbox image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE) ssim_scores = [] from ssim import ssim # local import to avoid cycles for bbox in boxes: x1, y1, x2, y2 = bbox bbox1 = [x1, y1, x2 - x1, y2 - y1] bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1] bx1, by1, bw1, bh1 = [int(val) for val in bbox1] bx2, by2, bw2, bh2 = [int(val) for val in bbox2] roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1] roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2] if roi_org.shape == roi_gen.shape and roi_org.size > 0: score = ssim(roi_org, roi_gen) ssim_scores.append(score) # Optional: attention visualization for terms (e.g. from CheXbert) attn_paths = None attn_terms = chexbert_pathology(text_prompt) if attn_terms: cfg_path = "VG/groundingdino/config/GroundingDINO_SwinT_OGC_2.py" vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth") attn_out_dir = os.path.join(output_path, "attn_overlays") try: attn_paths = run_token_ca_visualization( cfg_path=cfg_path, ckpt_path=vg_ckpt_path, image_path=image_path, # or max_sim_gen_path if you prefer generated CXR prompt=text_prompt, terms=attn_terms, out_dir=attn_out_dir, device="cuda" if torch.cuda.is_available() else "cpu", score_thresh=0.25, topk=100, term_agg="mean", save_per_term=True, ) except RuntimeError as e: print("Token attention visualization failed:", e) attn_paths = {} return { "boxes": boxes, "logits": logits, "phrases": phrases, "ssim_scores": ssim_scores, "shift_x": sx, "shift_y": sy, "best_generated_image_path": max_sim_gen_path, "attention_overlays": attn_paths, "VG_annotated_image_path": VG_path, }