kmunzwa commited on
Commit
555c25e
·
verified ·
1 Parent(s): 6c05378

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We import gradio which is the library we use to build the web interface
2
+ import gradio as gr
3
+
4
+ # numpy is used for numerical operations and array manipulation
5
+ import numpy as np
6
+
7
+ # tensorflow is used to load and run the tflite model
8
+ import tensorflow as tf
9
+
10
+ # PIL (Pillow) is used to handle image loading and resizing
11
+ from PIL import Image
12
+
13
+ # os is used to set environment variables
14
+ import os
15
+
16
+ # This line tells TensorFlow to suppress unnecessary log messages
17
+ # 3 means only show critical errors
18
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19
+
20
+
21
+ # ------------------------------------
22
+ # LOAD THE MODEL
23
+ # ------------------------------------
24
+
25
+ # This loads the tflite model file from the current directory
26
+ # The Interpreter class is what TensorFlow Lite uses to run models
27
+ interpreter = tf.lite.Interpreter(model_path="gatekeeper_model.tflite")
28
+
29
+ # This allocates memory for the model's input and output tensors
30
+ # You must always call this before running inference
31
+ interpreter.allocate_tensors()
32
+
33
+ # This gets the details of the input tensor
34
+ # It tells us the expected shape, data type, and index of the input
35
+ input_details = interpreter.get_input_details()
36
+
37
+ # This gets the details of the output tensor
38
+ # It tells us the shape and index of the output so we can read predictions
39
+ output_details = interpreter.get_output_details()
40
+
41
+ # This is the image size the model expects
42
+ # ResNet50 was trained on 224x224 images so we keep this the same
43
+ INPUT_SIZE = (224, 224)
44
+
45
+
46
+ # ------------------------------------
47
+ # IMAGE PREPROCESSING FUNCTION
48
+ # ------------------------------------
49
+
50
+ def preprocess_image(image):
51
+ # image comes in as a numpy array from gradio
52
+ # we convert it to a PIL Image object so we can resize it easily
53
+ # we also make sure it is in RGB format (3 channels: red, green, blue)
54
+ img = Image.fromarray(image).convert("RGB")
55
+
56
+ # we resize the image to match what the model expects
57
+ # if the image is not 224x224 the model will throw a shape error
58
+ img = img.resize(INPUT_SIZE)
59
+
60
+ # we convert the PIL image back to a numpy array
61
+ # dtype=np.float32 is important because the model expects 32-bit floats
62
+ img = np.array(img, dtype=np.float32)
63
+
64
+ # we divide all pixel values by 255
65
+ # this converts pixel values from the range [0, 255] to [0, 1]
66
+ # this is called normalization and it helps the model perform correctly
67
+ img = img / 255.0
68
+
69
+ # the model expects a batch of images, not a single image
70
+ # so we add an extra dimension at position 0
71
+ # this changes the shape from (224, 224, 3) to (1, 224, 224, 3)
72
+ # the 1 represents a batch size of 1 (one image at a time)
73
+ img = np.expand_dims(img, axis=0)
74
+
75
+ # we return the fully preprocessed image ready for inference
76
+ return img
77
+
78
+
79
+ # ------------------------------------
80
+ # CLASSIFICATION FUNCTION
81
+ # ------------------------------------
82
+
83
+ def classify_image(image):
84
+ # if the user clicks the button without uploading an image
85
+ # we return None for the scores and a warning message
86
+ if image is None:
87
+ return None, "Please upload an image first"
88
+
89
+ # we send the image through our preprocessing function
90
+ processed = preprocess_image(image)
91
+
92
+ # we load the preprocessed image into the model's input tensor
93
+ # input_details[0]['index'] gives us the correct tensor index to write to
94
+ interpreter.set_tensor(input_details[0]['index'], processed)
95
+
96
+ # this actually runs the model on the input we just loaded
97
+ interpreter.invoke()
98
+
99
+ # this reads the output from the model after inference is complete
100
+ # output_details[0]['index'] gives us the correct tensor index to read from
101
+ output = interpreter.get_tensor(output_details[0]['index'])
102
+
103
+ # we print the raw output to the console for debugging purposes
104
+ # this is useful to confirm the model is producing expected values
105
+ print(f"Raw model output: {output}")
106
+
107
+ # index 0 of the output corresponds to Non-Cervix probability
108
+ # we convert it to a plain Python float for easier handling
109
+ prob_non_cervix = float(output[0][0])
110
+
111
+ # index 1 of the output corresponds to Cervix probability
112
+ prob_cervix = float(output[0][1])
113
+
114
+ # we compare the two probabilities to determine the final prediction
115
+ # whichever class has the higher probability is our prediction
116
+ if prob_cervix > prob_non_cervix:
117
+ prediction_text = "Cervix Detected"
118
+ else:
119
+ prediction_text = "Non-Cervix"
120
+
121
+ # we build a dictionary of class names mapped to their confidence scores
122
+ # gradio's Label component accepts this format and displays it as a bar chart
123
+ # we round to 4 decimal places to keep the display clean
124
+ scores = {
125
+ "Cervix": round(prob_cervix, 4),
126
+ "Non-Cervix": round(prob_non_cervix, 4)
127
+ }
128
+
129
+ # we return both the scores dictionary and the prediction text
130
+ # these map to the two output components in the gradio interface
131
+ return scores, prediction_text
132
+
133
+
134
+ # ------------------------------------
135
+ # GRADIO USER INTERFACE
136
+ # ------------------------------------
137
+
138
+ # gr.Blocks gives us full control over the layout of the interface
139
+ # theme=gr.themes.Soft() gives it a clean and soft visual style
140
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
141
+
142
+ # gr.Markdown renders formatted text at the top of the page
143
+ gr.Markdown("""
144
+ # Gatekeeper Model
145
+ ### Cervix Image Binary Classifier
146
+ Upload an image to classify it as Cervix or Non-Cervix
147
+ ---
148
+ """)
149
+
150
+ # gr.Row arranges the components inside it horizontally side by side
151
+ with gr.Row():
152
+
153
+ # the first column holds the input components on the left side
154
+ with gr.Column():
155
+
156
+ # gr.Image creates an image upload box
157
+ # type="numpy" means the image will be passed to our function
158
+ # as a numpy array which is what we need for preprocessing
159
+ input_image = gr.Image(
160
+ label="Upload Image",
161
+ type="numpy"
162
+ )
163
+
164
+ # this is the main button the user clicks to run the model
165
+ # variant="primary" makes it stand out visually as the main action
166
+ # size="lg" makes it large and easy to click
167
+ classify_btn = gr.Button(
168
+ "Run Classification",
169
+ variant="primary",
170
+ size="lg"
171
+ )
172
+
173
+ # this is a secondary button to reset the interface
174
+ # variant="secondary" gives it a less prominent visual style
175
+ clear_btn = gr.Button(
176
+ "Clear",
177
+ variant="secondary",
178
+ size="sm"
179
+ )
180
+
181
+ # the second column holds the output components on the right side
182
+ with gr.Column():
183
+
184
+ # gr.Label displays the confidence scores as a visual bar chart
185
+ # num_top_classes=2 tells it to show both classes
186
+ output_scores = gr.Label(
187
+ label="Confidence Scores",
188
+ num_top_classes=2
189
+ )
190
+
191
+ # gr.Textbox displays the final prediction as plain text
192
+ # interactive=False means the user cannot edit it
193
+ # it is read-only output only
194
+ output_text = gr.Textbox(
195
+ label="Prediction",
196
+ interactive=False,
197
+ text_align="center"
198
+ )
199
+
200
+ # this adds a reference table at the bottom so users understand
201
+ # what the two class indices mean
202
+ gr.Markdown("""
203
+ ---
204
+ | Index | Label | Meaning |
205
+ |-------|-------------|----------------------------------|
206
+ | 0 | Non-Cervix | Image does NOT contain cervix |
207
+ | 1 | Cervix | Image contains cervix |
208
+
209
+ ---
210
+ Disclaimer: This tool is for research purposes only.
211
+ It is not intended for clinical diagnosis or medical use.
212
+ """)
213
+
214
+ # ------------------------------------
215
+ # BUTTON ACTIONS
216
+ # ------------------------------------
217
+
218
+ # this connects the classify button to the classify_image function
219
+ # inputs tells gradio which component to read from
220
+ # outputs tells gradio which components to write the results to
221
+ classify_btn.click(
222
+ fn=classify_image,
223
+ inputs=input_image,
224
+ outputs=[output_scores, output_text]
225
+ )
226
+
227
+ # this connects the clear button to a simple lambda function
228
+ # a lambda is a small anonymous function defined in one line
229
+ # it returns None for the image, None for scores, and empty string for text
230
+ # this effectively resets all three components back to their empty state
231
+ clear_btn.click(
232
+ fn=lambda: (None, None, ""),
233
+ inputs=None,
234
+ outputs=[input_image, output_scores, output_text]
235
+ )
236
+
237
+ # this starts the gradio web server and launches the interface
238
+ # on hugging face spaces this is called automatically
239
+ app.launch()