EdwardSamuel13 commited on
Commit
0c17dc3
·
verified ·
1 Parent(s): 1360822

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -129
app.py CHANGED
@@ -1,130 +1,136 @@
1
- import os
2
- import gradio as gr
3
- from PIL import Image
4
- from datasets import load_dataset
5
- import model_utils
6
- import visualization_utils
7
-
8
- # --- Configuration ---
9
- MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k"
10
- DATASET_PATH = "pawlo2013/chest_xray"
11
- MODEL_DIR = "./models"
12
- EXAMPLES_FOLDER = "./examples"
13
-
14
- # --- Load Data & Model ---
15
- print("Loading dataset information...")
16
- try:
17
- # We load the dataset mainly to get the class names correctly
18
- train_dataset = load_dataset(DATASET_PATH, split="train")
19
- class_names = train_dataset.features["label"].names
20
- print(f"Class names loaded: {class_names}")
21
- except Exception as e:
22
- print(f"Warning: Could not load dataset, using default class names. Error: {e}")
23
- # Fallback class names based on typical chest X-ray classification
24
- class_names = ["NORMAL", "PNEUMONIA"]
25
-
26
- print("Loading model and processor...")
27
- try:
28
- model, processor = model_utils.load_model_and_processor(
29
- MODEL_DIR, MODEL_NAME_OR_PATH, class_names
30
- )
31
- print("Model and processor loaded successfully!")
32
- except Exception as e:
33
- print(f"Error loading model: {e}")
34
- raise e
35
-
36
- # --- Core Logic ---
37
- def classify_and_visualize(img, device="cpu"):
38
- if img is None:
39
- return None, None
40
-
41
- try:
42
- # Get predictions
43
- outputs, processed_input, probabilities, prediction_idx = model_utils.predict(
44
- model, processor, img, device
45
- )
46
-
47
- # Format probabilities
48
- result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
49
-
50
- # Generate heatmap
51
- heatmap_img = visualization_utils.show_final_layer_attention_maps(
52
- outputs, processed_input, device
53
- )
54
-
55
- return result, heatmap_img
56
-
57
- except Exception as e:
58
- print(f"Error in classification: {e}")
59
- # Return empty result and None for heatmap on error
60
- empty_result = {class_name: 0.0 for class_name in class_names}
61
- return empty_result, None
62
-
63
- def format_output(img):
64
- try:
65
- probs, heatmap = classify_and_visualize(img)
66
- return probs, heatmap
67
- except Exception as e:
68
- print(f"Error in format_output: {e}")
69
- # Return empty results on error
70
- empty_result = {class_name: 0.0 for class_name in class_names}
71
- return empty_result, None
72
-
73
- # --- Helper Functions ---
74
- def load_examples_from_folder(folder_path):
75
- examples = []
76
- if os.path.exists(folder_path):
77
- for file in os.listdir(folder_path):
78
- if file.lower().endswith((".png", ".jpg", ".jpeg")):
79
- examples.append(os.path.join(folder_path, file))
80
- return examples
81
-
82
- examples = load_examples_from_folder(EXAMPLES_FOLDER)
83
-
84
- # --- UI Layout ---
85
- title = "Pneumonia Detection Assistant"
86
- description = """
87
- <div style="text-align: center; max-width: 700px; margin: 0 auto;">
88
- <p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p>
89
- <p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories
90
- and provides an attention heatmap to show which areas influenced the decision.</p>
91
- </div>
92
- """
93
- article = """
94
- <div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;">
95
- <h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3>
96
- <p style="color: #7f8c8d;">
97
- This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only.
98
- <b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice.
99
- Always consult with a qualified healthcare professional for medical diagnosis and treatment.
100
- </p>
101
- </div>
102
- """
103
-
104
- theme = gr.themes.Soft(
105
- primary_hue="blue",
106
- secondary_hue="slate",
107
- ).set(
108
- button_primary_background_fill="*primary_500",
109
- button_primary_background_fill_hover="*primary_600",
110
- )
111
-
112
- iface = gr.Interface(
113
- fn=format_output,
114
- inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
115
- outputs=[
116
- gr.Label(label="Prediction Confidence", num_top_classes=3),
117
- gr.Image(label="Attention Heatmap Analysis"),
118
- ],
119
- examples=examples,
120
- cache_examples=False,
121
- title=title,
122
- description=description,
123
- article=article,
124
- theme=theme,
125
- # allow_flagging="never"
126
- )
127
-
128
- # Launch
129
- if __name__ == "__main__":
 
 
 
 
 
 
130
  iface.launch()
 
1
+ import sys
2
+ import types
3
+
4
+ sys.modules["audioop"] = types.ModuleType("audioop")
5
+ sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")
6
+
7
+ import os
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from datasets import load_dataset
11
+ import model_utils
12
+ import visualization_utils
13
+
14
+ # --- Configuration ---
15
+ MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k"
16
+ DATASET_PATH = "pawlo2013/chest_xray"
17
+ MODEL_DIR = "./models"
18
+ EXAMPLES_FOLDER = "./examples"
19
+
20
+ # --- Load Data & Model ---
21
+ print("Loading dataset information...")
22
+ try:
23
+ # We load the dataset mainly to get the class names correctly
24
+ train_dataset = load_dataset(DATASET_PATH, split="train")
25
+ class_names = train_dataset.features["label"].names
26
+ print(f"Class names loaded: {class_names}")
27
+ except Exception as e:
28
+ print(f"Warning: Could not load dataset, using default class names. Error: {e}")
29
+ # Fallback class names based on typical chest X-ray classification
30
+ class_names = ["NORMAL", "PNEUMONIA"]
31
+
32
+ print("Loading model and processor...")
33
+ try:
34
+ model, processor = model_utils.load_model_and_processor(
35
+ MODEL_DIR, MODEL_NAME_OR_PATH, class_names
36
+ )
37
+ print("Model and processor loaded successfully!")
38
+ except Exception as e:
39
+ print(f"Error loading model: {e}")
40
+ raise e
41
+
42
+ # --- Core Logic ---
43
+ def classify_and_visualize(img, device="cpu"):
44
+ if img is None:
45
+ return None, None
46
+
47
+ try:
48
+ # Get predictions
49
+ outputs, processed_input, probabilities, prediction_idx = model_utils.predict(
50
+ model, processor, img, device
51
+ )
52
+
53
+ # Format probabilities
54
+ result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
55
+
56
+ # Generate heatmap
57
+ heatmap_img = visualization_utils.show_final_layer_attention_maps(
58
+ outputs, processed_input, device
59
+ )
60
+
61
+ return result, heatmap_img
62
+
63
+ except Exception as e:
64
+ print(f"Error in classification: {e}")
65
+ # Return empty result and None for heatmap on error
66
+ empty_result = {class_name: 0.0 for class_name in class_names}
67
+ return empty_result, None
68
+
69
+ def format_output(img):
70
+ try:
71
+ probs, heatmap = classify_and_visualize(img)
72
+ return probs, heatmap
73
+ except Exception as e:
74
+ print(f"Error in format_output: {e}")
75
+ # Return empty results on error
76
+ empty_result = {class_name: 0.0 for class_name in class_names}
77
+ return empty_result, None
78
+
79
+ # --- Helper Functions ---
80
+ def load_examples_from_folder(folder_path):
81
+ examples = []
82
+ if os.path.exists(folder_path):
83
+ for file in os.listdir(folder_path):
84
+ if file.lower().endswith((".png", ".jpg", ".jpeg")):
85
+ examples.append(os.path.join(folder_path, file))
86
+ return examples
87
+
88
+ examples = load_examples_from_folder(EXAMPLES_FOLDER)
89
+
90
+ # --- UI Layout ---
91
+ title = "Pneumonia Detection Assistant"
92
+ description = """
93
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
94
+ <p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p>
95
+ <p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories
96
+ and provides an attention heatmap to show which areas influenced the decision.</p>
97
+ </div>
98
+ """
99
+ article = """
100
+ <div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;">
101
+ <h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3>
102
+ <p style="color: #7f8c8d;">
103
+ This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only.
104
+ <b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice.
105
+ Always consult with a qualified healthcare professional for medical diagnosis and treatment.
106
+ </p>
107
+ </div>
108
+ """
109
+
110
+ theme = gr.themes.Soft(
111
+ primary_hue="blue",
112
+ secondary_hue="slate",
113
+ ).set(
114
+ button_primary_background_fill="*primary_500",
115
+ button_primary_background_fill_hover="*primary_600",
116
+ )
117
+
118
+ iface = gr.Interface(
119
+ fn=format_output,
120
+ inputs=gr.Image(type="pil", label="Upload Chest X-Ray"),
121
+ outputs=[
122
+ gr.Label(label="Prediction Confidence", num_top_classes=3),
123
+ gr.Image(label="Attention Heatmap Analysis"),
124
+ ],
125
+ examples=examples,
126
+ cache_examples=False,
127
+ title=title,
128
+ description=description,
129
+ article=article,
130
+ theme=theme,
131
+ # allow_flagging="never"
132
+ )
133
+
134
+ # Launch
135
+ if __name__ == "__main__":
136
  iface.launch()