vicca / inference.py
sayehghp's picture
Visualization
e09b1c8
"""
Input: image and text
Middle output: bbox (VG), Gen Image and similarity score (CXRGen), Shift_x&y (DETR)
Output: Localization Score, Reliability Score
python inference.py \
--image_path VG/38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg \
--text_prompt "Cardiomegaly with mild pulmonary vascular congestion."
"""
import sys, os
# ---------------------------------------------------------------------
# Make CheXbert's `src` folder importable (so `import utils` works)
# ---------------------------------------------------------------------
# BASE_DIR = os.path.dirname(__file__)
# CHEXBERT_SRC = os.path.join(BASE_DIR, "CheXbert", "src")
# if CHEXBERT_SRC not in sys.path:
# sys.path.insert(0, CHEXBERT_SRC)
# from label import label # now imports /app/CheXbert/src/label.py
import pandas as pd
import numpy as np
import time
import cv2
import argparse
from ast import literal_eval
# from nltk import tokenize
# sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
from pathlib import Path
import shutil
from huggingface_hub import hf_hub_download
from weights_utils import get_weight
# def ensure_vicca_weights():
# """
# Download all VICCA weights from the vicca-weights repo into the paths
# expected by the original code, with caching and safe subfolder handling.
# """
# repo_id = "sayehghp/vicca-weights"
# base = Path(__file__).parent
# weight_files = [
# # CheXbert
# "CheXbert/checkpoint/chexbert.pth",
# # Uniformer
# "CXRGen/annotator/ckpts/upernet_global_small.pth",
# # Diffusion
# "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
# # Encoders
# "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin",
# "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin",
# # Lung UNet
# "CXRGen/LungDetection/models/unet-2v.pt",
# "CXRGen/LungDetection/models/unet-6v.pt",
# # DETR
# "DETR/output/checkpoint.pth",
# # VG weights
# "VG/weights/checkpoint0399.pth",
# "VG/weights/checkpoint0399_log4.pth",
# "VG/weights/checkpoint_best_regular.pth",
# ]
# for rel_path in weight_files:
# local_path = base / rel_path
# local_path.parent.mkdir(parents=True, exist_ok=True)
# if local_path.exists():
# continue # skip if already mirrored into repo tree
# # Split repo path
# if "/" in rel_path:
# subfolder, filename = rel_path.rsplit("/", 1)
# else:
# subfolder, filename = None, rel_path
# cached_path = hf_hub_download(
# repo_id=repo_id,
# filename=filename,
# subfolder=subfolder if subfolder else None
# )
# # Copy from HF cache → repo tree
# shutil.copy2(cached_path, local_path)
# # Run once at import time so all weights are present before anything loads them
# ensure_vicca_weights()
# ---- SHIM FOR basicsr / torchvision ----
import types
from torchvision.transforms import functional as F
# Create a fake module torchvision.transforms.functional_tensor
# and expose rgb_to_grayscale from torchvision.transforms.functional
mod = types.ModuleType("torchvision.transforms.functional_tensor")
mod.rgb_to_grayscale = F.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = mod
# ---- END SHIM ----
from CXRGen import sample_generation
from DETR import svc
from DETR.arguments import get_args_parser as get_detr_args_parser
from VG import localization
from ssim import ssim
import torch
from CheXbert.src.label import label
def get_args_parser():
parser = argparse.ArgumentParser('Set the Input', add_help=True)
parser.add_argument('--weight_path_gencxr', type=str, default="CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
help="Path to the CXR generation trained model")
parser.add_argument('--weight_path_vg', type=str, default="VG/weights/checkpoint0399_log4.pth",
help="Path to the Visual Grounding trained model")
parser.add_argument('--image_path', type=str, required=True,
help="Path to the input image file.")
parser.add_argument('--text_prompt', type=str, required=True,
help="Text prompt describing pathology.")
parser.add_argument('--box_threshold', default=0.2, type=float, help="Box threshold for VG")
parser.add_argument('--text_threshold', default=0.2, type=float, help="Text threshold for VG")
parser.add_argument('--num_samples', type=int, default=4, help="Number of generated image samples.")
parser.add_argument('--output_path', type=str, default="CXRGen/test/samples/output/",
help="Path to save generated files.")
return parser
import re
def simple_sentence_split(text: str):
"""
Very lightweight sentence splitter good enough for radiology reports.
Splits on '.', ';', and newlines, then strips whitespace.
"""
parts = re.split(r"[.\n;]+", text)
return [p.strip() for p in parts if p.strip()]
path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
'Support Devices', 'No Finding']
# Cache CheXbert weights once at import time
CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth")
# def chexbert_pathology(text):
# sentences = list(set(tokenize.sent_tokenize(text)))
# path_dict = []
# for sentence in sentences:
# sentence = sentence.replace('\n',' ')
# sentence = sentence.replace('\s+',' ')
# chexbert_weight_path = get_weight("CheXbert/checkpoint/chexbert.pth")
# # pathology = np.array(label("CheXbert/checkpoint/chexbert.pth", sentence)).T[0]
# pathology = np.array(label(chexbert_weight_path, sentence)).T[0]
# if pathology[-1]==1 or len(list(set(pathology)))==1 or not any(e==1 for e in pathology):
# pass
# else:
# indice = [i for i, e in enumerate(pathology) if e==1]
# for ind in indice:
# path_dict.append(path_list[ind])
# return path_dict
def chexbert_pathology(text: str):
"""
Run CheXbert on the text and return a list of *positive* pathology labels,
deduplicated.
"""
# If NLTK punkt ever becomes a problem on Spaces, replace this with a simple split.
# sentences = list(set(tokenize.sent_tokenize(text)))
# sentences = [s.strip() for s in text.split(".") if s.strip()]
sentences = list(set(simple_sentence_split(text)))
path_terms = set()
for sentence in sentences:
sentence = sentence.replace("\n", " ")
sentence = sentence.replace("\s+", " ")
# Run CheXbert
pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0]
# Skip if: "No Finding" active, or all labels same, or no positives
if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology):
continue
# Collect positive indices
indices = [i for i, e in enumerate(pathology) if e == 1]
for ind in indices:
path_terms.add(path_list[ind])
return sorted(path_terms)
def extract_tensor(value):
cleaned_value = value.replace('tensor(', '').replace(')', '')
return literal_eval(cleaned_value)
def gen_cxr(weight_path, image_path, text_prompt, num_samples, output_path, device: str = "cpu"):
parser = sample_generation.get_args_parser()
args = parser.parse_args([])
# args.weight_path = weight_path
args.image_path = image_path
args.text_prompt = text_prompt
args.num_samples = num_samples
args.output_path = output_path
args.weight_path = get_weight(weight_path)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
args.device = device
sample_generation.main(args)
def cal_shift(img_org_path, img_gen_path):
parser = get_detr_args_parser()
args = parser.parse_args([])
args.read_checkpoint = get_weight("DETR/output/checkpoint.pth")
args.img_org = img_org_path
args.img_gen = img_gen_path
shift_x, shift_y = svc.main(args)
return shift_x, shift_y
def get_local_bbox(weight_path, image_path, text_prompt, box_threshold, text_threshold):
parser = localization.get_args_parser()
args = parser.parse_args([])
# vg_ckpt_main = get_weight("VG/weights/checkpoint0399.pth")
# vg_ckpt_best = get_weight("VG/weights/checkpoint_best_regular.pth")
# vg_ckpt_log4 = get_weight("VG/weights/checkpoint0399_log4.pth")
# args.weight_path = weight_path
args.weight_path = get_weight(weight_path)
args.image_path = image_path
args.text_prompt = text_prompt
args.box_threshold = box_threshold
args.text_threshold = text_threshold
bbox, logits, phrases = localization.main(args)
return bbox, logits, phrases
if __name__ == "__main__":
args = get_args_parser().parse_args()
gen_cxr(args.weight_path_gencxr, args.image_path, args.text_prompt, args.num_samples, args.output_path)
time.sleep(4) # ensure outputs are written
df = pd.read_csv(args.output_path + "info_path_similarity.csv")
sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]]
max_sim_index = sim_ratios.index(max(sim_ratios))
max_sim_gen_path = df["gen_sample_path"][max_sim_index]
sx, sy = cal_shift(args.image_path, max_sim_gen_path)
boxes, logits, phrases = get_local_bbox(
args.weight_path_vg,
args.image_path,
args.text_prompt,
args.box_threshold,
args.text_threshold
)
print("Boxes:", boxes)
print("Phrases:", phrases)
image_org_cv = cv2.imread(args.image_path, cv2.IMREAD_GRAYSCALE)
image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE)
ssim_scores = []
for bbox in boxes:
x1, y1, x2, y2 = bbox
bbox1 = [x1, y1, x2 - x1, y2 - y1]
bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1]
bx1, by1, bw1, bh1 = [int(val) for val in bbox1]
bx2, by2, bw2, bh2 = [int(val) for val in bbox2]
roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1]
roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2]
if roi_org.shape == roi_gen.shape and roi_org.size > 0:
score = ssim(roi_org, roi_gen)
ssim_scores.append(score)
if ssim_scores:
print("SSIM scores per box:", ssim_scores)
print("Localization Detection Scores per bbox:", boxes, logits)
# print("Average SSIM (Localization Score):", sum(ssim_scores) / len(ssim_scores))
else:
print("No valid SSIM scores (e.g., mismatched shapes or empty ROIs).")