from fastapi import FastAPI, File, UploadFile, HTTPException import tensorflow as tf import numpy as np from PIL import Image import io import sys import gradio as gr from tensorflow.keras.applications.resnet50 import preprocess_input # ========================= # 1. FastAPI Init # ========================= app = FastAPI(title="Ashoka Buried Penis Classifier API") # ========================= # 2. Load Model # ========================= print("Loading model...") try: model = tf.keras.models.load_model("cnn_kfold_best_model.h5") print("Model loaded successfully") except Exception as e: print("Failed to load model:", e) sys.exit(1) class_names = ["Normal", "Buried"] # ========================= # 3. Preprocessing # ========================= def prepare_image(image: Image.Image): image = image.convert("RGB") image = image.resize((224, 224)) img_array = np.array(image) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) return img_array # ========================= # 4. Prediction Logic # ========================= def predict_image(image): if image is None: return "No image uploaded", 0.0, 0.0 processed = prepare_image(image) prediction = model.predict(processed)[0][0] prob_buried = float(prediction * 100) prob_normal = float((1 - prediction) * 100) label = "Buried Penis" if prediction > 0.5 else "Normal" return label, round(prob_normal, 2), round(prob_buried, 2) # ========================= # 5. FastAPI Endpoint # ========================= @app.post("/predict") async def api_predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) label, normal, buried = predict_image(image) return { "class": label, "probabilities": { "normal": normal, "buried": buried } } # ========================= # 6. Gradio UI # ========================= with gr.Blocks() as demo: gr.Markdown(""" # Ashoka Hipospadia Classifier API - ResNet50 **Medical screening tool for Buried Penis** ⚠️ This tool is **NOT a diagnostic device**. Results must be interpreted by **qualified medical professionals**. """) with gr.Row(): image_input = gr.Image( type="pil", label="Upload / Drag & Drop Medical Image" ) classify_btn = gr.Button("Analyze Image") result_label = gr.Textbox(label="Prediction Result") prob_normal = gr.Number(label="Normal Probability (%)") prob_buried = gr.Number(label="Buried Probability (%)") classify_btn.click( fn=predict_image, inputs=image_input, outputs=[result_label, prob_normal, prob_buried] ) # ========================= # 7. Mount Gradio to FastAPI # ========================= app = gr.mount_gradio_app(app, demo, path="/")