vicca / app.py
sayehghp's picture
Visualization
27e3844
import os
os.environ["OMP_NUM_THREADS"] = "1"
import tempfile
import gradio as gr
from vicca_api import run_vicca
def vicca_interface(image, text_prompt, box_threshold=0.2, text_threshold=0.2, num_samples=4):
os.makedirs("uploads", exist_ok=True)
input_path = os.path.join("uploads", "input.png")
image.save(input_path)
result = run_vicca(
image_path=input_path,
text_prompt=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
num_samples=num_samples,
)
best_gen = result.get("best_generated_image_path")
VG_path = result.get("VG_annotated_image_path")
attn = result.get("attention_overlays") or {}
combined = attn.get("combined")
per_term_dict = attn.get("per_term") or {}
gallery_items = [(p, term) for term, p in per_term_dict.items()]
return best_gen, VG_path, combined, gallery_items, result
demo = gr.Interface(
fn=vicca_interface,
inputs=[
gr.Image(type="pil", label="Input CXR"),
gr.Textbox(lines=3, label="Text prompt"),
gr.Slider(0.0, 1.0, value=0.2, label="Box threshold"),
gr.Slider(0.0, 1.0, value=0.2, label="Text threshold"),
gr.Slider(1, 8, step=1, value=4, label="Number of samples"),
],
outputs=[
gr.Image(label="Best generated CXR"),
gr.Image(label="VG annotated image"),
gr.Image(label="Combined attention heatmap"),
# gr.Gallery(label="Per-term overlays").style(grid=[3], height=400),
gr.Gallery(
label="Per-term overlays",
columns=3, # replaces grid=[3]
height=400 # set height directly
),
gr.JSON(label="Raw VICCA output"),
],
title="VICCA",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
# def vicca_interface(image, text_prompt):
# """
# image: file from Gradio, we'll use its temp path
# text_prompt: report / description
# """
# # Gradio passes a PIL image or a file path depending on type
# # We'll request type='filepath' so this is already a path
# image_path = image
# result = run_vicca(
# image_path=image_path,
# text_prompt=text_prompt,
# )
# # You could also return the best generated image as an image output
# # For now, we expose the dict as JSON
# return result
# demo = gr.Interface(
# fn=vicca_interface,
# inputs=[
# gr.Image(type="filepath", label="Chest X-ray"),
# gr.Textbox(label="Report / pathology description", lines=3),
# ],
# outputs=gr.JSON(label="VICCA output"),
# title="VICCA – Visual Interpretation & Comprehension",
# description=(
# "Upload a chest X-ray and provide a text report / pathology description. "
# "The VICCA pipeline will run CXR generation, visual grounding, "
# "and ROI-level similarity scoring."
# ),
# )
# if __name__ == "__main__":
# demo.launch()
# if __name__ == "__main__":
# demo.launch(
# server_name="0.0.0.0",
# server_port=7860,
# debug=False
# )