kmunzwa's picture
Update app.py
770b5e4 verified
# gradio is the library used to build the web interface
import gradio as gr
# numpy is used for numerical operations
import numpy as np
# ai_edge_litert is Google's official TFLite runtime
from ai_edge_litert.interpreter import Interpreter
# PIL is used for image loading and conversion
from PIL import Image
# ------------------------------------
# LOAD THE MODEL
# ------------------------------------
interpreter = Interpreter(model_path="resnet50_float32.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
INPUT_SIZE = (224, 224)
print("Gatekeeper model loaded successfully")
# ------------------------------------
# THRESHOLD
# ------------------------------------
# cervix must score at least 0.55 to be accepted as a positive detection
CERVIX_THRESHOLD = 0.55
# ------------------------------------
# IMAGE PREPROCESSING FUNCTION
# ------------------------------------
def preprocess_image(image):
img = Image.fromarray(image).convert("RGB").resize(INPUT_SIZE)
img = np.array(img, dtype=np.float32) / 255.0
img = np.expand_dims(img, axis=0)
return img
# ------------------------------------
# CLASSIFICATION FUNCTION
# ------------------------------------
def classify_image(image):
if image is None:
return None, "Please upload an image first"
# preprocess and run inference
processed = preprocess_image(image)
interpreter.set_tensor(input_details[0]['index'], processed)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
print(f"Raw model output: {output}")
prob_non_cervix = float(output[0][0])
prob_cervix = float(output[0][1])
print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
# simple threshold check
if prob_cervix >= CERVIX_THRESHOLD:
prediction_text = "Cervix Detected"
else:
prediction_text = "Non-Cervix"
scores = {
"Cervix": round(prob_cervix, 4),
"Non-Cervix": round(prob_non_cervix, 4),
}
return scores, prediction_text
# ------------------------------------
# GRADIO USER INTERFACE
# ------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("""
# Gatekeeper Model
### Cervix Image Binary Classifier
Upload an image to classify it as Cervix or Non-Cervix
---
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="numpy"
)
classify_btn = gr.Button(
"Run Classification",
variant="primary",
size="lg"
)
clear_btn = gr.Button(
"Clear",
variant="secondary",
size="sm"
)
with gr.Column():
output_scores = gr.Label(
label="Confidence Scores",
num_top_classes=2
)
output_text = gr.Textbox(
label="Prediction",
interactive=False,
text_align="center"
)
gr.Markdown("""
---
| Index | Label | Meaning |
|-------|-------------|----------------------------------|
| 0 | Non-Cervix | Image does NOT contain cervix |
| 1 | Cervix | Image contains cervix |
---
Disclaimer: This tool is for research purposes only.
It is not intended for clinical diagnosis or medical use.
""")
classify_btn.click(
fn=classify_image,
inputs=input_image,
outputs=[output_scores, output_text]
)
clear_btn.click(
fn=lambda: (None, None, ""),
inputs=None,
outputs=[input_image, output_scores, output_text]
)
app.launch()