EdwardSamuel13's picture
Update app.py
98b68eb verified
import sys
import types
sys.modules["audioop"] = types.ModuleType("audioop")
sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")
import os
import gradio as gr
from PIL import Image
from datasets import load_dataset
import model_utils
import visualization_utils
# --- Configuration ---
MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k"
DATASET_PATH = "pawlo2013/chest_xray"
MODEL_DIR = "./models"
EXAMPLES_FOLDER = "./examples"
# --- Load Data & Model ---
print("Loading dataset information...")
try:
# We load the dataset mainly to get the class names correctly
train_dataset = load_dataset(DATASET_PATH, split="train")
class_names = train_dataset.features["label"].names
print(f"Class names loaded: {class_names}")
except Exception as e:
print(f"Warning: Could not load dataset, using default class names. Error: {e}")
# Fallback class names based on typical chest X-ray classification
class_names = ["NORMAL", "PNEUMONIA"]
print("Loading model and processor...")
try:
model, processor = model_utils.load_model_and_processor(
MODEL_DIR, MODEL_NAME_OR_PATH, class_names
)
print("Model and processor loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
# --- Core Logic ---
def classify_and_visualize(img, device="cpu"):
if img is None:
return None, None
try:
# Get predictions
outputs, processed_input, probabilities, prediction_idx = model_utils.predict(
model, processor, img, device
)
# Format probabilities
result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
# Generate heatmap
heatmap_img = visualization_utils.show_final_layer_attention_maps(
outputs, processed_input, device
)
return result, heatmap_img
except Exception as e:
print(f"Error in classification: {e}")
# Return empty result and None for heatmap on error
empty_result = {class_name: 0.0 for class_name in class_names}
return empty_result, None
def format_output(img):
try:
probs, heatmap = classify_and_visualize(img)
return probs, heatmap
except Exception as e:
print(f"Error in format_output: {e}")
# Return empty results on error
empty_result = {class_name: 0.0 for class_name in class_names}
return empty_result, None
# --- Helper Functions ---
def load_examples_from_folder(folder_path):
examples = []
if os.path.exists(folder_path):
for file in os.listdir(folder_path):
if file.lower().endswith((".png", ".jpg", ".jpeg")):
examples.append(os.path.join(folder_path, file))
return examples
examples = load_examples_from_folder(EXAMPLES_FOLDER)
# --- UI Layout ---
title = "Pneumonia Detection Assistant"
description = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p>
<p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories
and provides an attention heatmap to show which areas influenced the decision.</p>
</div>
"""
article = """
<div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;">
<h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3>
<p style="color: #7f8c8d;">
This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only.
<b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice.
Always consult with a qualified healthcare professional for medical diagnosis and treatment.
</p>
</div>
"""
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
)
iface = gr.Interface(
fn=format_output,
inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
outputs=[
gr.Label(label="Prediction Confidence", num_top_classes=3),
gr.Image(label="Attention Heatmap Analysis"),
],
examples=examples,
cache_examples=False,
title=title,
description=description,
article=article,
theme=theme,
# allow_flagging="never"
)
# Launch
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)