Project-Ashoka / app.py
Dinaliah's picture
Update app.py (#7)
def1723 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import sys
import gradio as gr
from tensorflow.keras.applications.resnet50 import preprocess_input
# =========================
# 1. FastAPI Init
# =========================
app = FastAPI(title="Ashoka Buried Penis Classifier API")
# =========================
# 2. Load Model
# =========================
print("Loading model...")
try:
model = tf.keras.models.load_model("cnn_kfold_best_model.h5")
print("Model loaded successfully")
except Exception as e:
print("Failed to load model:", e)
sys.exit(1)
class_names = ["Normal", "Buried"]
# =========================
# 3. Preprocessing
# =========================
def prepare_image(image: Image.Image):
image = image.convert("RGB")
image = image.resize((224, 224))
img_array = np.array(image)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
return img_array
# =========================
# 4. Prediction Logic
# =========================
def predict_image(image):
if image is None:
return "No image uploaded", 0.0, 0.0
processed = prepare_image(image)
prediction = model.predict(processed)[0][0]
prob_buried = float(prediction * 100)
prob_normal = float((1 - prediction) * 100)
label = "Buried Penis" if prediction > 0.5 else "Normal"
return label, round(prob_normal, 2), round(prob_buried, 2)
# =========================
# 5. FastAPI Endpoint
# =========================
@app.post("/predict")
async def api_predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
label, normal, buried = predict_image(image)
return {
"class": label,
"probabilities": {
"normal": normal,
"buried": buried
}
}
# =========================
# 6. Gradio UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
# Ashoka Hipospadia Classifier API - ResNet50
**Medical screening tool for Buried Penis**
⚠️ This tool is **NOT a diagnostic device**.
Results must be interpreted by **qualified medical professionals**.
""")
with gr.Row():
image_input = gr.Image(
type="pil",
label="Upload / Drag & Drop Medical Image"
)
classify_btn = gr.Button("Analyze Image")
result_label = gr.Textbox(label="Prediction Result")
prob_normal = gr.Number(label="Normal Probability (%)")
prob_buried = gr.Number(label="Buried Probability (%)")
classify_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[result_label, prob_normal, prob_buried]
)
# =========================
# 7. Mount Gradio to FastAPI
# =========================
app = gr.mount_gradio_app(app, demo, path="/")