kmunzwa commited on
Commit
3d40fd0
·
verified ·
1 Parent(s): b436cce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -118
app.py CHANGED
@@ -1,14 +1,17 @@
1
- # We import gradio which is the library we use to build the web interface
2
  import gradio as gr
3
 
4
- # numpy is used for numerical operations and array manipulation
5
  import numpy as np
6
 
7
- # ai_edge_litert is Google's official replacement for tf.lite
8
- # it is lighter, faster, and works correctly on Python 3.13
9
- from ai_edge_litert.interpreter import Interpreter
10
 
11
- # PIL (Pillow) is used to handle image loading and resizing
 
 
 
 
12
  from PIL import Image
13
 
14
 
@@ -16,58 +19,48 @@ from PIL import Image
16
  # LOAD THE MODEL
17
  # ------------------------------------
18
 
19
- # This loads the tflite model file from the current directory
20
- # Interpreter is the ai_edge_litert equivalent of tf.lite.Interpreter
21
- interpreter = Interpreter(model_path="best_gatekeeper_v2.pth")
22
-
23
- # This allocates memory for the model's input and output tensors
24
- # You must always call this before running inference
25
- interpreter.allocate_tensors()
26
-
27
- # This gets the details of the input tensor
28
- # It tells us the expected shape, data type, and index of the input
29
- input_details = interpreter.get_input_details()
30
-
31
- # This gets the details of the output tensor
32
- # It tells us the shape and index of the output so we can read predictions
33
- output_details = interpreter.get_output_details()
34
 
35
- # This is the image size the model expects
36
- # ResNet50 was trained on 224x224 images so we keep this the same
37
- INPUT_SIZE = (224, 224)
38
 
 
 
 
 
39
 
40
- # ------------------------------------
41
- # IMAGE PREPROCESSING FUNCTION
42
- # ------------------------------------
 
43
 
44
- def preprocess_image(image):
45
- # image comes in as a numpy array from gradio
46
- # we convert it to a PIL Image object so we can resize it easily
47
- # we also make sure it is in RGB format (3 channels: red, green, blue)
48
- img = Image.fromarray(image).convert("RGB")
49
 
50
- # we resize the image to match what the model expects
51
- # if the image is not 224x224 the model will throw a shape error
52
- img = img.resize(INPUT_SIZE)
53
 
54
- # we convert the PIL image back to a numpy array
55
- # dtype=np.float32 is important because the model expects 32-bit floats
56
- img = np.array(img, dtype=np.float32)
57
 
58
- # we divide all pixel values by 255
59
- # this converts pixel values from the range [0, 255] to [0, 1]
60
- # this is called normalization and it helps the model perform correctly
61
- img = img / 255.0
62
 
63
- # the model expects a batch of images, not a single image
64
- # so we add an extra dimension at position 0
65
- # this changes the shape from (224, 224, 3) to (1, 224, 224, 3)
66
- # the 1 represents a batch size of 1 (one image at a time)
67
- img = np.expand_dims(img, axis=0)
68
 
69
- # we return the fully preprocessed image ready for inference
70
- return img
 
 
 
 
71
 
72
 
73
  # ------------------------------------
@@ -75,53 +68,40 @@ def preprocess_image(image):
75
  # ------------------------------------
76
 
77
  def classify_image(image):
78
- # if the user clicks the button without uploading an image
79
- # we return None for the scores and a warning message
80
  if image is None:
81
  return None, "Please upload an image first"
82
 
83
- # we send the image through our preprocessing function
84
- processed = preprocess_image(image)
85
-
86
- # we load the preprocessed image into the model's input tensor
87
- # input_details[0]['index'] gives us the correct tensor index to write to
88
- interpreter.set_tensor(input_details[0]['index'], processed)
89
-
90
- # this actually runs the model on the input we just loaded
91
- interpreter.invoke()
92
 
93
- # this reads the output from the model after inference is complete
94
- # output_details[0]['index'] gives us the correct tensor index to read from
95
- output = interpreter.get_tensor(output_details[0]['index'])
96
 
97
- # we print the raw output to the console for debugging purposes
98
- # this is useful to confirm the model is producing expected values
99
- print(f"Raw model output: {output}")
 
100
 
101
- # index 0 of the output corresponds to Non-Cervix probability
102
- # we convert it to a plain Python float for easier handling
103
- prob_non_cervix = float(output[0][0])
104
 
105
- # index 1 of the output corresponds to Cervix probability
106
- prob_cervix = float(output[0][1])
107
 
108
- # we compare the two probabilities to determine the final prediction
109
- # whichever class has the higher probability is our prediction
110
  if prob_cervix > prob_non_cervix:
111
  prediction_text = "Cervix Detected"
112
  else:
113
  prediction_text = "Non-Cervix"
114
 
115
- # we build a dictionary of class names mapped to their confidence scores
116
- # gradio's Label component accepts this format and displays it as a bar chart
117
- # we round to 4 decimal places to keep the display clean
118
  scores = {
119
  "Cervix": round(prob_cervix, 4),
120
- "Non-Cervix": round(prob_non_cervix, 4)
121
  }
122
 
123
- # we return both the scores dictionary and the prediction text
124
- # these map to the two output components in the gradio interface
125
  return scores, prediction_text
126
 
127
 
@@ -129,11 +109,8 @@ def classify_image(image):
129
  # GRADIO USER INTERFACE
130
  # ------------------------------------
131
 
132
- # gr.Blocks gives us full control over the layout of the interface
133
- # theme=gr.themes.Soft() gives it a clean and soft visual style
134
  with gr.Blocks(theme=gr.themes.Soft()) as app:
135
 
136
- # gr.Markdown renders formatted text at the top of the page
137
  gr.Markdown("""
138
  # Gatekeeper Model
139
  ### Cervix Image Binary Classifier
@@ -141,58 +118,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
141
  ---
142
  """)
143
 
144
- # gr.Row arranges the components inside it horizontally side by side
145
  with gr.Row():
146
 
147
- # the first column holds the input components on the left side
148
  with gr.Column():
149
-
150
- # gr.Image creates an image upload box
151
- # type="numpy" means the image will be passed to our function
152
- # as a numpy array which is what we need for preprocessing
153
  input_image = gr.Image(
154
  label="Upload Image",
155
  type="numpy"
156
  )
157
-
158
- # this is the main button the user clicks to run the model
159
- # variant="primary" makes it stand out visually as the main action
160
- # size="lg" makes it large and easy to click
161
  classify_btn = gr.Button(
162
  "Run Classification",
163
  variant="primary",
164
  size="lg"
165
  )
166
-
167
- # this is a secondary button to reset the interface
168
- # variant="secondary" gives it a less prominent visual style
169
  clear_btn = gr.Button(
170
  "Clear",
171
  variant="secondary",
172
  size="sm"
173
  )
174
 
175
- # the second column holds the output components on the right side
176
  with gr.Column():
177
-
178
- # gr.Label displays the confidence scores as a visual bar chart
179
- # num_top_classes=2 tells it to show both classes
180
  output_scores = gr.Label(
181
  label="Confidence Scores",
182
  num_top_classes=2
183
  )
184
-
185
- # gr.Textbox displays the final prediction as plain text
186
- # interactive=False means the user cannot edit it
187
- # it is read-only output only
188
  output_text = gr.Textbox(
189
  label="Prediction",
190
  interactive=False,
191
  text_align="center"
192
  )
193
 
194
- # this adds a reference table at the bottom so users understand
195
- # what the two class indices mean
196
  gr.Markdown("""
197
  ---
198
  | Index | Label | Meaning |
@@ -205,29 +159,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
205
  It is not intended for clinical diagnosis or medical use.
206
  """)
207
 
208
- # ------------------------------------
209
- # BUTTON ACTIONS
210
- # ------------------------------------
211
-
212
- # this connects the classify button to the classify_image function
213
- # inputs tells gradio which component to read from
214
- # outputs tells gradio which components to write the results to
215
  classify_btn.click(
216
  fn=classify_image,
217
  inputs=input_image,
218
  outputs=[output_scores, output_text]
219
  )
220
 
221
- # this connects the clear button to a simple lambda function
222
- # a lambda is a small anonymous function defined in one line
223
- # it returns None for the image, None for scores, and empty string for text
224
- # this effectively resets all three components back to their empty state
225
  clear_btn.click(
226
  fn=lambda: (None, None, ""),
227
  inputs=None,
228
  outputs=[input_image, output_scores, output_text]
229
  )
230
 
231
- # this starts the gradio web server and launches the interface
232
- # on hugging face spaces this is called automatically
233
  app.launch()
 
1
+ # gradio is the library used to build the web interface
2
  import gradio as gr
3
 
4
+ # numpy is used for numerical operations
5
  import numpy as np
6
 
7
+ # torch is the core PyTorch library used to run the model
8
+ import torch
 
9
 
10
+ # torchvision provides the ResNet50 architecture and image transforms
11
+ import torchvision.transforms as transforms
12
+ from torchvision import models
13
+
14
+ # PIL is used for image loading and conversion
15
  from PIL import Image
16
 
17
 
 
19
  # LOAD THE MODEL
20
  # ------------------------------------
21
 
22
+ # we detect whether a GPU is available and fall back to CPU if not
23
+ # hugging face free tier runs on CPU so this will almost always be cpu
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f"Running on: {device}")
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # we recreate the ResNet50 architecture
28
+ # weights=None because we will load our own trained weights below
29
+ model = models.resnet50(weights=None)
30
 
31
+ # the original ResNet50 outputs 1000 classes (ImageNet)
32
+ # we replace the final fully connected layer to output 2 classes:
33
+ # class 0 = Non-Cervix, class 1 = Cervix
34
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
35
 
36
+ # we load the saved weights from the .pth file
37
+ # map_location=device ensures it loads correctly even without a GPU
38
+ state_dict = torch.load("best_gatekeeper_v2.pth", map_location=device)
39
+ model.load_state_dict(state_dict)
40
 
41
+ # we move the model to the correct device (CPU or GPU)
42
+ model = model.to(device)
 
 
 
43
 
44
+ # we set the model to evaluation mode
45
+ # this disables dropout and batch normalisation training behaviour
46
+ model.eval()
47
 
48
+ print("Gatekeeper model loaded successfully")
 
 
49
 
50
+ # this is the image size ResNet50 expects
51
+ INPUT_SIZE = 224
 
 
52
 
53
+ # these are the standard ImageNet normalisation values
54
+ # ResNet50 was pretrained on ImageNet so we use the same values
55
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
56
+ IMAGENET_STD = [0.229, 0.224, 0.225]
 
57
 
58
+ # we define the preprocessing pipeline using torchvision transforms
59
+ preprocess = transforms.Compose([
60
+ transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
61
+ transforms.ToTensor(), # converts [0,255] → [0,1]
62
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
63
+ ])
64
 
65
 
66
  # ------------------------------------
 
68
  # ------------------------------------
69
 
70
  def classify_image(image):
71
+ # if the user submits without an image return a warning
 
72
  if image is None:
73
  return None, "Please upload an image first"
74
 
75
+ # convert the numpy array from gradio to a PIL Image in RGB format
76
+ img = Image.fromarray(image).convert("RGB")
 
 
 
 
 
 
 
77
 
78
+ # apply the preprocessing pipeline and add a batch dimension
79
+ # unsqueeze(0) changes shape from (3, 224, 224) to (1, 3, 224, 224)
80
+ tensor = preprocess(img).unsqueeze(0).to(device)
81
 
82
+ # run inference without computing gradients (saves memory and is faster)
83
+ with torch.no_grad():
84
+ output = model(tensor) # raw logits shape: (1, 2)
85
+ probs = torch.softmax(output, dim=1)[0] # convert to probabilities
86
 
87
+ # extract individual class probabilities as plain Python floats
88
+ prob_non_cervix = float(probs[0])
89
+ prob_cervix = float(probs[1])
90
 
91
+ print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
 
92
 
93
+ # determine the final prediction label
 
94
  if prob_cervix > prob_non_cervix:
95
  prediction_text = "Cervix Detected"
96
  else:
97
  prediction_text = "Non-Cervix"
98
 
99
+ # build a dictionary for gradio's Label component (displays as bar chart)
 
 
100
  scores = {
101
  "Cervix": round(prob_cervix, 4),
102
+ "Non-Cervix": round(prob_non_cervix, 4),
103
  }
104
 
 
 
105
  return scores, prediction_text
106
 
107
 
 
109
  # GRADIO USER INTERFACE
110
  # ------------------------------------
111
 
 
 
112
  with gr.Blocks(theme=gr.themes.Soft()) as app:
113
 
 
114
  gr.Markdown("""
115
  # Gatekeeper Model
116
  ### Cervix Image Binary Classifier
 
118
  ---
119
  """)
120
 
 
121
  with gr.Row():
122
 
 
123
  with gr.Column():
 
 
 
 
124
  input_image = gr.Image(
125
  label="Upload Image",
126
  type="numpy"
127
  )
 
 
 
 
128
  classify_btn = gr.Button(
129
  "Run Classification",
130
  variant="primary",
131
  size="lg"
132
  )
 
 
 
133
  clear_btn = gr.Button(
134
  "Clear",
135
  variant="secondary",
136
  size="sm"
137
  )
138
 
 
139
  with gr.Column():
 
 
 
140
  output_scores = gr.Label(
141
  label="Confidence Scores",
142
  num_top_classes=2
143
  )
 
 
 
 
144
  output_text = gr.Textbox(
145
  label="Prediction",
146
  interactive=False,
147
  text_align="center"
148
  )
149
 
 
 
150
  gr.Markdown("""
151
  ---
152
  | Index | Label | Meaning |
 
159
  It is not intended for clinical diagnosis or medical use.
160
  """)
161
 
 
 
 
 
 
 
 
162
  classify_btn.click(
163
  fn=classify_image,
164
  inputs=input_image,
165
  outputs=[output_scores, output_text]
166
  )
167
 
 
 
 
 
168
  clear_btn.click(
169
  fn=lambda: (None, None, ""),
170
  inputs=None,
171
  outputs=[input_image, output_scores, output_text]
172
  )
173
 
 
 
174
  app.launch()