Spaces:
Sleeping
Sleeping
File size: 2,920 Bytes
aa4df94 d77d1ca 5902df4 d7e9d91 def1723 d77d1ca d7e9d91 d77d1ca d7e9d91 5902df4 d7e9d91 5902df4 def1723 d7e9d91 5902df4 d7e9d91 def1723 5902df4 d7e9d91 e07fa3a d7e9d91 5902df4 d77d1ca d7e9d91 def1723 d7e9d91 def1723 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | from fastapi import FastAPI, File, UploadFile, HTTPException
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import sys
import gradio as gr
from tensorflow.keras.applications.resnet50 import preprocess_input
# =========================
# 1. FastAPI Init
# =========================
app = FastAPI(title="Ashoka Buried Penis Classifier API")
# =========================
# 2. Load Model
# =========================
print("Loading model...")
try:
model = tf.keras.models.load_model("cnn_kfold_best_model.h5")
print("Model loaded successfully")
except Exception as e:
print("Failed to load model:", e)
sys.exit(1)
class_names = ["Normal", "Buried"]
# =========================
# 3. Preprocessing
# =========================
def prepare_image(image: Image.Image):
image = image.convert("RGB")
image = image.resize((224, 224))
img_array = np.array(image)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
return img_array
# =========================
# 4. Prediction Logic
# =========================
def predict_image(image):
if image is None:
return "No image uploaded", 0.0, 0.0
processed = prepare_image(image)
prediction = model.predict(processed)[0][0]
prob_buried = float(prediction * 100)
prob_normal = float((1 - prediction) * 100)
label = "Buried Penis" if prediction > 0.5 else "Normal"
return label, round(prob_normal, 2), round(prob_buried, 2)
# =========================
# 5. FastAPI Endpoint
# =========================
@app.post("/predict")
async def api_predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
label, normal, buried = predict_image(image)
return {
"class": label,
"probabilities": {
"normal": normal,
"buried": buried
}
}
# =========================
# 6. Gradio UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
# Ashoka Hipospadia Classifier API - ResNet50
**Medical screening tool for Buried Penis**
⚠️ This tool is **NOT a diagnostic device**.
Results must be interpreted by **qualified medical professionals**.
""")
with gr.Row():
image_input = gr.Image(
type="pil",
label="Upload / Drag & Drop Medical Image"
)
classify_btn = gr.Button("Analyze Image")
result_label = gr.Textbox(label="Prediction Result")
prob_normal = gr.Number(label="Normal Probability (%)")
prob_buried = gr.Number(label="Buried Probability (%)")
classify_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[result_label, prob_normal, prob_buried]
)
# =========================
# 7. Mount Gradio to FastAPI
# =========================
app = gr.mount_gradio_app(app, demo, path="/")
|