valste's picture
fixed for local runs
f235f14
from base import *
# ------------------------------------------------------------
# 3️⃣ Define calback function for Gradio Interface
# ------------------------------------------------------------
def _predict_and_display(img_path: str):
if not img_path:
# return empty values to clear outputs
return {}, None
filename_out, masked_vis, scores = predict(img_path)
return scores, masked_vis
def _clear_outputs():
return gr.update(value=None), gr.update(value=None)
default_image_label = "Select or upload an image to classify" #default label for img_input
def _change_input_label(img_path: str | None, default_image_label=default_image_label) -> gr.update:
if img_path:
fname = os.path.basename(img_path)
return gr.update(label= f"Image to classify: {fname}")
return gr.update(label=default_image_label)
# ------------------------------------------------------------
# 4️⃣ Build Gradio Interface
# ------------------------------------------------------------
N_examples_to_use = min(20, len(img_paths))
with gr.Blocks() as demo:
img_input = gr.Image(type="filepath", label=_change_input_label(None)["label"])
out_mask = gr.Image(type="numpy", label="Masked Lung Region")
out_label = gr.Label(label="Predicted confidences")
examples = gr.Examples(
label="Example X-rays",
examples=img_paths[:N_examples_to_use], # each element is a list of one string (path): [[path1], [path2], ... ]
example_labels=img_names[:N_examples_to_use],
inputs=img_input,
cache_examples=False,
)
# event order of img_input:
# * for uploads: select() → upload() → change()
# * for examples clicks: select() → change()
# img_input.select(_change_input_label, inputs=img_input, outputs=img_input)
img_input.change(_change_input_label, inputs=img_input, outputs=img_input)
img_input.change(_clear_outputs, inputs=None, outputs=[out_label, out_mask], queue=False)
img_input.change(_predict_and_display, inputs=img_input, outputs=[out_label, out_mask])
# ------------------------------------------------------------
# 5️⃣ Launch the app
# ------------------------------------------------------------
if __name__ == "__main__":
# print("EXAMPLES_DIR:", EXAMPLES_DIR)
# print("PNG count:", len(list(EXAMPLES_DIR.glob("*.png"))))
# print("len(img_paths):", len(img_paths), "len(img_names):", len(img_names))
# print("first img_paths item:", img_paths[0] if img_paths else None, "type:", type(img_paths[0]) if img_paths else None)
import sys
print("Python version used:", sys.version)
if is_spaces:
demo.launch(debug=True, allowed_paths=[str(EXAMPLES_DIR)],)
else:
demo.launch(share=True, debug=True, allowed_paths=[str(EXAMPLES_DIR)], server_name="127.0.0.1", server_port=7860)