emiraran commited on
Commit
4ba86b2
·
verified ·
1 Parent(s): 3bc158a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -158
app.py CHANGED
@@ -1,167 +1,146 @@
1
- """
2
- BraTS 2020 Brain Tumor Segmentation - Gradio Web Interface
3
- Interactive inference for brain MRI tumor segmentation
4
- """
5
-
6
  import gradio as gr
7
  import numpy as np
8
  import tensorflow as tf
9
- import h5py
10
- import matplotlib.pyplot as plt
11
- import io
12
- from PIL import Image
13
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Model configuration
16
- os.environ["SM_FRAMEWORK"] = "tf.keras"
17
- import segmentation_models as sm
18
 
19
- class BraTSInference:
20
- def __init__(self):
21
- print("Loading model...")
22
- self.model = sm.Unet(
23
- 'efficientnetb0',
24
- input_shape=(256, 256, 4),
25
- classes=3,
26
- activation='sigmoid',
27
- encoder_weights=None
28
- )
29
- self.model.load_weights('best_model_weights.h5')
30
- self.class_names = ['Necrotic', 'Enhancing Tumor', 'Edema']
31
-
32
- def predict_from_h5(self, h5_file):
33
- """H5 dosyasından segmentasyon yap"""
34
- try:
35
- # H5 dosyasını oku
36
- with h5py.File(h5_file.name, 'r') as f:
37
- image = f['image'][:]
38
-
39
- # Resize
40
- image_resized = tf.image.resize(image, (256, 256), method='bilinear').numpy()
41
-
42
- # Normalize (opsiyonel)
43
- image_resized = (image_resized - image_resized.min()) / (image_resized.max() - image_resized.min() + 1e-7)
44
-
45
- # Predict
46
- image_batch = np.expand_dims(image_resized, axis=0)
47
- prediction = self.model.predict(image_batch, verbose=0)[0]
48
-
49
- # Görselleştir
50
- fig, axes = plt.subplots(2, 4, figsize=(14, 7))
51
- fig.suptitle('BraTS 2020 Brain Tumor Segmentation', fontsize=14, fontweight='bold')
52
-
53
- # MRI channels
54
- modalities = ['T1', 'T1c', 'T2', 'FLAIR']
55
- for i, mod in enumerate(modalities):
56
- axes[0, i].imshow(image_resized[:, :, i], cmap='gray')
57
- axes[0, i].set_title(mod)
58
- axes[0, i].axis('off')
59
-
60
- # Predictions
61
- for i, class_name in enumerate(self.class_names):
62
- axes[1, i].imshow(image_resized[:, :, 0], cmap='gray', alpha=0.5)
63
- axes[1, i].imshow(prediction[:, :, i], cmap='hot', alpha=0.5)
64
- axes[1, i].set_title(f'{class_name}')
65
- axes[1, i].axis('off')
66
-
67
- plt.tight_layout()
68
-
69
- # Save to buffer
70
- buf = io.BytesIO()
71
- plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
72
- buf.seek(0)
73
- plt.close()
74
-
75
- # Metrics
76
- metrics_text = "📊 SEGMENTATION RESULTS\n\n"
77
- for i, class_name in enumerate(self.class_names):
78
- pred_binary = (prediction[:, :, i] > 0.5).astype(np.float32)
79
- area_pixels = np.sum(pred_binary)
80
- area_mm2 = area_pixels * 0.94 * 0.94
81
- percentage = (area_pixels / (256*256)) * 100
82
-
83
- metrics_text += f"{class_name}:\n"
84
- metrics_text += f" • Area: {area_mm2:.2f} mm²\n"
85
- metrics_text += f" • Coverage: {percentage:.2f}%\n\n"
86
-
87
- metrics_text += "⚠️ Disclaimer: This model requires radiologist validation"
88
-
89
- return Image.open(buf), metrics_text
90
-
91
- except Exception as e:
92
- return None, f"❌ Error: {str(e)}"
93
 
94
- # Initialize
95
- inference = BraTSInference()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Gradio Interface
98
- def create_interface():
99
- with gr.Blocks(title="BraTS 2020 Segmentation") as demo:
100
- gr.Markdown("# 🧠 BraTS 2020 Brain Tumor Segmentation")
101
- gr.Markdown("""
102
- Advanced U-Net model for multi-class brain tumor segmentation
103
- - **Necrotic Tissue**: Dead tumor region
104
- - **Enhancing Tumor**: Active tumor with contrast enhancement
105
- - **Edema**: Tumor surrounding tissue
106
- """)
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- file_input = gr.File(
111
- label="📁 Upload H5 File",
112
- file_types=[".h5"]
113
- )
114
- submit_btn = gr.Button("🔍 Analyze", variant="primary")
115
-
116
- with gr.Column():
117
- output_image = gr.Image(
118
- label="📊 Segmentation Results",
119
- type="pil"
120
- )
121
- output_text = gr.Textbox(
122
- label="📈 Metrics",
123
- lines=8,
124
- interactive=False
125
- )
126
-
127
- submit_btn.click(
128
- fn=inference.predict_from_h5,
129
- inputs=file_input,
130
- outputs=[output_image, output_text]
131
- )
132
-
133
- # Examples
134
- gr.Examples(
135
- examples=[
136
- ["volume_100_slice_100.h5"],
137
- ["volume_100_slice_105.h5"],
138
- ["volume_100_slice_110.h5"],
139
- ["volume_100_slice_115.h5"],
140
- ["volume_100_slice_120.h5"],
141
- ],
142
- inputs=file_input,
143
- label="📂 Example MRI Files (Preloaded from BraTS 2020)"
144
- )
145
-
146
- gr.Markdown("""
147
- ### 📋 Model Information
148
- - **Architecture**: U-Net with EfficientNetB0 Encoder
149
- - **Performance**: 90.34% Sensitivity | 99.96% Specificity
150
- - **Dataset**: BraTS 2020 Training (14,299 slices)
151
-
152
- ### ⚠️ Important Notes
153
- - Model trained on BraTS 2020 dataset
154
- - Requires radiologist validation for clinical use
155
- - Not for autonomous clinical decisions
156
-
157
- ### 📚 References
158
- - [BraTS Challenge](https://www.med.upenn.edu/cbica/brats2020/)
159
- - [Segmentation Models](https://github.com/qubvel/segmentation_models)
160
- - [U-Net Paper](https://arxiv.org/abs/1505.04597)
161
- """)
162
-
163
- return demo
 
 
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
- demo = create_interface()
167
- demo.launch(share=True)
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import tensorflow as tf
4
+ import tensorflow.keras.backend as K
5
+ from tensorflow.keras.layers import Dense, BatchNormalization, Dropout, Input
6
+ from tensorflow.keras.models import Model
7
+ import cv2
8
+
9
+ # Sınıf isimleri
10
+ class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary"]
11
+
12
+ # Model yapısını reconstruct et
13
+ def build_model():
14
+ img_size = (224, 224)
15
+ inputs = tf.keras.Input(shape=img_size + (3,))
16
+ base_model = tf.keras.applications.efficientnet.EfficientNetB3(
17
+ include_top=False,
18
+ weights="imagenet",
19
+ input_tensor=inputs,
20
+ pooling='max'
21
+ )
22
+ base_model.trainable = True
23
+
24
+ x = base_model.output
25
+ x = Dense(256, activation='relu')(x)
26
+ x = BatchNormalization()(x)
27
+ x = Dropout(0.3)(x)
28
+ outputs = Dense(len(class_names), activation='softmax')(x)
29
+
30
+ model = Model(inputs, outputs)
31
+ return model
32
 
33
+ # Model oluştur ve weights yükle
34
+ model = build_model()
35
+ model.load_weights("best_weights_balanced.h5")
36
 
37
+ # Son conv layer'ı otomatik bul
38
+ def get_last_conv_layer_name(model):
39
+ """Find the last convolutional layer in the model"""
40
+ for layer in reversed(model.layers):
41
+ if 'conv' in layer.name.lower():
42
+ return layer.name
43
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Grad-CAM fonksiyonu
46
+ def get_gradcam(img_array, model, last_conv_layer_name):
47
+ grad_model = tf.keras.models.Model(
48
+ [model.inputs],
49
+ [model.get_layer(last_conv_layer_name).output, model.output]
50
+ )
51
+
52
+ with tf.GradientTape() as tape:
53
+ conv_outputs, predictions = grad_model(img_array)
54
+ pred_index = tf.argmax(predictions[0])
55
+ class_channel = predictions[:, pred_index]
56
+
57
+ grads = tape.gradient(class_channel, conv_outputs)
58
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
59
+
60
+ conv_outputs = conv_outputs[0]
61
+ heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
62
+ heatmap = tf.squeeze(heatmap)
63
+
64
+ # Normalize to 0-1
65
+ heatmap_min = tf.math.reduce_min(heatmap)
66
+ heatmap_max = tf.math.reduce_max(heatmap)
67
+ heatmap = (heatmap - heatmap_min) / (heatmap_max - heatmap_min + K.epsilon())
68
+
69
+ return heatmap.numpy(), pred_index.numpy()
70
 
71
+ def predict_and_explain(img):
72
+ # Görüntüyü hazırla
73
+ img_resized = cv2.resize(img, (224, 224))
74
+
75
+ # Gradio'dan gelen image 0-255 range'de
76
+ # preprocess_input bu range'i normalize ediyor
77
+ img_array = np.expand_dims(img_resized, axis=0)
78
+ img_array = tf.keras.applications.efficientnet.preprocess_input(img_array.astype(np.float32))
79
+
80
+ # Tahmin
81
+ predictions = model.predict(img_array, verbose=0)
82
+ pred_class = np.argmax(predictions[0])
83
+ confidence = predictions[0][pred_class]
84
+
85
+ # Grad-CAM - son conv layer'ı bul
86
+ last_conv_layer_name = get_last_conv_layer_name(model)
87
+ heatmap, _ = get_gradcam(img_array, model, last_conv_layer_name)
88
+ heatmap = cv2.resize(heatmap, (img_resized.shape[1], img_resized.shape[0]))
89
+ # Heatmap'ı ters çevir: kırmızı = model odaklandığı yer
90
+ heatmap = 1 - heatmap
91
+ heatmap = np.uint8(255 * heatmap)
92
+ heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
93
+
94
+ # Original image for overlay (normalize to 0-255)
95
+ img_for_display = cv2.resize(img, (224, 224))
96
+ if img_for_display.max() <= 1.0:
97
+ img_for_display = (img_for_display * 255).astype(np.uint8)
98
+
99
+ # Overlay
100
+ superimposed = cv2.addWeighted(img_for_display, 0.6, heatmap_colored, 0.4, 0)
101
+
102
+ # Sonuçlar
103
+ results = {class_names[i]: float(predictions[0][i]) for i in range(4)}
104
+
105
+ return results, superimposed
106
+
107
+ # Gradio arayüzü
108
+ demo = gr.Interface(
109
+ fn=predict_and_explain,
110
+ inputs=gr.Image(label="Upload Brain MRI Image"),
111
+ outputs=[
112
+ gr.Label(num_top_classes=4, label="Prediction Confidence"),
113
+ gr.Image(label="Grad-CAM Explanation (Red = High Attention)")
114
+ ],
115
+ title="🧠 Brain Tumor MRI Classification (99% Accuracy)",
116
+ description="""
117
+ **EfficientNetB3 + Grad-CAM Explainable AI**
118
+
119
+ This model classifies brain MRI images into 4 categories:
120
+ - **Glioma** - Tumor from glial cells (malignant)
121
+ - **Meningioma** - Tumor from meninges (usually benign)
122
+ - **Pituitary** - Pituitary gland tumor
123
+ - **No Tumor** - Normal brain tissue
124
+
125
+ **Model Performance** (Test Accuracy: 99.11%):
126
+ - Sensitivity: >96% for all tumor types
127
+ - Specificity: >99% for all classes
128
+ - Zero false negatives for tumor detection
129
+
130
+ Grad-CAM visualization shows which regions the model focuses on for its decision.
131
+
132
+ ⚠️ **DISCLAIMER**: This tool is for research and educational purposes only.
133
+ NOT approved for clinical diagnosis. Always consult qualified medical professionals.
134
+
135
+ 📊 **Usage Instructions**:
136
+ 1. Upload a brain MRI image (axial T1/T2 view preferred)
137
+ 2. Model will predict tumor type with confidence score
138
+ 3. Grad-CAM heatmap shows areas of focus (red = high attention)
139
+ 4. If confidence < 80%, consider expert review
140
+ """,
141
+ examples=[], # Örnek görüntü ekleyebilirsin
142
+ theme=gr.themes.Soft()
143
+ )
144
 
145
  if __name__ == "__main__":
146
+ demo.launch()