Project-Ashoka / app.py
Xaviant's picture
Update app.py
a1adf42 verified
raw
history blame
2.82 kB
from PIL import Image
import io
import sys
import gradio as gr
from tensorflow.keras.applications.densenet import preprocess_input
# =========================
# 1. FastAPI Init
# =========================
app = FastAPI(title="Ashoka Buried Penis Classifier API")
# =========================
# 2. Load Model
# =========================
print("Loading model...")
try:
model = tf.keras.models.load_model("cnn_kfold_best_model_v2.h5")
print("Model loaded successfully")
except Exception as e:
print("Failed to load model:", e)
sys.exit(1)
class_names = ["Normal", "Buried"]
# =========================
# 3. Preprocessing
# =========================
def prepare_image(image: Image.Image):
image = image.convert("RGB")
image = image.resize((224, 224))
img_array = np.array(image)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
return img_array
# =========================
# 4. Prediction Logic
# =========================
def predict_image(image):
if image is None:
return "No image uploaded", 0.0, 0.0
processed = prepare_image(image)
prediction = model.predict(processed)[0][0]
prob_buried = float(prediction * 100)
prob_normal = float((1 - prediction) * 100)
label = "Buried Penis" if prediction > 0.5 else "Normal"
return label, round(prob_normal, 2), round(prob_buried, 2)
# =========================
# 5. FastAPI Endpoint
# =========================
@app.post("/predict")
async def api_predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
label, normal, buried = predict_image(image)
return {
"class": label,
"probabilities": {
"normal": normal,
"buried": buried
}
}
# =========================
# 6. Gradio UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
# Ashoka Hipospadia Classifier API - DenseNet
**Medical screening tool for Buried Penis**
⚠️ This tool is **NOT a diagnostic device**.
Results must be interpreted by **qualified medical professionals**.
""")
with gr.Row():
image_input = gr.Image(
type="pil",
label="Upload / Drag & Drop Medical Image"
)
classify_btn = gr.Button("Analyze Image")
result_label = gr.Textbox(label="Prediction Result")
prob_normal = gr.Number(label="Normal Probability (%)")
prob_buried = gr.Number(label="Buried Probability (%)")
classify_btn.click(
fn=predict_image,
inputs=image_input,
outputs=[result_label, prob_normal, prob_buried]
)
# =========================
# 7. Mount Gradio to FastAPI
# =========================
app = gr.mount_gradio_app(app, demo, path="/")