import gradio as gr from datasets import load_dataset import tensorflow as tf from huggingface_hub import hf_hub_download from PIL import Image import os import numpy as np from tensorflow.keras.preprocessing.image import img_to_array from modelbuilder import capsnet_custom_objects # determine the running environment: local machine or huggingface spaces def running_in_spaces() -> bool: """Return True if app is running inside Hugging Face Spaces.""" return ( os.environ.get("SPACE_ID") is not None or os.environ.get("SYSTEM") == "spaces" ) is_spaces = running_in_spaces() if is_spaces: print(f"Running in Hugging Face Spaces environment.") else: print(f"Running in local machine environment:{os.environ.get('SYSTEM')}") # -------CONSTANTS-------# TARGET_SIZE = (256, 256) # target size for masked images CLASS_LABELS = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"] # ------------------------------------------------------------ # 1️⃣ Load the models from Hugging Face Hub # capsnet for disease classification and GAN for lung segmentation/masking # ------------------------------------------------------------ gan_model_path = None capsnet_model_path = None dataset = None data_dir = None if is_spaces: # huggingface datasets is preinstalled in Spaces data_dir = "external_xrays_299x299" dataset = load_dataset( "valste/lung-disease-xrays", data_dir=data_dir, split="train" ) gan_model_path = hf_hub_download( repo_id="valste/lung-segmentation-gan", filename="model.keras" ) capsnet_model_path = hf_hub_download( repo_id="valste/capsnet-4class-lung-disease-classifier", filename="model.keras" ) else: raise NotImplementedError("clone required models locally and adjust paths in here first! ... and remove this line afterwards ;)") # local machine capsnet_model_path = os.path.join( ".", "models", "capsnet-4class-lung-disease-classifier", "model.keras" ) gan_model_path = os.path.join(".", "models", "lung-segmentation-gan", "model.keras") data_dir = os.path.join(".", "data", "external_xrays_299x299") dataset = load_dataset( "imagefolder", data_dir=data_dir, split="train" # path to your local folder ) model_gan = tf.keras.models.load_model(gan_model_path, compile=False) model_capsnet = tf.keras.models.load_model( capsnet_model_path, custom_objects=capsnet_custom_objects, compile=False ) # ------------------------------------------------------------ # 2️⃣ Load sample X-ray images from your dataset # ------------------------------------------------------------ imgs=[] img_paths = [] img_names = [] class DemoException(Exception): pass for ex in dataset: if "image" in ex: imgs.append(ex["image"]) path = getattr(ex["image"], "filename", None) # string if path: img_paths.append([path]) img_names.append(os.path.basename(path)) else: raise DemoException("Missing path") else: raise DemoException("Dataset examples do not contain 'image' field.") # ------------------------------------------------------------ # 3️⃣ Define preprocessing and inference function # ------------------------------------------------------------ def create_masked_img(img: Image.Image) -> tuple[np.ndarray, np.ndarray]: # --- 1) Make a grayscale base image for segmentation --- img_gray = img.convert("L") # grayscale (299, 299) img_gray = img_gray.resize(TARGET_SIZE, Image.BILINEAR) # (256, 256) gray_array = np.array(img_gray, dtype=np.float32) / 255.0 # (H, W) in [0,1] # --- 2) Build a 1-channel version for GAN input --- gan_input = gray_array[..., np.newaxis] # adding channel dim (H, W) → (H, W, 1) gan_input = np.expand_dims(gan_input, axis=0) # adding batch dim (1, H, W, 1) # --- 3) Run segmentation GAN to get lung mask --- prediction = model_gan.predict(gan_input) # (1, 256, 256, 1) lung_prob = np.squeeze(prediction) # (H, W) mask = (lung_prob > 0.5).astype(np.float32) # (H, W), 0/1 # --- 4) Apply mask --- masked_gray = gray_array * mask # (H, W) # --- 5) Prepare input for CapsNet: 3 channel + batch dimension--- masked_gray_ch = np.repeat( masked_gray[..., np.newaxis], 3, axis=-1 ) # channel dim (H, W, 3) # adding x = np.expand_dims(masked_gray_ch, axis=0).astype( np.float32 ) # (1, H, W, 1) ✅ for CapsNet return x, masked_gray def predict(img_path: str) -> tuple[str, np.ndarray, np.ndarray]: img = Image.open(img_path) x, masked_vis = create_masked_img(img) preds = model_capsnet.predict(x, verbose=0) preds = np.asarray(preds) if preds.ndim > 2: preds = np.squeeze(preds, axis=-1) # (1, 4, 1) → (1, 4) preds = np.squeeze(preds) # (4,) scores = {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))} filename_out = os.path.basename(img_path) return filename_out, masked_vis, scores