CGAllenger commited on
Commit
3b6abf7
·
verified ·
1 Parent(s): 51308a7

building wrapper sepearately

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import tensorflow as tf
4
  from PIL import Image
 
5
 
6
  # ==========================================
7
  # 1. MRI Model Setup (Your Existing Model)
@@ -29,11 +30,10 @@ def predict_mri(image):
29
 
30
 
31
  # ==========================================
32
- # 2. X-Ray Model Setup (Reconstructing from Weights)
33
  # ==========================================
34
  print("Building X-Ray model architecture...")
35
 
36
- # The 14 classes from the NIH Chest X-Ray dataset
37
  xray_class_names = [
38
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
39
  'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
@@ -41,11 +41,11 @@ xray_class_names = [
41
  ]
42
 
43
  def build_xray_model():
44
- # Based on the Kaggle dataset description, the weights were trained
45
- # on an EfficientNetB1 with a 128x128 input size.
46
- base_model = tf.keras.applications.EfficientNetB1(
47
  input_shape=(128, 128, 3),
48
- weights=None, # We are loading custom weights next
49
  include_top=False
50
  )
51
 
@@ -53,10 +53,10 @@ def build_xray_model():
53
  base_model,
54
  tf.keras.layers.GlobalAveragePooling2D(),
55
  tf.keras.layers.Dense(1024, activation='relu'),
56
- tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid') # Sigmoid for multi-label
57
  ])
58
 
59
- # Load the downloaded weights
60
  model.load_weights("xray.h5")
61
  return model
62
 
@@ -68,19 +68,19 @@ def predict_xray(image):
68
  return None
69
 
70
  # Preprocess the X-Ray input
71
- img = Image.fromarray(image).convert('RGB') # EfficientNet expects 3 channels
72
- img = img.resize((128, 128)) # The Kaggle dataset used 128x128
73
 
74
  img_array = np.array(img)
75
- # Keras EfficientNet applications have built-in rescaling,
76
- # so we skip the / 255.0 step here.
77
 
78
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
 
79
 
80
  # Predict
81
  predictions = xray_model.predict(img_array)[0]
82
 
83
- # Map probabilities to class names
84
  confidences = {xray_class_names[i]: float(predictions[i]) for i in range(len(xray_class_names))}
85
  return confidences
86
 
@@ -111,7 +111,6 @@ with gr.Blocks(title="Medical Scan Classification") as interface:
111
  xray_input = gr.Image(label="Upload Chest X-Ray")
112
  xray_button = gr.Button("Classify X-Ray", variant="primary")
113
  with gr.Column():
114
- # Displaying top 5 conditions since there are 14 possible labels
115
  xray_output = gr.Label(num_top_classes=5, label="Top 5 Predicted Conditions")
116
 
117
  xray_button.click(fn=predict_xray, inputs=xray_input, outputs=xray_output)
 
2
  import numpy as np
3
  import tensorflow as tf
4
  from PIL import Image
5
+ import efficientnet.tfkeras as efn # <-- Add this import
6
 
7
  # ==========================================
8
  # 1. MRI Model Setup (Your Existing Model)
 
30
 
31
 
32
  # ==========================================
33
+ # 2. X-Ray Model Setup (Using original EfficientNet library)
34
  # ==========================================
35
  print("Building X-Ray model architecture...")
36
 
 
37
  xray_class_names = [
38
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
39
  'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
 
41
  ]
42
 
43
  def build_xray_model():
44
+ # Use the 'efn' library instead of tf.keras.applications
45
+ # This guarantees the architecture has exactly 437 weights as expected.
46
+ base_model = efn.EfficientNetB1(
47
  input_shape=(128, 128, 3),
48
+ weights=None,
49
  include_top=False
50
  )
51
 
 
53
  base_model,
54
  tf.keras.layers.GlobalAveragePooling2D(),
55
  tf.keras.layers.Dense(1024, activation='relu'),
56
+ tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid')
57
  ])
58
 
59
+ # Load weights should now perfectly match 437 to 437
60
  model.load_weights("xray.h5")
61
  return model
62
 
 
68
  return None
69
 
70
  # Preprocess the X-Ray input
71
+ img = Image.fromarray(image).convert('RGB')
72
+ img = img.resize((128, 128))
73
 
74
  img_array = np.array(img)
75
+ img_array = np.expand_dims(img_array, axis=0)
 
76
 
77
+ # Use the library's built-in preprocessing to match training conditions
78
+ img_array = efn.preprocess_input(img_array)
79
 
80
  # Predict
81
  predictions = xray_model.predict(img_array)[0]
82
 
83
+ # Map probabilities
84
  confidences = {xray_class_names[i]: float(predictions[i]) for i in range(len(xray_class_names))}
85
  return confidences
86
 
 
111
  xray_input = gr.Image(label="Upload Chest X-Ray")
112
  xray_button = gr.Button("Classify X-Ray", variant="primary")
113
  with gr.Column():
 
114
  xray_output = gr.Label(num_top_classes=5, label="Top 5 Predicted Conditions")
115
 
116
  xray_button.click(fn=predict_xray, inputs=xray_input, outputs=xray_output)