File size: 5,350 Bytes
24c865a
 
 
0a36196
24c865a
 
 
 
0a36196
 
24c865a
 
 
0a36196
24c865a
 
 
 
0a36196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c865a
0a36196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c865a
0a36196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24c865a
 
0a36196
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from ultralytics import YOLO
import gradio as gr
from PIL import Image
import torch  # (kept, since you already imported it)

# Load model
model = YOLO("best.pt")  # make sure best.pt is in the same folder


# Prediction function (same idea: YOLO predict -> plot -> return image + label)
def predict(inp):
    if inp is None:
        return None, {}

    results = model.predict(source=inp, conf=0.5, iou=0.5, imgsz=640)
    r = results[0]
    output_img = r.plot()[:, :, ::-1]  # convert BGR to RGB for Gradio

    # Convert annotated output to PIL
    annotated = Image.fromarray(output_img)

    # Build label dictionary for gr.Label (class -> confidence)
    label_dict = {}

    # If detections exist, aggregate by class using highest confidence
    if hasattr(r, "boxes") and r.boxes is not None and len(r.boxes) > 0:
        names = model.names if hasattr(model, "names") else {}
        cls = r.boxes.cls.tolist()
        conf = r.boxes.conf.tolist()

        best = {}
        for c, p in zip(cls, conf):
            c = int(c)
            best[c] = max(best.get(c, 0.0), float(p))

        # Sort top predictions by confidence
        for c, p in sorted(best.items(), key=lambda x: x[1], reverse=True):
            label_dict[names.get(c, str(c))] = p

    return annotated, label_dict


# -----------------------------
# UI ONLY (layout + styling)
# -----------------------------
CSS = """
:root { --radius: 16px; }

#page {
  max-width: 1200px;
  margin: 0 auto;
}

.hero {
  padding: 18px 18px 8px 18px;
  border-radius: var(--radius);
  border: 1px solid rgba(255,255,255,0.08);
  background: linear-gradient(180deg, rgba(255,126,0,0.10), rgba(0,0,0,0.0));
}

.hero h1 {
  font-size: 28px;
  line-height: 1.1;
  margin: 0 0 6px 0;
}

.hero p {
  margin: 0;
  opacity: 0.9;
}

.card {
  border-radius: var(--radius) !important;
  border: 1px solid rgba(255,255,255,0.10) !important;
  background: rgba(255,255,255,0.03) !important;
}

.btn-primary button {
  background: linear-gradient(90deg, #ff7a00, #ff4d00) !important;
  border: 0 !important;
  border-radius: 14px !important;
  font-weight: 700 !important;
}

.btn-ghost button {
  border-radius: 14px !important;
  font-weight: 600 !important;
}

.small-muted {
  font-size: 12px;
  opacity: 0.75;
}

.gradio-container .prose h3 {
  margin-top: 6px !important;
}
"""

with gr.Blocks(
    css=CSS,
    theme=gr.themes.Soft(
        radius_size="lg",
        text_size="md"
    ),
) as demo:
    with gr.Column(elem_id="page"):
        gr.HTML(
            """
            <div class="hero">
              <h1>bonsAI Pill Detection</h1>
              <p>
                Upload an image of a pill. The YOLOv12 model detects and classifies pill types commonly found in the Philippines.
                This study aims to automate pill recognition for pharmaceutical verification and healthcare support.
              </p>
              <p class="small-muted" style="margin-top:10px;">
                Tip: Use clear lighting, avoid blur, and keep pills centered for best results.
              </p>
            </div>
            """
        )

        with gr.Row(equal_height=True):
            # LEFT: Upload + Actions
            with gr.Column(scale=5):
                with gr.Group(elem_classes=["card"]):
                    gr.Markdown("#### Upload Pill Image")
                    inp = gr.Image(
                        type="pil",
                        label="Drop image here or click to upload",
                        height=330,
                    )

                with gr.Row():
                    clear_btn = gr.Button("Clear", elem_classes=["btn-ghost"], scale=1)
                    submit_btn = gr.Button("Submit", elem_classes=["btn-primary"], scale=1)

            # RIGHT: Results
            with gr.Column(scale=5):
                with gr.Group(elem_classes=["card"]):
                    gr.Markdown("#### Results")
                    with gr.Tabs():
                        with gr.Tab("Detected Pills"):
                            out_img = gr.Image(
                                label="Detected Pills",
                                height=330
                            )
                        with gr.Tab("Predictions"):
                            out_lbl = gr.Label(
                                label="Predictions",
                                num_top_classes=5
                            )

        with gr.Accordion("Study Summary", open=True, elem_classes=["card"]):
            gr.Markdown(
                "The bonsAI project demonstrates the application of YOLOv12 in real-time pill classification "
                "and segmentation. By training on the Pharmaceutical Drugs and Vitamins Dataset (Version 2), "
                "the system accurately identifies tablets and capsules across 20 classes using bounding boxes "
                "and mask segmentation. The model achieved high mAP and F1-scores, confirming its potential "
                "for aiding pharmacists and healthcare providers in ensuring drug authenticity and preventing "
                "dispensing errors."
            )

        # Events (UI wiring only)
        submit_btn.click(fn=predict, inputs=inp, outputs=[out_img, out_lbl])
        clear_btn.click(fn=lambda: (None, {}), inputs=None, outputs=[inp, out_img, out_lbl])

demo.launch()