Spaces:
Running
Running
File size: 3,919 Bytes
3d40fd0 555c25e 3d40fd0 555c25e c5ad6cb 3d40fd0 555c25e c5ad6cb 555c25e c5ad6cb 555c25e c5ad6cb 555c25e c5ad6cb 555c25e 92b3bf5 770b5e4 92b3bf5 770b5e4 92b3bf5 555c25e c5ad6cb 555c25e 770b5e4 c5ad6cb 770b5e4 c5ad6cb 555c25e c5ad6cb 555c25e 3d40fd0 555c25e 770b5e4 555c25e 92b3bf5 770b5e4 555c25e 3d40fd0 555c25e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | # gradio is the library used to build the web interface
import gradio as gr
# numpy is used for numerical operations
import numpy as np
# ai_edge_litert is Google's official TFLite runtime
from ai_edge_litert.interpreter import Interpreter
# PIL is used for image loading and conversion
from PIL import Image
# ------------------------------------
# LOAD THE MODEL
# ------------------------------------
interpreter = Interpreter(model_path="resnet50_float32.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
INPUT_SIZE = (224, 224)
print("Gatekeeper model loaded successfully")
# ------------------------------------
# THRESHOLD
# ------------------------------------
# cervix must score at least 0.55 to be accepted as a positive detection
CERVIX_THRESHOLD = 0.55
# ------------------------------------
# IMAGE PREPROCESSING FUNCTION
# ------------------------------------
def preprocess_image(image):
img = Image.fromarray(image).convert("RGB").resize(INPUT_SIZE)
img = np.array(img, dtype=np.float32) / 255.0
img = np.expand_dims(img, axis=0)
return img
# ------------------------------------
# CLASSIFICATION FUNCTION
# ------------------------------------
def classify_image(image):
if image is None:
return None, "Please upload an image first"
# preprocess and run inference
processed = preprocess_image(image)
interpreter.set_tensor(input_details[0]['index'], processed)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
print(f"Raw model output: {output}")
prob_non_cervix = float(output[0][0])
prob_cervix = float(output[0][1])
print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}")
# simple threshold check
if prob_cervix >= CERVIX_THRESHOLD:
prediction_text = "Cervix Detected"
else:
prediction_text = "Non-Cervix"
scores = {
"Cervix": round(prob_cervix, 4),
"Non-Cervix": round(prob_non_cervix, 4),
}
return scores, prediction_text
# ------------------------------------
# GRADIO USER INTERFACE
# ------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("""
# Gatekeeper Model
### Cervix Image Binary Classifier
Upload an image to classify it as Cervix or Non-Cervix
---
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="numpy"
)
classify_btn = gr.Button(
"Run Classification",
variant="primary",
size="lg"
)
clear_btn = gr.Button(
"Clear",
variant="secondary",
size="sm"
)
with gr.Column():
output_scores = gr.Label(
label="Confidence Scores",
num_top_classes=2
)
output_text = gr.Textbox(
label="Prediction",
interactive=False,
text_align="center"
)
gr.Markdown("""
---
| Index | Label | Meaning |
|-------|-------------|----------------------------------|
| 0 | Non-Cervix | Image does NOT contain cervix |
| 1 | Cervix | Image contains cervix |
---
Disclaimer: This tool is for research purposes only.
It is not intended for clinical diagnosis or medical use.
""")
classify_btn.click(
fn=classify_image,
inputs=input_image,
outputs=[output_scores, output_text]
)
clear_btn.click(
fn=lambda: (None, None, ""),
inputs=None,
outputs=[input_image, output_scores, output_text]
)
app.launch() |