File size: 4,632 Bytes
0c17dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99ff9e0
 
 
e1bc07e
99ff9e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import sys
import types

sys.modules["audioop"] = types.ModuleType("audioop")
sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")

import os
import gradio as gr
from PIL import Image
from datasets import load_dataset
import model_utils
import visualization_utils

# --- Configuration ---
MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k"
DATASET_PATH = "pawlo2013/chest_xray"
MODEL_DIR = "./models"
EXAMPLES_FOLDER = "./examples"

# --- Load Data & Model ---
print("Loading dataset information...")
try:
    # We load the dataset mainly to get the class names correctly
    train_dataset = load_dataset(DATASET_PATH, split="train")
    class_names = train_dataset.features["label"].names
    print(f"Class names loaded: {class_names}")
except Exception as e:
    print(f"Warning: Could not load dataset, using default class names. Error: {e}")
    # Fallback class names based on typical chest X-ray classification
    class_names = ["NORMAL", "PNEUMONIA"]

print("Loading model and processor...")
try:
    model, processor = model_utils.load_model_and_processor(
        MODEL_DIR, MODEL_NAME_OR_PATH, class_names
    )
    print("Model and processor loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise e

# --- Core Logic ---
def classify_and_visualize(img, device="cpu"):
    if img is None:
        return None, None
    
    try:
        # Get predictions
        outputs, processed_input, probabilities, prediction_idx = model_utils.predict(
            model, processor, img, device
        )
        
        # Format probabilities
        result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
        
        # Generate heatmap
        heatmap_img = visualization_utils.show_final_layer_attention_maps(
            outputs, processed_input, device
        )
        
        return result, heatmap_img
    
    except Exception as e:
        print(f"Error in classification: {e}")
        # Return empty result and None for heatmap on error
        empty_result = {class_name: 0.0 for class_name in class_names}
        return empty_result, None

def format_output(img):
    try:
        probs, heatmap = classify_and_visualize(img)
        return probs, heatmap
    except Exception as e:
        print(f"Error in format_output: {e}")
        # Return empty results on error
        empty_result = {class_name: 0.0 for class_name in class_names}
        return empty_result, None

# --- Helper Functions ---
def load_examples_from_folder(folder_path):
    examples = []
    if os.path.exists(folder_path):
        for file in os.listdir(folder_path):
            if file.lower().endswith((".png", ".jpg", ".jpeg")):
                examples.append(os.path.join(folder_path, file))
    return examples

examples = load_examples_from_folder(EXAMPLES_FOLDER)

# --- UI Layout ---
title = "Pneumonia Detection Assistant"
description = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p>
    <p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories 
    and provides an attention heatmap to show which areas influenced the decision.</p>
</div>
"""
article = """
<div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;">
    <h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3>
    <p style="color: #7f8c8d;">
        This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only. 
        <b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice.
        Always consult with a qualified healthcare professional for medical diagnosis and treatment.
    </p>
</div>
"""

theme = gr.themes.Soft(
    primary_hue="blue",
    secondary_hue="slate",
).set(
    button_primary_background_fill="*primary_500",
    button_primary_background_fill_hover="*primary_600",
)

iface = gr.Interface(
    fn=format_output,
    inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
    outputs=[
        gr.Label(label="Prediction Confidence", num_top_classes=3),
        gr.Image(label="Attention Heatmap Analysis"),
    ],
    examples=examples,
    cache_examples=False,
    title=title,
    description=description,
    article=article,
    theme=theme,
    # allow_flagging="never"
)

# Launch
if __name__ == "__main__":
    iface.launch(
    server_name="0.0.0.0",
    server_port=7860,
    share=True
)