File size: 4,583 Bytes
0f8411f
 
 
 
 
b94ff1b
27e3844
0f8411f
 
b94ff1b
d80de43
0f8411f
 
 
 
 
 
 
 
 
b94ff1b
0f8411f
 
 
 
 
 
 
 
 
b94ff1b
0f8411f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d80de43
0f8411f
 
 
 
 
 
 
 
daaccda
27e3844
 
 
 
0f8411f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b94ff1b
 
 
 
e90d12e
b94ff1b
 
ecf378d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b94ff1b
0f8411f
 
 
 
 
 
 
 
b94ff1b
27e3844
0f8411f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# vicca_api.py
import os
import time
import cv2
import pandas as pd
import torch
import supervision as sv

from weights_utils import ensure_all_vicca_weights, get_weight
from vg_token_attention import run_token_ca_visualization
from VG.groundingdino.util.inference import annotate, load_image

# Make sure all heavy weights are present once per container
ensure_all_vicca_weights()

from inference import (
    gen_cxr,
    cal_shift,
    get_local_bbox,
    extract_tensor,
    chexbert_pathology,
)

def run_vicca(
    image_path: str,
    text_prompt: str,
    box_threshold: float = 0.2,
    text_threshold: float = 0.2,
    num_samples: int = 4,
    output_path: str = "CXRGen/test/samples/output/",
    attn_terms=None,
):
    """
    Top-level VICCA API used by app.py / Gradio.
    """

    # Use locally cached VICCA weights instead of re-downloading from Hub
    weight_path_gencxr = get_weight("CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth")
    weight_path_vg = get_weight("VG/weights/checkpoint0399.pth")

    os.makedirs(output_path, exist_ok=True)

    # 1) Generate CXR samples (force CPU to avoid GPU OOM on T4)
    gen_cxr(
        weight_path=weight_path_gencxr,
        image_path=image_path,
        text_prompt=text_prompt,
        num_samples=num_samples,
        output_path=output_path,
        device="cpu",          # 👈 important: run diffusion on CPU
    )

    # 2) Pick best sample by similarity
    time.sleep(4)  # ensure CSV is written; could be tuned down
    csv_path = os.path.join(output_path, "info_path_similarity.csv")
    df = pd.read_csv(csv_path)

    sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]]
    max_sim_index = sim_ratios.index(max(sim_ratios))
    max_sim_gen_path = df["gen_sample_path"][max_sim_index]

    # 3) Compute shifts (still fine on CPU)
    sx, sy = cal_shift(image_path, max_sim_gen_path)

    image_source, _ = load_image(image_path)
    # 4) Visual grounding
    boxes, logits, phrases = get_local_bbox(
        weight_path_vg,
        image_path,
        text_prompt,
        box_threshold,
        text_threshold,
    )
    annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=8, text_thickness=1)

    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict)
    VG_path = os.path.join(output_path, "VG_annotations.jpg")
    cv2.imwrite(VG_path, annotated_frame)

    # 5) SSIM per bbox
    image_org_cv = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE)

    ssim_scores = []
    from ssim import ssim  # local import to avoid cycles

    for bbox in boxes:
        x1, y1, x2, y2 = bbox
        bbox1 = [x1, y1, x2 - x1, y2 - y1]
        bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1]

        bx1, by1, bw1, bh1 = [int(val) for val in bbox1]
        bx2, by2, bw2, bh2 = [int(val) for val in bbox2]

        roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1]
        roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2]

        if roi_org.shape == roi_gen.shape and roi_org.size > 0:
            score = ssim(roi_org, roi_gen)
            ssim_scores.append(score)

    # Optional: attention visualization for terms (e.g. from CheXbert)
    attn_paths = None
    attn_terms = chexbert_pathology(text_prompt)
    if attn_terms:
        cfg_path = "VG/groundingdino/config/GroundingDINO_SwinT_OGC_2.py"
        vg_ckpt_path = get_weight("VG/weights/checkpoint0399_log4.pth")
        attn_out_dir = os.path.join(output_path, "attn_overlays")
        try:  
            attn_paths = run_token_ca_visualization(
                cfg_path=cfg_path,
                ckpt_path=vg_ckpt_path,
                image_path=image_path,     # or max_sim_gen_path if you prefer generated CXR
                prompt=text_prompt,
                terms=attn_terms,
                out_dir=attn_out_dir,
                device="cuda" if torch.cuda.is_available() else "cpu",
                score_thresh=0.25,
                topk=100,
                term_agg="mean",
                save_per_term=True,
            )
        except RuntimeError as e:
            print("Token attention visualization failed:", e)
            attn_paths = {}

    return {
        "boxes": boxes,
        "logits": logits,
        "phrases": phrases,
        "ssim_scores": ssim_scores,
        "shift_x": sx,
        "shift_y": sy,
        "best_generated_image_path": max_sim_gen_path,
        "attention_overlays": attn_paths,
        "VG_annotated_image_path": VG_path,
    }