CGAllenger commited on
Commit
8c5d0a5
·
verified ·
1 Parent(s): 1faf60c

unifying the process

Browse files
Files changed (1) hide show
  1. app.py +72 -43
app.py CHANGED
@@ -4,9 +4,25 @@ import tensorflow as tf
4
  from PIL import Image
5
  import efficientnet.tfkeras as efn
6
  import random
 
 
7
 
8
  # ==========================================
9
- # 1. MRI Model Setup (Your Existing Model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # ==========================================
11
  print("Loading MRI model...")
12
  mri_model = tf.keras.models.load_model("mri.keras")
@@ -16,7 +32,6 @@ def predict_mri(image):
16
  if image is None:
17
  return None
18
 
19
- # Preprocess the MRI
20
  img = Image.fromarray(image).convert('L')
21
  img = img.resize((168, 168))
22
 
@@ -24,28 +39,21 @@ def predict_mri(image):
24
  img_array = np.expand_dims(img_array, axis=-1)
25
  img_array = np.expand_dims(img_array, axis=0)
26
 
27
- # Predict
28
  predictions = mri_model.predict(img_array)[0]
29
 
30
- # Apply the 3% to 7% random reduction
31
  confidences = {}
32
  for i in range(len(mri_class_names)):
33
  original_conf = float(predictions[i])
34
  random_drop = random.uniform(0.03, 0.07)
35
-
36
- # Ensure it doesn't drop below 0
37
  adjusted_conf = max(0.0, original_conf - random_drop)
38
-
39
- # Rounding to 4 decimal places
40
  confidences[mri_class_names[i]] = round(adjusted_conf, 4)
41
 
42
  return confidences
43
 
44
  # ==========================================
45
- # 2. X-Ray Model Setup (Using original EfficientNet library)
46
  # ==========================================
47
  print("Building X-Ray model architecture...")
48
-
49
  xray_class_names = [
50
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
51
  'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
@@ -76,7 +84,6 @@ def predict_xray(image):
76
  if image is None:
77
  return None
78
 
79
- # Preprocess the X-Ray input
80
  img = Image.fromarray(image).convert('RGB')
81
  img = img.resize((128, 128))
82
 
@@ -85,54 +92,76 @@ def predict_xray(image):
85
 
86
  img_array = efn.preprocess_input(img_array)
87
 
88
- # Predict
89
  predictions = xray_model.predict(img_array)[0]
90
 
91
- # Apply the 3% to 7% random reduction
92
  confidences = {}
93
  for i in range(len(xray_class_names)):
94
  original_conf = float(predictions[i])
95
  random_drop = random.uniform(0.03, 0.07)
96
-
97
- # Ensure it doesn't drop below 0
98
  adjusted_conf = max(0.0, original_conf - random_drop)
99
-
100
- # Rounding to 4 decimal places
101
  confidences[xray_class_names[i]] = round(adjusted_conf, 4)
102
 
103
  return confidences
104
 
105
  # ==========================================
106
- # 3. Define the Gradio Interface with Tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # ==========================================
108
  with gr.Blocks(title="Medical Scan Classification") as interface:
109
  gr.Markdown("# 🩺 Medical Scan Classifier")
110
- gr.Markdown("Upload an **MRI Brain Scan** or a **Chest X-Ray** into the respective tabs below for AI-powered classification.")
111
 
112
- with gr.Tabs():
113
- # --- TAB 1: MRI ---
114
- with gr.TabItem("MRI Brain Scan"):
115
- with gr.Row():
116
- with gr.Column():
117
- mri_input = gr.Image(label="Upload MRI Brain Scan")
118
- mri_button = gr.Button("Classify MRI", variant="primary")
119
- with gr.Column():
120
- # CHANGE APPLIED HERE: num_top_classes changed to 1
121
- mri_output = gr.Label(num_top_classes=1, label="Top Predicted Condition")
122
-
123
- mri_button.click(fn=predict_mri, inputs=mri_input, outputs=mri_output)
124
 
125
- # --- TAB 2: X-Ray ---
126
- with gr.TabItem("Chest X-Ray"):
127
- with gr.Row():
128
- with gr.Column():
129
- xray_input = gr.Image(label="Upload Chest X-Ray")
130
- xray_button = gr.Button("Classify X-Ray", variant="primary")
131
- with gr.Column():
132
- xray_output = gr.Label(num_top_classes=2, label="Top 2 Predicted Conditions")
133
-
134
- xray_button.click(fn=predict_xray, inputs=xray_input, outputs=xray_output)
135
-
136
- # Launch the app
137
  if __name__ == "__main__":
138
  interface.launch()
 
4
  from PIL import Image
5
  import efficientnet.tfkeras as efn
6
  import random
7
+ import torch
8
+ from open_clip import create_model_and_transforms, get_tokenizer
9
 
10
  # ==========================================
11
+ # 1. Modality Router Setup (BiomedCLIP)
12
+ # ==========================================
13
+ print("Loading BiomedCLIP Router...")
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ clip_model_name = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
16
+ clip_model, _, clip_preprocess = create_model_and_transforms(clip_model_name)
17
+ clip_model = clip_model.to(device)
18
+ clip_tokenizer = get_tokenizer(clip_model_name)
19
+
20
+ # Define the text embeddings for routing
21
+ router_labels = ['an MRI brain scan', 'a chest X-ray']
22
+ text_tokens = clip_tokenizer(router_labels).to(device)
23
+
24
+ # ==========================================
25
+ # 2. MRI Model Setup
26
  # ==========================================
27
  print("Loading MRI model...")
28
  mri_model = tf.keras.models.load_model("mri.keras")
 
32
  if image is None:
33
  return None
34
 
 
35
  img = Image.fromarray(image).convert('L')
36
  img = img.resize((168, 168))
37
 
 
39
  img_array = np.expand_dims(img_array, axis=-1)
40
  img_array = np.expand_dims(img_array, axis=0)
41
 
 
42
  predictions = mri_model.predict(img_array)[0]
43
 
 
44
  confidences = {}
45
  for i in range(len(mri_class_names)):
46
  original_conf = float(predictions[i])
47
  random_drop = random.uniform(0.03, 0.07)
 
 
48
  adjusted_conf = max(0.0, original_conf - random_drop)
 
 
49
  confidences[mri_class_names[i]] = round(adjusted_conf, 4)
50
 
51
  return confidences
52
 
53
  # ==========================================
54
+ # 3. X-Ray Model Setup
55
  # ==========================================
56
  print("Building X-Ray model architecture...")
 
57
  xray_class_names = [
58
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
59
  'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
 
84
  if image is None:
85
  return None
86
 
 
87
  img = Image.fromarray(image).convert('RGB')
88
  img = img.resize((128, 128))
89
 
 
92
 
93
  img_array = efn.preprocess_input(img_array)
94
 
 
95
  predictions = xray_model.predict(img_array)[0]
96
 
 
97
  confidences = {}
98
  for i in range(len(xray_class_names)):
99
  original_conf = float(predictions[i])
100
  random_drop = random.uniform(0.03, 0.07)
 
 
101
  adjusted_conf = max(0.0, original_conf - random_drop)
 
 
102
  confidences[xray_class_names[i]] = round(adjusted_conf, 4)
103
 
104
  return confidences
105
 
106
  # ==========================================
107
+ # 4. Master Routing Function
108
+ # ==========================================
109
+ def process_scan(image):
110
+ if image is None:
111
+ return "No image provided.", None
112
+
113
+ # Step A: Preprocess for CLIP
114
+ img_pil = Image.fromarray(image).convert('RGB')
115
+ img_tensor = clip_preprocess(img_pil).unsqueeze(0).to(device)
116
+
117
+ # Step B: Calculate Modality Probabilities
118
+ with torch.no_grad():
119
+ image_features = clip_model.encode_image(img_tensor)
120
+ text_features = clip_model.encode_text(text_tokens)
121
+
122
+ image_features /= image_features.norm(dim=-1, keepdim=True)
123
+ text_features /= text_features.norm(dim=-1, keepdim=True)
124
+
125
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)[0]
126
+
127
+ mri_prob = text_probs[0].item()
128
+ xray_prob = text_probs[1].item()
129
+
130
+ # Step C: Route to Specific Model
131
+ if mri_prob > xray_prob:
132
+ modality_status = f"🧠 Modality Detected: MRI Brain Scan (Confidence: {mri_prob:.1%})"
133
+ diagnostic_results = predict_mri(image)
134
+ # We only want top 1 for MRI based on your previous UI setup
135
+ top_k = 1
136
+ else:
137
+ modality_status = f"🩻 Modality Detected: Chest X-Ray (Confidence: {xray_prob:.1%})"
138
+ diagnostic_results = predict_xray(image)
139
+ # We want top 2 for X-Ray based on your previous UI setup
140
+ top_k = 2
141
+
142
+ return modality_status, diagnostic_results
143
+
144
+ # ==========================================
145
+ # 5. Define the Unified Gradio Interface
146
  # ==========================================
147
  with gr.Blocks(title="Medical Scan Classification") as interface:
148
  gr.Markdown("# 🩺 Medical Scan Classifier")
149
+ gr.Markdown("Upload **any** scan (MRI Brain Scan or Chest X-Ray). The system will automatically detect the modality and route it to the appropriate diagnostic model.")
150
 
151
+ with gr.Row():
152
+ with gr.Column():
153
+ scan_input = gr.Image(label="Upload Medical Scan")
154
+ analyze_button = gr.Button("Analyze Scan", variant="primary")
155
+
156
+ with gr.Column():
157
+ modality_output = gr.Textbox(label="Detection Routing Status", interactive=False)
158
+ diagnostic_output = gr.Label(label="Predicted Conditions")
 
 
 
 
159
 
160
+ analyze_button.click(
161
+ fn=process_scan,
162
+ inputs=scan_input,
163
+ outputs=[modality_output, diagnostic_output]
164
+ )
165
+
 
 
 
 
 
 
166
  if __name__ == "__main__":
167
  interface.launch()