Spaces:
Build error
Build error
File size: 4,632 Bytes
0c17dc3 99ff9e0 e1bc07e 99ff9e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import sys
import types
sys.modules["audioop"] = types.ModuleType("audioop")
sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")
import os
import gradio as gr
from PIL import Image
from datasets import load_dataset
import model_utils
import visualization_utils
# --- Configuration ---
MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k"
DATASET_PATH = "pawlo2013/chest_xray"
MODEL_DIR = "./models"
EXAMPLES_FOLDER = "./examples"
# --- Load Data & Model ---
print("Loading dataset information...")
try:
# We load the dataset mainly to get the class names correctly
train_dataset = load_dataset(DATASET_PATH, split="train")
class_names = train_dataset.features["label"].names
print(f"Class names loaded: {class_names}")
except Exception as e:
print(f"Warning: Could not load dataset, using default class names. Error: {e}")
# Fallback class names based on typical chest X-ray classification
class_names = ["NORMAL", "PNEUMONIA"]
print("Loading model and processor...")
try:
model, processor = model_utils.load_model_and_processor(
MODEL_DIR, MODEL_NAME_OR_PATH, class_names
)
print("Model and processor loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
# --- Core Logic ---
def classify_and_visualize(img, device="cpu"):
if img is None:
return None, None
try:
# Get predictions
outputs, processed_input, probabilities, prediction_idx = model_utils.predict(
model, processor, img, device
)
# Format probabilities
result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
# Generate heatmap
heatmap_img = visualization_utils.show_final_layer_attention_maps(
outputs, processed_input, device
)
return result, heatmap_img
except Exception as e:
print(f"Error in classification: {e}")
# Return empty result and None for heatmap on error
empty_result = {class_name: 0.0 for class_name in class_names}
return empty_result, None
def format_output(img):
try:
probs, heatmap = classify_and_visualize(img)
return probs, heatmap
except Exception as e:
print(f"Error in format_output: {e}")
# Return empty results on error
empty_result = {class_name: 0.0 for class_name in class_names}
return empty_result, None
# --- Helper Functions ---
def load_examples_from_folder(folder_path):
examples = []
if os.path.exists(folder_path):
for file in os.listdir(folder_path):
if file.lower().endswith((".png", ".jpg", ".jpeg")):
examples.append(os.path.join(folder_path, file))
return examples
examples = load_examples_from_folder(EXAMPLES_FOLDER)
# --- UI Layout ---
title = "Pneumonia Detection Assistant"
description = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p>
<p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories
and provides an attention heatmap to show which areas influenced the decision.</p>
</div>
"""
article = """
<div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;">
<h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3>
<p style="color: #7f8c8d;">
This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only.
<b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice.
Always consult with a qualified healthcare professional for medical diagnosis and treatment.
</p>
</div>
"""
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_hover="*primary_600",
)
iface = gr.Interface(
fn=format_output,
inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
outputs=[
gr.Label(label="Prediction Confidence", num_top_classes=3),
gr.Image(label="Attention Heatmap Analysis"),
],
examples=examples,
cache_examples=False,
title=title,
description=description,
article=article,
theme=theme,
# allow_flagging="never"
)
# Launch
if __name__ == "__main__":
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
|