| | |
| | import os |
| | import time |
| | import cv2 |
| | import pandas as pd |
| | import torch |
| | import supervision as sv |
| |
|
| | from weights_utils import ensure_all_vicca_weights, get_weight |
| | from vg_token_attention import run_token_ca_visualization |
| | from VG.groundingdino.util.inference import annotate, load_image |
| |
|
| | |
| | ensure_all_vicca_weights() |
| |
|
| | from inference import ( |
| | gen_cxr, |
| | cal_shift, |
| | get_local_bbox, |
| | extract_tensor, |
| | chexbert_pathology, |
| | ) |
| |
|
| | def run_vicca( |
| | image_path: str, |
| | text_prompt: str, |
| | box_threshold: float = 0.2, |
| | text_threshold: float = 0.2, |
| | num_samples: int = 4, |
| | output_path: str = "CXRGen/test/samples/output/", |
| | attn_terms=None, |
| | ): |
| | """ |
| | Top-level VICCA API used by app.py / Gradio. |
| | """ |
| |
|
| | |
| | weight_path_gencxr = get_weight("CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth") |
| | weight_path_vg = get_weight("VG/weights/checkpoint0399.pth") |
| |
|
| | os.makedirs(output_path, exist_ok=True) |
| |
|
| | |
| | gen_cxr( |
| | weight_path=weight_path_gencxr, |
| | image_path=image_path, |
| | text_prompt=text_prompt, |
| | num_samples=num_samples, |
| | output_path=output_path, |
| | device="cpu", |
| | ) |
| |
|
| | |
| | time.sleep(4) |
| | csv_path = os.path.join(output_path, "info_path_similarity.csv") |
| | df = pd.read_csv(csv_path) |
| |
|
| | sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]] |
| | max_sim_index = sim_ratios.index(max(sim_ratios)) |
| | max_sim_gen_path = df["gen_sample_path"][max_sim_index] |
| |
|
| | |
| | sx, sy = cal_shift(image_path, max_sim_gen_path) |
| |
|
| | image_source, _ = load_image(image_path) |
| | |
| | boxes, logits, phrases = get_local_bbox( |
| | weight_path_vg, |
| | image_path, |
| | text_prompt, |
| | box_threshold, |
| | text_threshold, |
| | ) |
| | annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=8, text_thickness=1) |
| |
|
| | annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict) |
| | VG_path = os.path.join(output_path, "VG_annotations.jpg") |
| | cv2.imwrite(VG_path, annotated_frame) |
| |
|
| | |
| | image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) |
| | image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE) |
| |
|
| | ssim_scores = [] |
| | from ssim import ssim |
| |
|
| | for bbox in boxes: |
| | x1, y1, x2, y2 = bbox |
| | bbox1 = [x1, y1, x2 - x1, y2 - y1] |
| | bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1] |
| |
|
| | bx1, by1, bw1, bh1 = [int(val) for val in bbox1] |
| | bx2, by2, bw2, bh2 = [int(val) for val in bbox2] |
| |
|
| | roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1] |
| | roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2] |
| |
|
| | if roi_org.shape == roi_gen.shape and roi_org.size > 0: |
| | score = ssim(roi_org, roi_gen) |
| | ssim_scores.append(score) |
| |
|
| | |
| | attn_paths = None |
| | attn_terms = chexbert_pathology(text_prompt) |
| | if attn_terms: |
| | cfg_path = "VG/groundingdino/config/GroundingDINO_SwinT_OGC_2.py" |
| | vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth") |
| | attn_out_dir = os.path.join(output_path, "attn_overlays") |
| | try: |
| | attn_paths = run_token_ca_visualization( |
| | cfg_path=cfg_path, |
| | ckpt_path=vg_ckpt_path, |
| | image_path=image_path, |
| | prompt=text_prompt, |
| | terms=attn_terms, |
| | out_dir=attn_out_dir, |
| | device="cuda" if torch.cuda.is_available() else "cpu", |
| | score_thresh=0.25, |
| | topk=100, |
| | term_agg="mean", |
| | save_per_term=True, |
| | ) |
| | except RuntimeError as e: |
| | print("Token attention visualization failed:", e) |
| | attn_paths = {} |
| |
|
| | return { |
| | "boxes": boxes, |
| | "logits": logits, |
| | "phrases": phrases, |
| | "ssim_scores": ssim_scores, |
| | "shift_x": sx, |
| | "shift_y": sy, |
| | "best_generated_image_path": max_sim_gen_path, |
| | "attention_overlays": attn_paths, |
| | "VG_annotated_image_path": VG_path, |
| | } |