| | import gradio as gr |
| | from datasets import load_dataset |
| | import tensorflow as tf |
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| | import os |
| | import numpy as np |
| | from tensorflow.keras.preprocessing.image import img_to_array |
| | from modelbuilder import capsnet_custom_objects |
| |
|
| |
|
| | |
| | def running_in_spaces() -> bool: |
| | """Return True if app is running inside Hugging Face Spaces.""" |
| | return ( |
| | os.environ.get("SPACE_ID") is not None or os.environ.get("SYSTEM") == "spaces" |
| | ) |
| |
|
| |
|
| | is_spaces = running_in_spaces() |
| |
|
| | if is_spaces: |
| | print(f"Running in Hugging Face Spaces environment.") |
| | else: |
| | print(f"Running in local machine environment:{os.environ.get('SYSTEM')}") |
| |
|
| | |
| | TARGET_SIZE = (256, 256) |
| | CLASS_LABELS = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"] |
| |
|
| | |
| | |
| | |
| | |
| | gan_model_path = None |
| | capsnet_model_path = None |
| | dataset = None |
| | data_dir = None |
| |
|
| | if is_spaces: |
| | |
| |
|
| | data_dir = "external_xrays_299x299" |
| | dataset = load_dataset( |
| | "valste/lung-disease-xrays", data_dir=data_dir, split="train" |
| | ) |
| |
|
| | gan_model_path = hf_hub_download( |
| | repo_id="valste/lung-segmentation-gan", filename="model.keras" |
| | ) |
| | capsnet_model_path = hf_hub_download( |
| | repo_id="valste/capsnet-4class-lung-disease-classifier", filename="model.keras" |
| | ) |
| |
|
| | else: |
| | raise NotImplementedError("clone required models locally and adjust paths in here first! ... and remove this line afterwards ;)") |
| | |
| | |
| | capsnet_model_path = os.path.join( |
| | ".", "models", "capsnet-4class-lung-disease-classifier", "model.keras" |
| | ) |
| | gan_model_path = os.path.join(".", "models", "lung-segmentation-gan", "model.keras") |
| | data_dir = os.path.join(".", "data", "external_xrays_299x299") |
| | dataset = load_dataset( |
| | "imagefolder", data_dir=data_dir, split="train" |
| | ) |
| |
|
| | model_gan = tf.keras.models.load_model(gan_model_path, compile=False) |
| | model_capsnet = tf.keras.models.load_model( |
| | capsnet_model_path, custom_objects=capsnet_custom_objects, compile=False |
| | ) |
| |
|
| | |
| | |
| | |
| | imgs=[] |
| | img_paths = [] |
| | img_names = [] |
| |
|
| | class DemoException(Exception): |
| | pass |
| | |
| |
|
| | for ex in dataset: |
| | if "image" in ex: |
| | imgs.append(ex["image"]) |
| | path = getattr(ex["image"], "filename", None) |
| | if path: |
| | img_paths.append([path]) |
| | img_names.append(os.path.basename(path)) |
| | else: |
| | raise DemoException("Missing path") |
| | else: |
| | raise DemoException("Dataset examples do not contain 'image' field.") |
| |
|
| | |
| | |
| | |
| |
|
| | def create_masked_img(img: Image.Image) -> tuple[np.ndarray, np.ndarray]: |
| | |
| | |
| | img_gray = img.convert("L") |
| | img_gray = img_gray.resize(TARGET_SIZE, Image.BILINEAR) |
| | gray_array = np.array(img_gray, dtype=np.float32) / 255.0 |
| |
|
| | |
| | gan_input = gray_array[..., np.newaxis] |
| | gan_input = np.expand_dims(gan_input, axis=0) |
| | |
| | |
| | prediction = model_gan.predict(gan_input) |
| | lung_prob = np.squeeze(prediction) |
| | mask = (lung_prob > 0.5).astype(np.float32) |
| |
|
| | |
| | masked_gray = gray_array * mask |
| | |
| | |
| | masked_gray_ch = np.repeat( |
| | masked_gray[..., np.newaxis], 3, axis=-1 |
| | ) |
| | |
| | |
| | x = np.expand_dims(masked_gray_ch, axis=0).astype( |
| | np.float32 |
| | ) |
| |
|
| | return x, masked_gray |
| |
|
| |
|
| | def predict(img_path: str) -> tuple[str, np.ndarray, np.ndarray]: |
| | |
| | img = Image.open(img_path) |
| | x, masked_vis = create_masked_img(img) |
| | preds = model_capsnet.predict(x, verbose=0) |
| |
|
| | preds = np.asarray(preds) |
| | if preds.ndim > 2: |
| | preds = np.squeeze(preds, axis=-1) |
| | preds = np.squeeze(preds) |
| |
|
| | scores = {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))} |
| | |
| | filename_out = os.path.basename(img_path) |
| |
|
| | return filename_out, masked_vis, scores |