bonsAI_App / app.py
jin3213's picture
Update app.py
0a36196 verified
from ultralytics import YOLO
import gradio as gr
from PIL import Image
import torch # (kept, since you already imported it)
# Load model
model = YOLO("best.pt") # make sure best.pt is in the same folder
# Prediction function (same idea: YOLO predict -> plot -> return image + label)
def predict(inp):
if inp is None:
return None, {}
results = model.predict(source=inp, conf=0.5, iou=0.5, imgsz=640)
r = results[0]
output_img = r.plot()[:, :, ::-1] # convert BGR to RGB for Gradio
# Convert annotated output to PIL
annotated = Image.fromarray(output_img)
# Build label dictionary for gr.Label (class -> confidence)
label_dict = {}
# If detections exist, aggregate by class using highest confidence
if hasattr(r, "boxes") and r.boxes is not None and len(r.boxes) > 0:
names = model.names if hasattr(model, "names") else {}
cls = r.boxes.cls.tolist()
conf = r.boxes.conf.tolist()
best = {}
for c, p in zip(cls, conf):
c = int(c)
best[c] = max(best.get(c, 0.0), float(p))
# Sort top predictions by confidence
for c, p in sorted(best.items(), key=lambda x: x[1], reverse=True):
label_dict[names.get(c, str(c))] = p
return annotated, label_dict
# -----------------------------
# UI ONLY (layout + styling)
# -----------------------------
CSS = """
:root { --radius: 16px; }
#page {
max-width: 1200px;
margin: 0 auto;
}
.hero {
padding: 18px 18px 8px 18px;
border-radius: var(--radius);
border: 1px solid rgba(255,255,255,0.08);
background: linear-gradient(180deg, rgba(255,126,0,0.10), rgba(0,0,0,0.0));
}
.hero h1 {
font-size: 28px;
line-height: 1.1;
margin: 0 0 6px 0;
}
.hero p {
margin: 0;
opacity: 0.9;
}
.card {
border-radius: var(--radius) !important;
border: 1px solid rgba(255,255,255,0.10) !important;
background: rgba(255,255,255,0.03) !important;
}
.btn-primary button {
background: linear-gradient(90deg, #ff7a00, #ff4d00) !important;
border: 0 !important;
border-radius: 14px !important;
font-weight: 700 !important;
}
.btn-ghost button {
border-radius: 14px !important;
font-weight: 600 !important;
}
.small-muted {
font-size: 12px;
opacity: 0.75;
}
.gradio-container .prose h3 {
margin-top: 6px !important;
}
"""
with gr.Blocks(
css=CSS,
theme=gr.themes.Soft(
radius_size="lg",
text_size="md"
),
) as demo:
with gr.Column(elem_id="page"):
gr.HTML(
"""
<div class="hero">
<h1>bonsAI Pill Detection</h1>
<p>
Upload an image of a pill. The YOLOv12 model detects and classifies pill types commonly found in the Philippines.
This study aims to automate pill recognition for pharmaceutical verification and healthcare support.
</p>
<p class="small-muted" style="margin-top:10px;">
Tip: Use clear lighting, avoid blur, and keep pills centered for best results.
</p>
</div>
"""
)
with gr.Row(equal_height=True):
# LEFT: Upload + Actions
with gr.Column(scale=5):
with gr.Group(elem_classes=["card"]):
gr.Markdown("#### Upload Pill Image")
inp = gr.Image(
type="pil",
label="Drop image here or click to upload",
height=330,
)
with gr.Row():
clear_btn = gr.Button("Clear", elem_classes=["btn-ghost"], scale=1)
submit_btn = gr.Button("Submit", elem_classes=["btn-primary"], scale=1)
# RIGHT: Results
with gr.Column(scale=5):
with gr.Group(elem_classes=["card"]):
gr.Markdown("#### Results")
with gr.Tabs():
with gr.Tab("Detected Pills"):
out_img = gr.Image(
label="Detected Pills",
height=330
)
with gr.Tab("Predictions"):
out_lbl = gr.Label(
label="Predictions",
num_top_classes=5
)
with gr.Accordion("Study Summary", open=True, elem_classes=["card"]):
gr.Markdown(
"The bonsAI project demonstrates the application of YOLOv12 in real-time pill classification "
"and segmentation. By training on the Pharmaceutical Drugs and Vitamins Dataset (Version 2), "
"the system accurately identifies tablets and capsules across 20 classes using bounding boxes "
"and mask segmentation. The model achieved high mAP and F1-scores, confirming its potential "
"for aiding pharmacists and healthcare providers in ensuring drug authenticity and preventing "
"dispensing errors."
)
# Events (UI wiring only)
submit_btn.click(fn=predict, inputs=inp, outputs=[out_img, out_lbl])
clear_btn.click(fn=lambda: (None, {}), inputs=None, outputs=[inp, out_img, out_lbl])
demo.launch()