mlbench123 commited on
Commit
4752ca3
Β·
verified Β·
1 Parent(s): 99624c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -0
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ from ultralytics import YOLO
6
+
7
+ # ── Model ─────────────────────────────────────────────────────────────────────
8
+ model = YOLO("best.pt")
9
+
10
+ CLASS_NAMES = {0: "Full", 1: "Broken"}
11
+ CLASS_COLORS = {0: (34, 197, 94), 1: (239, 68, 68)} # green, red
12
+
13
+
14
+ # ── Inference ─────────────────────────────────────────────────────────────────
15
+ def predict(image: Image.Image):
16
+ if image is None:
17
+ return None, "", ""
18
+
19
+ img_np = np.array(image)
20
+ h, w = img_np.shape[:2]
21
+ results = model(img_np, imgsz=1280, conf=0.25)[0]
22
+
23
+ annotated = img_np.copy()
24
+ overlay = img_np.copy()
25
+
26
+ counts = {"Full": 0, "Broken": 0}
27
+
28
+ if results.masks is not None:
29
+ # Adaptive font scale based on image size
30
+ font_scale = max(0.35, min(0.65, w / 2000))
31
+ font_thick = 1
32
+ font = cv2.FONT_HERSHEY_SIMPLEX
33
+
34
+ for mask_tensor, box in zip(results.masks.data, results.boxes):
35
+ cls_id = int(box.cls[0])
36
+ cls_name = CLASS_NAMES.get(cls_id, "?")
37
+ color = CLASS_COLORS.get(cls_id, (200, 200, 200))
38
+
39
+ counts[cls_name] += 1
40
+
41
+ # Resize mask to image size
42
+ mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
43
+ mask_np = cv2.resize(mask_np, (w, h), interpolation=cv2.INTER_NEAREST)
44
+
45
+ # Fill overlay
46
+ overlay[mask_np == 1] = color
47
+
48
+ # Draw contour
49
+ cnts, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
50
+ cv2.drawContours(annotated, cnts, -1, color, 2)
51
+
52
+ # Label placement β€” centroid of mask
53
+ ys, xs = np.where(mask_np == 1)
54
+ if len(xs) == 0:
55
+ continue
56
+ cx = int(xs.mean())
57
+ cy = int(ys.mean())
58
+
59
+ label = cls_name
60
+ (tw, th), baseline = cv2.getTextSize(label, font, font_scale, font_thick)
61
+
62
+ # Small pill background behind text
63
+ pad = 3
64
+ cv2.rectangle(annotated,
65
+ (cx - tw // 2 - pad, cy - th - pad),
66
+ (cx + tw // 2 + pad, cy + pad),
67
+ (0, 0, 0), -1)
68
+ cv2.putText(annotated, label,
69
+ (cx - tw // 2, cy),
70
+ font, font_scale, color, font_thick, cv2.LINE_AA)
71
+
72
+ # Blend mask overlay with original
73
+ annotated = cv2.addWeighted(annotated, 0.72, overlay, 0.28, 0)
74
+
75
+ # Redraw contours on top of blend so they stay sharp
76
+ if results.masks is not None:
77
+ for mask_tensor, box in zip(results.masks.data, results.boxes):
78
+ cls_id = int(box.cls[0])
79
+ color = CLASS_COLORS.get(cls_id, (200, 200, 200))
80
+ mask_np = mask_tensor.cpu().numpy().astype(np.uint8)
81
+ mask_np = cv2.resize(mask_np, (w, h), interpolation=cv2.INTER_NEAREST)
82
+ cnts, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
83
+ cv2.drawContours(annotated, cnts, -1, color, 2)
84
+
85
+ total = counts["Full"] + counts["Broken"]
86
+ summary = f"**Total:** {total} 🟒 **Full:** {counts['Full']} πŸ”΄ **Broken:** {counts['Broken']}"
87
+
88
+ # Table markdown
89
+ table = f"""| | Count |
90
+ |---|---|
91
+ | 🌾 Total Grains | **{total}** |
92
+ | 🟒 Full Grains | **{counts['Full']}** |
93
+ | πŸ”΄ Broken Grains | **{counts['Broken']}** |"""
94
+
95
+ return Image.fromarray(annotated), summary, table
96
+
97
+
98
+ # ── Custom CSS ────────────────────────────────────────────────────────────────
99
+ css = """
100
+ @import url('https://fonts.googleapis.com/css2?family=DM+Serif+Display&family=DM+Sans:wght@300;400;500&display=swap');
101
+
102
+ * { box-sizing: border-box; }
103
+
104
+ body, .gradio-container {
105
+ background: #0c0f0a !important;
106
+ font-family: 'DM Sans', sans-serif !important;
107
+ }
108
+
109
+ .gradio-container {
110
+ max-width: 1100px !important;
111
+ margin: 0 auto !important;
112
+ padding: 0 24px !important;
113
+ }
114
+
115
+ /* Header */
116
+ #header {
117
+ text-align: center;
118
+ padding: 48px 0 32px;
119
+ border-bottom: 1px solid #1e2a1a;
120
+ margin-bottom: 36px;
121
+ }
122
+
123
+ #header h1 {
124
+ font-family: 'DM Serif Display', serif !important;
125
+ font-size: 2.6rem;
126
+ color: #e8f5e1;
127
+ letter-spacing: -0.5px;
128
+ margin: 0 0 10px;
129
+ }
130
+
131
+ #header p {
132
+ color: #6b8f5e;
133
+ font-size: 1rem;
134
+ font-weight: 300;
135
+ margin: 0;
136
+ letter-spacing: 0.3px;
137
+ }
138
+
139
+ /* Accent pill */
140
+ #header span {
141
+ display: inline-block;
142
+ background: #1a2e14;
143
+ color: #7ec86a;
144
+ font-size: 0.72rem;
145
+ font-weight: 500;
146
+ letter-spacing: 1.5px;
147
+ text-transform: uppercase;
148
+ padding: 4px 14px;
149
+ border-radius: 20px;
150
+ margin-bottom: 18px;
151
+ border: 1px solid #2d4a24;
152
+ }
153
+
154
+ /* Panels */
155
+ .panel-box {
156
+ background: #111710 !important;
157
+ border: 1px solid #1e2a1a !important;
158
+ border-radius: 12px !important;
159
+ padding: 0 !important;
160
+ overflow: hidden;
161
+ }
162
+
163
+ /* Input/output image areas */
164
+ .image-wrap {
165
+ border-radius: 10px;
166
+ overflow: hidden;
167
+ background: #0c0f0a;
168
+ }
169
+
170
+ /* Upload button */
171
+ .upload-btn, button[class*="upload"] {
172
+ background: #1a2e14 !important;
173
+ color: #7ec86a !important;
174
+ border: 1px dashed #2d4a24 !important;
175
+ border-radius: 10px !important;
176
+ font-family: 'DM Sans', sans-serif !important;
177
+ }
178
+
179
+ /* Submit button */
180
+ #run-btn button, button#run-btn {
181
+ background: #3d7a2e !important;
182
+ color: #e8f5e1 !important;
183
+ font-family: 'DM Serif Display', serif !important;
184
+ font-size: 1.05rem !important;
185
+ letter-spacing: 0.3px !important;
186
+ border: none !important;
187
+ border-radius: 8px !important;
188
+ padding: 12px 0 !important;
189
+ transition: background 0.2s ease !important;
190
+ width: 100% !important;
191
+ }
192
+
193
+ #run-btn button:hover {
194
+ background: #4e9939 !important;
195
+ }
196
+
197
+ /* Summary text */
198
+ #summary-box {
199
+ background: #111710 !important;
200
+ border: 1px solid #1e2a1a !important;
201
+ border-radius: 10px !important;
202
+ padding: 16px 20px !important;
203
+ color: #b5d4a8 !important;
204
+ font-size: 1rem !important;
205
+ font-family: 'DM Sans', sans-serif !important;
206
+ min-height: 52px !important;
207
+ }
208
+
209
+ /* Table */
210
+ #table-box {
211
+ background: #111710 !important;
212
+ border: 1px solid #1e2a1a !important;
213
+ border-radius: 10px !important;
214
+ padding: 4px 0 !important;
215
+ font-family: 'DM Sans', sans-serif !important;
216
+ }
217
+
218
+ #table-box table {
219
+ width: 100% !important;
220
+ border-collapse: collapse !important;
221
+ }
222
+
223
+ #table-box th {
224
+ background: #182413 !important;
225
+ color: #6b8f5e !important;
226
+ font-size: 0.72rem !important;
227
+ font-weight: 500 !important;
228
+ letter-spacing: 1.2px !important;
229
+ text-transform: uppercase !important;
230
+ padding: 10px 18px !important;
231
+ border-bottom: 1px solid #1e2a1a !important;
232
+ }
233
+
234
+ #table-box td {
235
+ padding: 12px 18px !important;
236
+ color: #c8e0bf !important;
237
+ font-size: 0.95rem !important;
238
+ border-bottom: 1px solid #141d10 !important;
239
+ }
240
+
241
+ #table-box tr:last-child td {
242
+ border-bottom: none !important;
243
+ }
244
+
245
+ #table-box tr:hover td {
246
+ background: #141d10 !important;
247
+ }
248
+
249
+ /* Section labels */
250
+ .section-label {
251
+ color: #4a6e3e;
252
+ font-size: 0.7rem;
253
+ font-weight: 500;
254
+ letter-spacing: 1.4px;
255
+ text-transform: uppercase;
256
+ margin-bottom: 8px;
257
+ padding: 0 2px;
258
+ }
259
+
260
+ /* Footer */
261
+ #footer {
262
+ text-align: center;
263
+ color: #2d4224;
264
+ font-size: 0.78rem;
265
+ padding: 32px 0 24px;
266
+ border-top: 1px solid #141d10;
267
+ margin-top: 40px;
268
+ letter-spacing: 0.3px;
269
+ }
270
+
271
+ /* Gradio component overrides */
272
+ .gr-block, .gr-box { background: transparent !important; border: none !important; }
273
+ label.svelte-1b6s6vi, .gr-input-label { color: #4a6e3e !important; font-size: 0.7rem !important; letter-spacing: 1.2px !important; text-transform: uppercase !important; font-weight: 500 !important; }
274
+ """
275
+
276
+
277
+ # ── Layout ────────────────────────────────────────────────────────────────────
278
+ with gr.Blocks(css=css, title="GrainVision β€” Rice Grain Classifier") as demo:
279
+
280
+ gr.HTML("""
281
+ <div id="header">
282
+ <span>AI Β· Segmentation Β· Classification</span>
283
+ <h1>🌾 GrainVision</h1>
284
+ <p>Upload a rice image to detect and classify each grain as Full or Broken</p>
285
+ </div>
286
+ """)
287
+
288
+ with gr.Row(equal_height=False):
289
+
290
+ # ── Left: Input ───────────────────────────────────────────────────────
291
+ with gr.Column(scale=1):
292
+ gr.HTML('<div class="section-label">Input Image</div>')
293
+ input_img = gr.Image(
294
+ type="pil",
295
+ label="",
296
+ elem_classes=["image-wrap"],
297
+ height=420,
298
+ )
299
+ gr.HTML('<div style="height:12px"></div>')
300
+ run_btn = gr.Button("Analyse Grains", elem_id="run-btn")
301
+
302
+ # ── Right: Output ─────────────────────────────────────────────────────
303
+ with gr.Column(scale=1):
304
+ gr.HTML('<div class="section-label">Segmentation Result</div>')
305
+ output_img = gr.Image(
306
+ type="pil",
307
+ label="",
308
+ elem_classes=["image-wrap"],
309
+ height=420,
310
+ )
311
+
312
+ gr.HTML('<div style="height:20px"></div>')
313
+
314
+ # ── Results row ───────────────────────────────────────────────────────────
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ gr.HTML('<div class="section-label">Detection Summary</div>')
318
+ summary_md = gr.Markdown(
319
+ value="Results will appear here after analysis.",
320
+ elem_id="summary-box",
321
+ )
322
+
323
+ with gr.Column(scale=1):
324
+ gr.HTML('<div class="section-label">Grain Count Table</div>')
325
+ table_md = gr.Markdown(
326
+ value="| | Count |\n|---|---|\n| 🌾 Total Grains | β€” |\n| 🟒 Full Grains | β€” |\n| πŸ”΄ Broken Grains | β€” |",
327
+ elem_id="table-box",
328
+ )
329
+
330
+ # ── Events ────────────────────────────────────────────────────────────────
331
+ run_btn.click(
332
+ fn = predict,
333
+ inputs = [input_img],
334
+ outputs = [output_img, summary_md, table_md],
335
+ )
336
+ input_img.change(
337
+ fn = predict,
338
+ inputs = [input_img],
339
+ outputs = [output_img, summary_md, table_md],
340
+ )
341
+
342
+ gr.HTML('<div id="footer">GrainVision Β· Powered by YOLO11x-seg Β· For research & quality inspection use</div>')
343
+
344
+
345
+ if __name__ == "__main__":
346
+ demo.launch()