from PIL import Image import io import sys import gradio as gr from tensorflow.keras.applications.densenet import preprocess_input # ========================= # 1. FastAPI Init # ========================= app = FastAPI(title="Ashoka Buried Penis Classifier API") # ========================= # 2. Load Model # ========================= print("Loading model...") try: model = tf.keras.models.load_model("cnn_kfold_best_model_v2.h5") print("Model loaded successfully") except Exception as e: print("Failed to load model:", e) sys.exit(1) class_names = ["Normal", "Buried"] # ========================= # 3. Preprocessing # ========================= def prepare_image(image: Image.Image): image = image.convert("RGB") image = image.resize((224, 224)) img_array = np.array(image) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) return img_array # ========================= # 4. Prediction Logic # ========================= def predict_image(image): if image is None: return "No image uploaded", 0.0, 0.0 processed = prepare_image(image) prediction = model.predict(processed)[0][0] prob_buried = float(prediction * 100) prob_normal = float((1 - prediction) * 100) label = "Buried Penis" if prediction > 0.5 else "Normal" return label, round(prob_normal, 2), round(prob_buried, 2) # ========================= # 5. FastAPI Endpoint # ========================= @app.post("/predict") async def api_predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) label, normal, buried = predict_image(image) return { "class": label, "probabilities": { "normal": normal, "buried": buried } } # ========================= # 6. Gradio UI # ========================= with gr.Blocks() as demo: gr.Markdown(""" # Ashoka Hipospadia Classifier API - DenseNet **Medical screening tool for Buried Penis** ⚠️ This tool is **NOT a diagnostic device**. Results must be interpreted by **qualified medical professionals**. """) with gr.Row(): image_input = gr.Image( type="pil", label="Upload / Drag & Drop Medical Image" ) classify_btn = gr.Button("Analyze Image") result_label = gr.Textbox(label="Prediction Result") prob_normal = gr.Number(label="Normal Probability (%)") prob_buried = gr.Number(label="Buried Probability (%)") classify_btn.click( fn=predict_image, inputs=image_input, outputs=[result_label, prob_normal, prob_buried] ) # ========================= # 7. Mount Gradio to FastAPI # ========================= app = gr.mount_gradio_app(app, demo, path="/")