CGAllenger commited on
Commit
d26f326
·
verified ·
1 Parent(s): 1964b53

seperating mri and xray

Browse files
Files changed (1) hide show
  1. app.py +99 -30
app.py CHANGED
@@ -3,49 +3,118 @@ import numpy as np
3
  import tensorflow as tf
4
  from PIL import Image
5
 
6
- # 1. Load the Keras model directly from the local folder
7
- print("Loading model...")
8
- model = tf.keras.models.load_model("mri.keras")
 
 
 
9
 
10
- # 2. Class mappings based on the notebook training
11
- # Order: 'Glioma', 'Meningioma', 'Notumor', 'Pituitary'
12
- class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']
13
-
14
- def predict(image):
15
  if image is None:
16
  return None
17
 
18
- # 3. Preprocess the input
19
- # Expected: 168x168, grayscale, scaled by 1/255.0, with batch dimension
20
-
21
- # Convert image to grayscale and resize it to 168x168
22
  img = Image.fromarray(image).convert('L')
23
  img = img.resize((168, 168))
24
 
25
- # Convert to numpy array and normalize pixel values to [0, 1]
26
  img_array = np.array(img) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # The model expects input shape: (batch_size, 168, 168, 1)
29
- img_array = np.expand_dims(img_array, axis=-1) # Add channel dimension
30
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # 4. Make prediction
33
- predictions = model.predict(img_array)[0]
 
34
 
35
- # 5. Map the output to a clean, human-readable format for Gradio Interface
36
- # Convert probabilities to a dictionary mapping class name to confidence score
37
- confidences = {class_names[i]: float(predictions[i]) for i in range(len(class_names))}
 
 
 
 
38
  return confidences
39
 
40
- # 6. Define the Gradio interface
41
- interface = gr.Interface(
42
- fn=predict,
43
- inputs=gr.Image(label="Upload MRI Brain Scan"),
44
- outputs=gr.Label(num_top_classes=4, label="Prediction Confidence"),
45
- title="MRI Brain Tumor Classification",
46
- description="Upload an MRI scan to classify it into one of four categories: Glioma, Meningioma, No Tumor, or Pituitary Tumor.",
47
- flagging_mode="never"
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Launch the app
51
  if __name__ == "__main__":
 
3
  import tensorflow as tf
4
  from PIL import Image
5
 
6
+ # ==========================================
7
+ # 1. MRI Model Setup (Your Existing Model)
8
+ # ==========================================
9
+ print("Loading MRI model...")
10
+ mri_model = tf.keras.models.load_model("mri.keras")
11
+ mri_class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']
12
 
13
+ def predict_mri(image):
 
 
 
 
14
  if image is None:
15
  return None
16
 
17
+ # Preprocess the MRI
 
 
 
18
  img = Image.fromarray(image).convert('L')
19
  img = img.resize((168, 168))
20
 
 
21
  img_array = np.array(img) / 255.0
22
+ img_array = np.expand_dims(img_array, axis=-1)
23
+ img_array = np.expand_dims(img_array, axis=0)
24
+
25
+ # Predict
26
+ predictions = mri_model.predict(img_array)[0]
27
+ confidences = {mri_class_names[i]: float(predictions[i]) for i in range(len(mri_class_names))}
28
+ return confidences
29
+
30
+
31
+ # ==========================================
32
+ # 2. X-Ray Model Setup (Reconstructing from Weights)
33
+ # ==========================================
34
+ print("Building X-Ray model architecture...")
35
+
36
+ # The 14 classes from the NIH Chest X-Ray dataset
37
+ xray_class_names = [
38
+ 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
39
+ 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
40
+ 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation'
41
+ ]
42
+
43
+ def build_xray_model():
44
+ # Based on the Kaggle dataset description, the weights were trained
45
+ # on an EfficientNetB1 with a 128x128 input size.
46
+ base_model = tf.keras.applications.EfficientNetB1(
47
+ input_shape=(128, 128, 3),
48
+ weights=None, # We are loading custom weights next
49
+ include_top=False
50
+ )
51
+
52
+ model = tf.keras.Sequential([
53
+ base_model,
54
+ tf.keras.layers.GlobalAveragePooling2D(),
55
+ tf.keras.layers.Dense(1024, activation='relu'),
56
+ tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid') # Sigmoid for multi-label
57
+ ])
58
 
59
+ # Load the downloaded weights
60
+ model.load_weights("xray.h5")
61
+ return model
62
+
63
+ xray_model = build_xray_model()
64
+ print("X-Ray model loaded successfully.")
65
+
66
+ def predict_xray(image):
67
+ if image is None:
68
+ return None
69
+
70
+ # Preprocess the X-Ray input
71
+ img = Image.fromarray(image).convert('RGB') # EfficientNet expects 3 channels
72
+ img = img.resize((128, 128)) # The Kaggle dataset used 128x128
73
 
74
+ img_array = np.array(img)
75
+ # Keras EfficientNet applications have built-in rescaling,
76
+ # so we skip the / 255.0 step here.
77
 
78
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
79
+
80
+ # Predict
81
+ predictions = xray_model.predict(img_array)[0]
82
+
83
+ # Map probabilities to class names
84
+ confidences = {xray_class_names[i]: float(predictions[i]) for i in range(len(xray_class_names))}
85
  return confidences
86
 
87
+
88
+ # ==========================================
89
+ # 3. Define the Gradio Interface with Tabs
90
+ # ==========================================
91
+ with gr.Blocks(title="Medical Scan Classification") as interface:
92
+ gr.Markdown("# 🩺 Medical Scan Classifier")
93
+ gr.Markdown("Upload an **MRI Brain Scan** or a **Chest X-Ray** into the respective tabs below for AI-powered classification.")
94
+
95
+ with gr.Tabs():
96
+ # --- TAB 1: MRI ---
97
+ with gr.TabItem("MRI Brain Scan"):
98
+ with gr.Row():
99
+ with gr.Column():
100
+ mri_input = gr.Image(label="Upload MRI Brain Scan")
101
+ mri_button = gr.Button("Classify MRI", variant="primary")
102
+ with gr.Column():
103
+ mri_output = gr.Label(num_top_classes=4, label="Prediction Confidence")
104
+
105
+ mri_button.click(fn=predict_mri, inputs=mri_input, outputs=mri_output)
106
+
107
+ # --- TAB 2: X-Ray ---
108
+ with gr.TabItem("Chest X-Ray"):
109
+ with gr.Row():
110
+ with gr.Column():
111
+ xray_input = gr.Image(label="Upload Chest X-Ray")
112
+ xray_button = gr.Button("Classify X-Ray", variant="primary")
113
+ with gr.Column():
114
+ # Displaying top 5 conditions since there are 14 possible labels
115
+ xray_output = gr.Label(num_top_classes=5, label="Top 5 Predicted Conditions")
116
+
117
+ xray_button.click(fn=predict_xray, inputs=xray_input, outputs=xray_output)
118
 
119
  # Launch the app
120
  if __name__ == "__main__":