Spaces:
Running
Running
| # gradio is the library used to build the web interface | |
| import gradio as gr | |
| # numpy is used for numerical operations | |
| import numpy as np | |
| # ai_edge_litert is Google's official TFLite runtime | |
| from ai_edge_litert.interpreter import Interpreter | |
| # PIL is used for image loading and conversion | |
| from PIL import Image | |
| # ------------------------------------ | |
| # LOAD THE MODEL | |
| # ------------------------------------ | |
| interpreter = Interpreter(model_path="resnet50_float32.tflite") | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| INPUT_SIZE = (224, 224) | |
| print("Gatekeeper model loaded successfully") | |
| # ------------------------------------ | |
| # THRESHOLD | |
| # ------------------------------------ | |
| # cervix must score at least 0.55 to be accepted as a positive detection | |
| CERVIX_THRESHOLD = 0.55 | |
| # ------------------------------------ | |
| # IMAGE PREPROCESSING FUNCTION | |
| # ------------------------------------ | |
| def preprocess_image(image): | |
| img = Image.fromarray(image).convert("RGB").resize(INPUT_SIZE) | |
| img = np.array(img, dtype=np.float32) / 255.0 | |
| img = np.expand_dims(img, axis=0) | |
| return img | |
| # ------------------------------------ | |
| # CLASSIFICATION FUNCTION | |
| # ------------------------------------ | |
| def classify_image(image): | |
| if image is None: | |
| return None, "Please upload an image first" | |
| # preprocess and run inference | |
| processed = preprocess_image(image) | |
| interpreter.set_tensor(input_details[0]['index'], processed) | |
| interpreter.invoke() | |
| output = interpreter.get_tensor(output_details[0]['index']) | |
| print(f"Raw model output: {output}") | |
| prob_non_cervix = float(output[0][0]) | |
| prob_cervix = float(output[0][1]) | |
| print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}") | |
| # simple threshold check | |
| if prob_cervix >= CERVIX_THRESHOLD: | |
| prediction_text = "Cervix Detected" | |
| else: | |
| prediction_text = "Non-Cervix" | |
| scores = { | |
| "Cervix": round(prob_cervix, 4), | |
| "Non-Cervix": round(prob_non_cervix, 4), | |
| } | |
| return scores, prediction_text | |
| # ------------------------------------ | |
| # GRADIO USER INTERFACE | |
| # ------------------------------------ | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # Gatekeeper Model | |
| ### Cervix Image Binary Classifier | |
| Upload an image to classify it as Cervix or Non-Cervix | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="numpy" | |
| ) | |
| classify_btn = gr.Button( | |
| "Run Classification", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "Clear", | |
| variant="secondary", | |
| size="sm" | |
| ) | |
| with gr.Column(): | |
| output_scores = gr.Label( | |
| label="Confidence Scores", | |
| num_top_classes=2 | |
| ) | |
| output_text = gr.Textbox( | |
| label="Prediction", | |
| interactive=False, | |
| text_align="center" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| | Index | Label | Meaning | | |
| |-------|-------------|----------------------------------| | |
| | 0 | Non-Cervix | Image does NOT contain cervix | | |
| | 1 | Cervix | Image contains cervix | | |
| --- | |
| Disclaimer: This tool is for research purposes only. | |
| It is not intended for clinical diagnosis or medical use. | |
| """) | |
| classify_btn.click( | |
| fn=classify_image, | |
| inputs=input_image, | |
| outputs=[output_scores, output_text] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, ""), | |
| inputs=None, | |
| outputs=[input_image, output_scores, output_text] | |
| ) | |
| app.launch() |