capsnet-4class-lung-disease-classifier / inference_example_pipe.py
valste's picture
Initial commit: CapsNet 4-class lung disease classifier
f550944
import gradio as gr
from datasets import load_dataset
import tensorflow as tf
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from modelbuilder import capsnet_custom_objects
# determine the running environment: local machine or huggingface spaces
def running_in_spaces() -> bool:
"""Return True if app is running inside Hugging Face Spaces."""
return (
os.environ.get("SPACE_ID") is not None or os.environ.get("SYSTEM") == "spaces"
)
is_spaces = running_in_spaces()
if is_spaces:
print(f"Running in Hugging Face Spaces environment.")
else:
print(f"Running in local machine environment:{os.environ.get('SYSTEM')}")
# -------CONSTANTS-------#
TARGET_SIZE = (256, 256) # target size for masked images
CLASS_LABELS = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]
# ------------------------------------------------------------
# 1️⃣ Load the models from Hugging Face Hub
# capsnet for disease classification and GAN for lung segmentation/masking
# ------------------------------------------------------------
gan_model_path = None
capsnet_model_path = None
dataset = None
data_dir = None
if is_spaces:
# huggingface datasets is preinstalled in Spaces
data_dir = "external_xrays_299x299"
dataset = load_dataset(
"valste/lung-disease-xrays", data_dir=data_dir, split="train"
)
gan_model_path = hf_hub_download(
repo_id="valste/lung-segmentation-gan", filename="model.keras"
)
capsnet_model_path = hf_hub_download(
repo_id="valste/capsnet-4class-lung-disease-classifier", filename="model.keras"
)
else:
raise NotImplementedError("clone required models locally and adjust paths in here first! ... and remove this line afterwards ;)")
# local machine
capsnet_model_path = os.path.join(
".", "models", "capsnet-4class-lung-disease-classifier", "model.keras"
)
gan_model_path = os.path.join(".", "models", "lung-segmentation-gan", "model.keras")
data_dir = os.path.join(".", "data", "external_xrays_299x299")
dataset = load_dataset(
"imagefolder", data_dir=data_dir, split="train" # path to your local folder
)
model_gan = tf.keras.models.load_model(gan_model_path, compile=False)
model_capsnet = tf.keras.models.load_model(
capsnet_model_path, custom_objects=capsnet_custom_objects, compile=False
)
# ------------------------------------------------------------
# 2️⃣ Load sample X-ray images from your dataset
# ------------------------------------------------------------
imgs=[]
img_paths = []
img_names = []
class DemoException(Exception):
pass
for ex in dataset:
if "image" in ex:
imgs.append(ex["image"])
path = getattr(ex["image"], "filename", None) # string
if path:
img_paths.append([path])
img_names.append(os.path.basename(path))
else:
raise DemoException("Missing path")
else:
raise DemoException("Dataset examples do not contain 'image' field.")
# ------------------------------------------------------------
# 3️⃣ Define preprocessing and inference function
# ------------------------------------------------------------
def create_masked_img(img: Image.Image) -> tuple[np.ndarray, np.ndarray]:
# --- 1) Make a grayscale base image for segmentation ---
img_gray = img.convert("L") # grayscale (299, 299)
img_gray = img_gray.resize(TARGET_SIZE, Image.BILINEAR) # (256, 256)
gray_array = np.array(img_gray, dtype=np.float32) / 255.0 # (H, W) in [0,1]
# --- 2) Build a 1-channel version for GAN input ---
gan_input = gray_array[..., np.newaxis] # adding channel dim (H, W) → (H, W, 1)
gan_input = np.expand_dims(gan_input, axis=0) # adding batch dim (1, H, W, 1)
# --- 3) Run segmentation GAN to get lung mask ---
prediction = model_gan.predict(gan_input) # (1, 256, 256, 1)
lung_prob = np.squeeze(prediction) # (H, W)
mask = (lung_prob > 0.5).astype(np.float32) # (H, W), 0/1
# --- 4) Apply mask ---
masked_gray = gray_array * mask # (H, W)
# --- 5) Prepare input for CapsNet: 3 channel + batch dimension---
masked_gray_ch = np.repeat(
masked_gray[..., np.newaxis], 3, axis=-1
) # channel dim (H, W, 3)
# adding
x = np.expand_dims(masked_gray_ch, axis=0).astype(
np.float32
) # (1, H, W, 1) ✅ for CapsNet
return x, masked_gray
def predict(img_path: str) -> tuple[str, np.ndarray, np.ndarray]:
img = Image.open(img_path)
x, masked_vis = create_masked_img(img)
preds = model_capsnet.predict(x, verbose=0)
preds = np.asarray(preds)
if preds.ndim > 2:
preds = np.squeeze(preds, axis=-1) # (1, 4, 1) → (1, 4)
preds = np.squeeze(preds) # (4,)
scores = {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))}
filename_out = os.path.basename(img_path)
return filename_out, masked_vis, scores