import sys import types sys.modules["audioop"] = types.ModuleType("audioop") sys.modules["pyaudioop"] = types.ModuleType("pyaudioop") import os import gradio as gr from PIL import Image from datasets import load_dataset import model_utils import visualization_utils # --- Configuration --- MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k" DATASET_PATH = "pawlo2013/chest_xray" MODEL_DIR = "./models" EXAMPLES_FOLDER = "./examples" # --- Load Data & Model --- print("Loading dataset information...") try: # We load the dataset mainly to get the class names correctly train_dataset = load_dataset(DATASET_PATH, split="train") class_names = train_dataset.features["label"].names print(f"Class names loaded: {class_names}") except Exception as e: print(f"Warning: Could not load dataset, using default class names. Error: {e}") # Fallback class names based on typical chest X-ray classification class_names = ["NORMAL", "PNEUMONIA"] print("Loading model and processor...") try: model, processor = model_utils.load_model_and_processor( MODEL_DIR, MODEL_NAME_OR_PATH, class_names ) print("Model and processor loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise e # --- Core Logic --- def classify_and_visualize(img, device="cpu"): if img is None: return None, None try: # Get predictions outputs, processed_input, probabilities, prediction_idx = model_utils.predict( model, processor, img, device ) # Format probabilities result = {class_name: prob for class_name, prob in zip(class_names, probabilities)} # Generate heatmap heatmap_img = visualization_utils.show_final_layer_attention_maps( outputs, processed_input, device ) return result, heatmap_img except Exception as e: print(f"Error in classification: {e}") # Return empty result and None for heatmap on error empty_result = {class_name: 0.0 for class_name in class_names} return empty_result, None def format_output(img): try: probs, heatmap = classify_and_visualize(img) return probs, heatmap except Exception as e: print(f"Error in format_output: {e}") # Return empty results on error empty_result = {class_name: 0.0 for class_name in class_names} return empty_result, None # --- Helper Functions --- def load_examples_from_folder(folder_path): examples = [] if os.path.exists(folder_path): for file in os.listdir(folder_path): if file.lower().endswith((".png", ".jpg", ".jpeg")): examples.append(os.path.join(folder_path, file)) return examples examples = load_examples_from_folder(EXAMPLES_FOLDER) # --- UI Layout --- title = "Pneumonia Detection Assistant" description = """
Upload a Chest X-Ray image to analyze it for signs of Pneumonia.
The model classifies the image into Normal, Viral Pneumonia, or Bacterial Pneumonia categories and provides an attention heatmap to show which areas influenced the decision.
This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only. It is NOT a diagnostic tool. The results generated by this model should not be treated as medical advice. Always consult with a qualified healthcare professional for medical diagnosis and treatment.