# gradio is the library used to build the web interface import gradio as gr # numpy is used for numerical operations import numpy as np # ai_edge_litert is Google's official TFLite runtime from ai_edge_litert.interpreter import Interpreter # PIL is used for image loading and conversion from PIL import Image # ------------------------------------ # LOAD THE MODEL # ------------------------------------ interpreter = Interpreter(model_path="resnet50_float32.tflite") interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() INPUT_SIZE = (224, 224) print("Gatekeeper model loaded successfully") # ------------------------------------ # THRESHOLD # ------------------------------------ # cervix must score at least 0.55 to be accepted as a positive detection CERVIX_THRESHOLD = 0.55 # ------------------------------------ # IMAGE PREPROCESSING FUNCTION # ------------------------------------ def preprocess_image(image): img = Image.fromarray(image).convert("RGB").resize(INPUT_SIZE) img = np.array(img, dtype=np.float32) / 255.0 img = np.expand_dims(img, axis=0) return img # ------------------------------------ # CLASSIFICATION FUNCTION # ------------------------------------ def classify_image(image): if image is None: return None, "Please upload an image first" # preprocess and run inference processed = preprocess_image(image) interpreter.set_tensor(input_details[0]['index'], processed) interpreter.invoke() output = interpreter.get_tensor(output_details[0]['index']) print(f"Raw model output: {output}") prob_non_cervix = float(output[0][0]) prob_cervix = float(output[0][1]) print(f"Non-Cervix: {prob_non_cervix:.4f} | Cervix: {prob_cervix:.4f}") # simple threshold check if prob_cervix >= CERVIX_THRESHOLD: prediction_text = "Cervix Detected" else: prediction_text = "Non-Cervix" scores = { "Cervix": round(prob_cervix, 4), "Non-Cervix": round(prob_non_cervix, 4), } return scores, prediction_text # ------------------------------------ # GRADIO USER INTERFACE # ------------------------------------ with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown(""" # Gatekeeper Model ### Cervix Image Binary Classifier Upload an image to classify it as Cervix or Non-Cervix --- """) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Image", type="numpy" ) classify_btn = gr.Button( "Run Classification", variant="primary", size="lg" ) clear_btn = gr.Button( "Clear", variant="secondary", size="sm" ) with gr.Column(): output_scores = gr.Label( label="Confidence Scores", num_top_classes=2 ) output_text = gr.Textbox( label="Prediction", interactive=False, text_align="center" ) gr.Markdown(""" --- | Index | Label | Meaning | |-------|-------------|----------------------------------| | 0 | Non-Cervix | Image does NOT contain cervix | | 1 | Cervix | Image contains cervix | --- Disclaimer: This tool is for research purposes only. It is not intended for clinical diagnosis or medical use. """) classify_btn.click( fn=classify_image, inputs=input_image, outputs=[output_scores, output_text] ) clear_btn.click( fn=lambda: (None, None, ""), inputs=None, outputs=[input_image, output_scores, output_text] ) app.launch()