saim1309 commited on
Commit
458095e
Β·
verified Β·
1 Parent(s): 8290f6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -268
app.py CHANGED
@@ -2,27 +2,38 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import io
5
- import os
6
- import tempfile
7
- import base64
8
- from PIL import Image, ImageDraw
9
  import matplotlib
10
  matplotlib.use("Agg")
11
  import matplotlib.pyplot as plt
12
 
13
- # ─── PDF ──────────────────────────────────────────────────────────────────────
14
- from reportlab.lib.pagesizes import A4
15
- from reportlab.lib import colors
16
- from reportlab.lib.units import cm, mm
17
- from reportlab.platypus import (
18
- SimpleDocTemplate, Paragraph, Spacer,
19
- Image as RLImage, Table, TableStyle, HRFlowable,
20
- )
21
- from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
22
- from reportlab.lib.enums import TA_CENTER, TA_LEFT, TA_RIGHT
23
- from reportlab.pdfgen import canvas as rl_canvas
24
- from reportlab.platypus import BaseDocTemplate, PageTemplate, Frame
25
- from reportlab.lib.colors import red, black
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ─── Cellpose model (lazy) ────────────────────────────────────────────────────
28
  _model = None
@@ -33,10 +44,9 @@ def get_model():
33
  from cellpose import models
34
  from huggingface_hub import hf_hub_download
35
  fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
36
- _model = models.CellposeModel(gpu=False, pretrained_model=fpath)
37
  return _model
38
 
39
-
40
  # ─── Image helpers ────────────────────────────────────────────────────────────
41
  def normalize99(img):
42
  X = img.copy().astype(np.float32)
@@ -89,245 +99,197 @@ def build_outline_image(img, masks) -> Image.Image:
89
  return out
90
 
91
 
92
- # ─── Colours ──────────────────────────────────────────────────────────────────
93
- VIOLET = colors.HexColor("#7C3AED")
94
- VIOLET_DARK = colors.HexColor("#5B21B6")
95
- VIOLET_LITE = colors.HexColor("#EDE9FE")
96
- TEAL = colors.HexColor("#0D9488")
97
- TEAL_LITE = colors.HexColor("#CCFBF1")
98
- AMBER = colors.HexColor("#D97706")
99
- AMBER_LITE = colors.HexColor("#FEF3C7")
100
- ROSE = colors.HexColor("#E11D48")
101
- ROSE_LITE = colors.HexColor("#FFE4E6")
102
- SLATE = colors.HexColor("#1E293B")
103
- SLATE_MID = colors.HexColor("#64748B")
104
- SLATE_LITE = colors.HexColor("#F1F5F9")
105
- WHITE = colors.white
 
106
 
 
 
 
 
107
 
108
- def _header_canvas(c, doc):
 
 
 
 
 
 
 
109
  """
110
- Header: NO background fill β€” just the MLBench text, tagline, and teal line.
111
- Footer: page number + thin rule.
 
 
112
  """
113
- W, H = A4
114
- c.saveState()
115
-
116
- # ── "ML" in red, "Bench" in black β€” no background ──────────────────────
117
- c.setFont("Helvetica-Bold", 24)
118
- c.setFillColor(red)
119
- ml_w = c.stringWidth("ML", "Helvetica-Bold", 24)
120
- c.drawString(1.8*cm, H - 1.6*cm, "ML")
121
- c.setFillColor(black)
122
- c.drawString(1.8*cm + ml_w, H - 1.6*cm, "Bench")
123
-
124
- # Tagline right-aligned β€” dark colour since no purple bg
125
- c.setFont("Helvetica", 9)
126
- c.setFillColor(SLATE_MID)
127
- c.drawRightString(W - 1.8*cm, H - 1.55*cm, "Rice Grain Analysis Report")
128
-
129
- # Teal accent line β€” kept exactly as before
130
- c.setStrokeColor(TEAL)
131
- c.setLineWidth(2.5)
132
- c.line(0, H - 2.0*cm, W, H - 2.0*cm)
133
-
134
- # Footer rule + page number
135
- c.setStrokeColor(colors.HexColor("#E2E8F0"))
136
- c.setLineWidth(0.5)
137
- c.line(1.8*cm, 1.4*cm, W - 1.8*cm, 1.4*cm)
138
- c.setFont("Helvetica", 8)
139
- c.setFillColor(SLATE_MID)
140
- c.drawRightString(W - 1.8*cm, 0.9*cm, f"Page {doc.page}")
141
-
142
- c.restoreState()
143
-
144
-
145
- def build_pdf(segmented_pil: Image.Image, total_count: int) -> str:
146
- out_path = tempfile.mktemp(suffix=".pdf")
147
- PAGE_W, PAGE_H = A4
148
- LM = RM = 1.8 * cm
149
- TM = 2.6 * cm # slightly less since no coloured band
150
- BM = 2.0 * cm
151
- usable_w = PAGE_W - LM - RM
152
-
153
- doc = BaseDocTemplate(
154
- out_path, pagesize=A4,
155
- leftMargin=LM, rightMargin=RM,
156
- topMargin=TM, bottomMargin=BM,
157
- )
158
- frame = Frame(LM, BM, usable_w, PAGE_H - TM - BM, id="main")
159
- doc.addPageTemplates([PageTemplate(id="pg", frames=[frame], onPage=_header_canvas)])
160
-
161
- # ── Styles ───────────────────────────────────────────────────────────────
162
- section_s = ParagraphStyle(
163
- "SEC", fontSize=10, fontName="Helvetica-Bold",
164
- textColor=WHITE, alignment=TA_LEFT,
165
- leftIndent=6, spaceAfter=0, leading=14,
166
- )
167
- stat_label_s = ParagraphStyle(
168
- "SL", fontSize=9, fontName="Helvetica-Bold",
169
- textColor=SLATE, alignment=TA_LEFT, leading=13,
170
- )
171
- stat_val_s = ParagraphStyle(
172
- "SV", fontSize=9, fontName="Helvetica-Bold",
173
- textColor=SLATE, alignment=TA_CENTER, leading=13,
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- def section_header(title, bg_color):
177
- t = Table([[Paragraph(title, section_s)]], colWidths=[usable_w])
178
- t.setStyle(TableStyle([
179
- ("BACKGROUND", (0,0),(-1,-1), bg_color),
180
- ("TOPPADDING", (0,0),(-1,-1), 5),
181
- ("BOTTOMPADDING", (0,0),(-1,-1), 5),
182
- ("LEFTPADDING", (0,0),(-1,-1), 8),
183
- ("RIGHTPADDING", (0,0),(-1,-1), 8),
184
- ]))
185
- return t
186
-
187
- story = []
188
-
189
- # ── Column widths & gap ──────────────────────────────────────────────────
190
- GAP = 0.5 * cm # space between the two columns
191
- stat_w = usable_w * 0.42 # left β†’ stats table
192
- img_w = usable_w - stat_w - GAP # right β†’ image
193
-
194
- # ── Separate column headers (stats left, image right) ────────────────────
195
- stat_hdr_s = ParagraphStyle("SHD", fontSize=10, fontName="Helvetica-Bold",
196
- textColor=WHITE, alignment=TA_CENTER, leading=14)
197
- img_hdr_s = ParagraphStyle("IHD", fontSize=10, fontName="Helvetica-Bold",
198
- textColor=WHITE, alignment=TA_CENTER, leading=14)
199
-
200
- stat_hdr_cell = Table([[Paragraph("Grain Count Statistics", stat_hdr_s)]],
201
- colWidths=[stat_w])
202
- stat_hdr_cell.setStyle(TableStyle([
203
- ("BACKGROUND", (0,0),(-1,-1), TEAL),
204
- ("TOPPADDING", (0,0),(-1,-1), 5),
205
- ("BOTTOMPADDING", (0,0),(-1,-1), 5),
206
- ("LEFTPADDING", (0,0),(-1,-1), 8),
207
- ("RIGHTPADDING", (0,0),(-1,-1), 8),
208
- ]))
209
-
210
- img_hdr_cell = Table([[Paragraph("Segmentation Output", img_hdr_s)]],
211
- colWidths=[img_w])
212
- img_hdr_cell.setStyle(TableStyle([
213
- ("BACKGROUND", (0,0),(-1,-1), VIOLET),
214
- ("TOPPADDING", (0,0),(-1,-1), 5),
215
- ("BOTTOMPADDING", (0,0),(-1,-1), 5),
216
- ("LEFTPADDING", (0,0),(-1,-1), 8),
217
- ("RIGHTPADDING", (0,0),(-1,-1), 8),
218
- ]))
219
-
220
- # Combined header row (gap column in between)
221
- hdr_row = Table(
222
- [[stat_hdr_cell, "", img_hdr_cell]],
223
- colWidths=[stat_w, GAP, img_w],
224
- )
225
- hdr_row.setStyle(TableStyle([
226
- ("LEFTPADDING", (0,0),(-1,-1), 0),
227
- ("RIGHTPADDING", (0,0),(-1,-1), 0),
228
- ("TOPPADDING", (0,0),(-1,-1), 0),
229
- ("BOTTOMPADDING",(0,0),(-1,-1), 0),
230
- ("VALIGN", (0,0),(-1,-1), "TOP"),
231
- ]))
232
- story.append(hdr_row)
233
- story.append(Spacer(1, 5))
234
-
235
- # ── Stats table ──────────────────────────────────────────────────────────
236
  stat_rows_def = [
237
- ("Total Rice Grain", str(total_count), VIOLET, VIOLET_LITE),
238
- ("Long Grain", "β€”", TEAL, TEAL_LITE),
239
- ("Short Grain", "β€”", AMBER, AMBER_LITE),
240
- ("Half Grain", "β€”", ROSE, ROSE_LITE),
241
- ("Broken Edge", "β€”", SLATE_MID,SLATE_LITE),
242
  ]
243
 
244
- col_hdr_s = ParagraphStyle("CHD", fontSize=8, fontName="Helvetica-Bold",
245
- textColor=WHITE, alignment=TA_CENTER)
246
- stripe_w = 0.28 * cm
247
- label_w = stat_w * 0.64
248
- val_w = stat_w - stripe_w - label_w
249
-
250
- tdata = [["", Paragraph("Category", col_hdr_s), Paragraph("Count", col_hdr_s)]]
251
- for label, val, _, _ in stat_rows_def:
252
- tdata.append(["", Paragraph(label, stat_label_s), Paragraph(val, stat_val_s)])
253
-
254
- stat_table = Table(tdata, colWidths=[stripe_w, label_w, val_w])
255
- ts = TableStyle([
256
- ("BACKGROUND", (0,0),(-1,0), SLATE),
257
- ("TOPPADDING", (0,0),(-1,-1), 7),
258
- ("BOTTOMPADDING", (0,0),(-1,-1), 7),
259
- ("LEFTPADDING", (0,0),(-1,-1), 4),
260
- ("RIGHTPADDING", (0,0),(-1,-1), 4),
261
- ("VALIGN", (0,0),(-1,-1), "MIDDLE"),
262
- ("ALIGN", (2,0),(2,-1), "CENTER"),
263
- ("BOX", (0,0),(-1,-1), 1, colors.HexColor("#CBD5E1")),
264
- ("INNERGRID", (1,1),(-1,-1), 0.4, colors.HexColor("#E2E8F0")),
265
- ])
266
- for i, (_, _, accent, bg) in enumerate(stat_rows_def, start=1):
267
- ts.add("BACKGROUND", (0,i), (0,i), accent)
268
- ts.add("BACKGROUND", (1,i), (2,i), bg)
269
- ts.add("FONTSIZE", (2,1),(2,1), 11)
270
- ts.add("TEXTCOLOR", (2,1),(2,1), VIOLET)
271
- stat_table.setStyle(ts)
272
-
273
- # ── Image β€” exact height match to stats table ───────────────────────────
274
- # Measure the actual rendered height of stat_table using wrap()
275
- from reportlab.lib.pagesizes import A4 as _A4
276
- _exact_h = stat_table.wrap(stat_w, 9999)[1] # returns (w, h) in points
277
- img_max_h = _exact_h # points β€” same unit RLImage uses
278
-
279
- buf = io.BytesIO()
280
- segmented_pil.save(buf, format="PNG"); buf.seek(0)
281
- iw, ih = segmented_pil.size
282
- # img_w is in points (cm * 28.35); keep image within column width & exact table height
283
- ratio = min(img_w / iw, img_max_h / ih)
284
- rl_img = RLImage(buf, width=iw * ratio, height=ih * ratio)
285
-
286
- img_cell = Table([[rl_img]], colWidths=[img_w], rowHeights=[_exact_h])
287
- img_cell.setStyle(TableStyle([
288
- ("BOX", (0,0),(-1,-1), 1.5, VIOLET),
289
- ("BACKGROUND", (0,0),(-1,-1), colors.black),
290
- ("ALIGN", (0,0),(-1,-1), "CENTER"),
291
- ("VALIGN", (0,0),(-1,-1), "MIDDLE"),
292
- ("TOPPADDING", (0,0),(-1,-1), 0),
293
- ("BOTTOMPADDING", (0,0),(-1,-1), 0),
294
- ("LEFTPADDING", (0,0),(-1,-1), 4),
295
- ("RIGHTPADDING", (0,0),(-1,-1), 4),
296
- ]))
297
-
298
- # ── Side-by-side: stats LEFT, image RIGHT, gap in middle ─────────────────
299
- side_by_side = Table(
300
- [[stat_table, "", img_cell]],
301
- colWidths=[stat_w, GAP, img_w],
302
- )
303
- side_by_side.setStyle(TableStyle([
304
- ("VALIGN", (0,0),(-1,-1), "TOP"),
305
- ("LEFTPADDING", (0,0),(-1,-1), 0),
306
- ("RIGHTPADDING", (0,0),(-1,-1), 0),
307
- ("TOPPADDING", (0,0),(-1,-1), 0),
308
- ("BOTTOMPADDING",(0,0),(-1,-1), 0),
309
- ]))
310
- story.append(side_by_side)
311
-
312
- doc.build(story)
313
- return out_path
314
-
315
-
316
- def pdf_to_preview_html(pdf_path: str) -> str:
317
- with open(pdf_path, "rb") as f:
318
- b64 = base64.b64encode(f.read()).decode()
319
- return (
320
- f'<iframe src="data:application/pdf;base64,{b64}" '
321
- f'width="100%" height="680px" '
322
- f'style="border:1px solid #7C3AED; border-radius:8px;"></iframe>'
323
- )
324
 
325
 
326
  # ─── Sample example images ────────────────────────────────────────────────────
327
  SAMPLE_PATHS = [
328
- "IMG_0614.jpg",
329
- "IMG_0623.jpg",
330
- "IMG_0693.jpg",
331
  ]
332
 
333
  # ─── Status helpers ───────────────────────────────────────────────────────────
@@ -339,8 +301,9 @@ def make_status(level: str, message: str) -> dict:
339
 
340
  # ─── Main processing ──────────────────────────────────────────────────────────
341
  def process_image(pil_image):
 
342
  if pil_image is None:
343
- return None, "", make_status("warning", "No image provided. Please upload or select a sample image first.")
344
 
345
  try:
346
  img_np = np.array(pil_image.convert("RGB"))
@@ -351,7 +314,7 @@ def process_image(pil_image):
351
  total_count = int(masks.max())
352
 
353
  if total_count == 0:
354
- return None, "", make_status(
355
  "warning",
356
  "No rice grains were detected in this image. "
357
  "Try a clearer photo or adjust the image contrast."
@@ -362,22 +325,20 @@ def process_image(pil_image):
362
  (img_resized.shape[1], img_resized.shape[0]), resample=Image.BICUBIC
363
  )
364
 
365
- pdf_path = build_pdf(outline_pil, total_count)
366
- preview_html = pdf_to_preview_html(pdf_path)
367
 
368
  return (
369
- pdf_path,
370
- preview_html,
371
- make_status("success", f"{total_count} rice grains detected. PDF report ready β€” preview on the right, download below it."),
372
  )
373
 
374
  except MemoryError:
375
- return None, "", make_status("error", "Out of memory. Try uploading a smaller image.")
376
 
377
  except Exception as e:
378
  import traceback
379
  traceback.print_exc()
380
- return None, "", make_status("error", f"Unexpected error: {type(e).__name__}: {str(e)}")
381
 
382
 
383
  # ─── UI ───────────────────────────────────────────────────────────────────────
@@ -393,7 +354,7 @@ CSS = """
393
  #status-box textarea { font-size: 0.92rem; }
394
  """
395
 
396
- with gr.Blocks(title="Rice Grain Counter", css=CSS) as demo:
397
 
398
  gr.HTML("""
399
  <div style="padding:18px 12px 10px 12px; background-color:#0F172A; border-radius:10px; margin-bottom:10px;">
@@ -401,7 +362,7 @@ with gr.Blocks(title="Rice Grain Counter", css=CSS) as demo:
401
  Rice Grain Counter
402
  </span>
403
  <p style="color:#CBD5E1;font-size:0.9rem;margin-top:4px;font-family:sans-serif;">
404
- Upload a rice image to segment each grain and generate a PDF report.
405
  </p>
406
  </div>
407
  """)
@@ -436,23 +397,18 @@ with gr.Blocks(title="Rice Grain Counter", css=CSS) as demo:
436
 
437
  # ── RIGHT COLUMN ──────────────────────────────────────────────────
438
  with gr.Column(scale=1):
439
- gr.Markdown("### PDF Report Preview")
440
- pdf_preview = gr.HTML(
441
- value="<div style='height:680px;border:1px dashed #CBD5E1;border-radius:8px;"
442
- "display:flex;align-items:center;justify-content:center;"
443
- "color:#94A3B8;font-family:sans-serif;font-size:0.95rem;'>"
444
- "PDF preview will appear here after analysis.</div>"
445
- )
446
- pdf_output = gr.File(
447
- label="⬇ Download PDF",
448
  interactive=False,
 
449
  )
450
 
451
  run_btn.click(
452
  fn=process_image,
453
  inputs=[inp_image],
454
- outputs=[pdf_output, pdf_preview, status_box],
455
  )
456
 
457
  if __name__ == "__main__":
458
- demo.launch(share=True)
 
2
  import cv2
3
  import numpy as np
4
  import io
5
+ from PIL import Image, ImageDraw, ImageFont
 
 
 
6
  import matplotlib
7
  matplotlib.use("Agg")
8
  import matplotlib.pyplot as plt
9
 
10
+ # ─── Font paths ───────────────────────────────────────────────────────────────
11
+ FONT_BOLD = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
12
+ FONT_REGULAR = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
13
+
14
+ def _font(size, bold=True):
15
+ try:
16
+ return ImageFont.truetype(FONT_BOLD if bold else FONT_REGULAR, size)
17
+ except Exception:
18
+ return ImageFont.load_default()
19
+
20
+ # ─── Colours (R,G,B) ──────────────────────────────────────────────────────────
21
+ C_VIOLET = (124, 58, 237)
22
+ C_VIOLET_DARK = ( 91, 33, 182)
23
+ C_VIOLET_LITE = (237, 233, 254)
24
+ C_TEAL = ( 13, 148, 136)
25
+ C_TEAL_LITE = (204, 251, 241)
26
+ C_AMBER = (217, 119, 6)
27
+ C_AMBER_LITE = (254, 243, 199)
28
+ C_ROSE = (225, 29, 72)
29
+ C_ROSE_LITE = (255, 228, 230)
30
+ C_SLATE = ( 30, 41, 59)
31
+ C_SLATE_MID = (100, 116, 139)
32
+ C_SLATE_LITE = (241, 245, 249)
33
+ C_WHITE = (255, 255, 255)
34
+ C_BLACK = ( 0, 0, 0)
35
+ C_RED = (220, 38, 38)
36
+ C_BG = (255, 255, 255) # page background
37
 
38
  # ─── Cellpose model (lazy) ────────────────────────────────────────────────────
39
  _model = None
 
44
  from cellpose import models
45
  from huggingface_hub import hf_hub_download
46
  fpath = hf_hub_download(repo_id="mouseland/cellpose-sam", filename="cpsam")
47
+ _model = models.CellposeModel(gpu=True, pretrained_model=fpath)
48
  return _model
49
 
 
50
  # ─── Image helpers ────────────────────────────────────────────────────────────
51
  def normalize99(img):
52
  X = img.copy().astype(np.float32)
 
99
  return out
100
 
101
 
102
+ # ─── Drawing helpers ──────────────────────────────────────────────────────────
103
+ def _text_size(draw, text, font):
104
+ """Return (width, height) of text."""
105
+ bbox = draw.textbbox((0, 0), text, font=font)
106
+ return bbox[2] - bbox[0], bbox[3] - bbox[1]
107
+
108
+ def _draw_rect(img, x0, y0, x1, y1, fill, border=None, border_width=2, radius=0):
109
+ """Draw a filled rectangle with optional border on a PIL Image."""
110
+ draw = ImageDraw.Draw(img)
111
+ if radius > 0:
112
+ draw.rounded_rectangle([x0, y0, x1, y1], radius=radius, fill=fill,
113
+ outline=border, width=border_width if border else 0)
114
+ else:
115
+ draw.rectangle([x0, y0, x1, y1], fill=fill,
116
+ outline=border, width=border_width if border else 0)
117
 
118
+ def _draw_text_centred(img, cx, cy, text, font, color):
119
+ draw = ImageDraw.Draw(img)
120
+ tw, th = _text_size(draw, text, font)
121
+ draw.text((cx - tw // 2, cy - th // 2), text, font=font, fill=color)
122
 
123
+ def _draw_text_left(img, x, cy, text, font, color):
124
+ draw = ImageDraw.Draw(img)
125
+ _, th = _text_size(draw, text, font)
126
+ draw.text((x, cy - th // 2), text, font=font, fill=color)
127
+
128
+
129
+ # ─── Report image builder ─────────────────────────────────────────────────────
130
+ def build_report_image(segmented_pil: Image.Image, total_count: int) -> Image.Image:
131
  """
132
+ Renders the full report as a PIL Image with the same structure as the PDF:
133
+ β€’ Header : MLBench + tagline + teal rule
134
+ β€’ Body : [Grain Count Statistics table] | gap | [Segmentation Output image]
135
+ No footer line / page number.
136
  """
137
+ DPI = 150
138
+ PW_IN = 8.27 # A4 width in inches
139
+ PH_IN = 11.69 # A4 height in inches (we'll crop to content)
140
+ PW = int(PW_IN * DPI)
141
+ MARGIN = int(0.7 * DPI) # ~0.7 inch margin
142
+
143
+ # ── Fonts ─────────────────────────────────────────────────────────────
144
+ f_logo_ml = _font(int(0.28 * DPI)) # "ML" large
145
+ f_logo_b = _font(int(0.28 * DPI)) # "Bench" same size
146
+ f_tagline = _font(int(0.09 * DPI), bold=False)
147
+ f_sec_hdr = _font(int(0.11 * DPI)) # section bar text
148
+ f_col_hdr = _font(int(0.09 * DPI)) # table column headers
149
+ f_label = _font(int(0.10 * DPI)) # row labels
150
+ f_val_total = _font(int(0.13 * DPI)) # total count value (bigger)
151
+ f_val = _font(int(0.10 * DPI)) # other value cells
152
+
153
+ # ── Dimensions ────────────────────────────────────────────────────────
154
+ usable_w = PW - 2 * MARGIN
155
+ GAP = int(0.18 * DPI)
156
+ stat_w = int(usable_w * 0.43)
157
+ img_col_w = usable_w - stat_w - GAP
158
+
159
+ HDR_H = int(0.55 * DPI) # header area height
160
+ SEC_BAR_H = int(0.22 * DPI) # coloured section title bar
161
+ COL_HDR_H = int(0.18 * DPI) # table column header row
162
+ ROW_H = int(0.17 * DPI) # each data row
163
+ STRIPE_W = int(0.07 * DPI) # coloured left stripe on each row
164
+ TEAL_LINE = 3 # teal rule thickness
165
+
166
+ N_ROWS = 5
167
+ TABLE_H = COL_HDR_H + N_ROWS * ROW_H
168
+
169
+ # Total canvas height: margin + header + gap + sec_bar + content + margin
170
+ BODY_TOP = HDR_H + int(0.12 * DPI) # y where body starts
171
+ CONTENT_H = SEC_BAR_H + TABLE_H
172
+ CANVAS_H = BODY_TOP + CONTENT_H + MARGIN
173
+
174
+ # ── Create canvas ─────────────────────────────────────────────────────
175
+ img = Image.new("RGB", (PW, CANVAS_H), C_BG)
176
+ draw = ImageDraw.Draw(img)
177
+
178
+ # ── Header ────────────────────────────────────────────────────────────
179
+ # "ML" in red, "Bench" in black
180
+ logo_y = int(HDR_H * 0.38)
181
+ ml_w, _ = _text_size(draw, "ML", f_logo_ml)
182
+ draw.text((MARGIN, logo_y), "ML", font=f_logo_ml, fill=C_RED)
183
+ draw.text((MARGIN + ml_w, logo_y), "Bench", font=f_logo_b, fill=C_BLACK)
184
+
185
+ # Tagline right-aligned
186
+ tag = "Rice Grain Analysis Report"
187
+ tag_w, tag_h = _text_size(draw, tag, f_tagline)
188
+ draw.text((PW - MARGIN - tag_w, logo_y + 6), tag, font=f_tagline, fill=C_SLATE_MID)
189
+
190
+ # Teal horizontal rule
191
+ rule_y = HDR_H - 4
192
+ draw.rectangle([0, rule_y, PW, rule_y + TEAL_LINE], fill=C_TEAL)
193
+
194
+ # ── Section header bars ───────────────────────────────────────────────
195
+ stat_x = MARGIN
196
+ img_x = MARGIN + stat_w + GAP
197
+
198
+ stat_bar_y0 = BODY_TOP
199
+ stat_bar_y1 = BODY_TOP + SEC_BAR_H
200
+
201
+ # Teal bar β€” "Grain Count Statistics"
202
+ _draw_rect(img, stat_x, stat_bar_y0, stat_x + stat_w, stat_bar_y1, fill=C_TEAL)
203
+ _draw_text_centred(img, stat_x + stat_w // 2, (stat_bar_y0 + stat_bar_y1) // 2,
204
+ "Grain Count Statistics", f_sec_hdr, C_WHITE)
205
+
206
+ # Violet bar β€” "Segmentation Output"
207
+ _draw_rect(img, img_x, stat_bar_y0, img_x + img_col_w, stat_bar_y1, fill=C_VIOLET)
208
+ _draw_text_centred(img, img_x + img_col_w // 2, (stat_bar_y0 + stat_bar_y1) // 2,
209
+ "Segmentation Output", f_sec_hdr, C_WHITE)
210
+
211
+ # ── Stats table ───────────────────────────────────────────────────────
212
+ table_top = BODY_TOP + SEC_BAR_H
213
+ col_hdr_y0 = table_top
214
+ col_hdr_y1 = table_top + COL_HDR_H
215
+
216
+ # Column header background
217
+ _draw_rect(img, stat_x, col_hdr_y0, stat_x + stat_w, col_hdr_y1, fill=C_SLATE)
218
+ cat_cx = stat_x + STRIPE_W + (stat_w - STRIPE_W) // 2 - int((stat_w - STRIPE_W) * 0.18)
219
+ count_cx = stat_x + STRIPE_W + int((stat_w - STRIPE_W) * 0.78)
220
+ _draw_text_centred(img, cat_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Category", f_col_hdr, C_WHITE)
221
+ _draw_text_centred(img, count_cx, (col_hdr_y0 + col_hdr_y1) // 2, "Count", f_col_hdr, C_WHITE)
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  stat_rows_def = [
224
+ ("Total Rice Grain", str(total_count), C_VIOLET, C_VIOLET_LITE),
225
+ ("Long Grain", "β€”", C_TEAL, C_TEAL_LITE),
226
+ ("Short Grain", "β€”", C_AMBER, C_AMBER_LITE),
227
+ ("Half Grain", "β€”", C_ROSE, C_ROSE_LITE),
228
+ ("Broken Edge", "β€”", C_SLATE_MID,C_SLATE_LITE),
229
  ]
230
 
231
+ border_color = (203, 213, 225)
232
+ grid_color = (226, 232, 240)
233
+
234
+ for i, (label, val, accent, bg) in enumerate(stat_rows_def):
235
+ ry0 = table_top + COL_HDR_H + i * ROW_H
236
+ ry1 = ry0 + ROW_H
237
+ cy = (ry0 + ry1) // 2
238
+
239
+ # Row background
240
+ _draw_rect(img, stat_x, ry0, stat_x + stat_w, ry1, fill=bg)
241
+ # Accent stripe
242
+ _draw_rect(img, stat_x, ry0, stat_x + STRIPE_W, ry1, fill=accent)
243
+ # Label
244
+ f_lbl = f_label
245
+ _draw_text_left(img, stat_x + STRIPE_W + 8, cy, label, f_lbl, C_SLATE)
246
+ # Value
247
+ f_v = f_val_total if i == 0 else f_val
248
+ c_v = C_VIOLET if i == 0 else C_SLATE
249
+ vw, _ = _text_size(draw, val, f_v)
250
+ draw.text((stat_x + stat_w - vw - 14, cy - _text_size(draw, val, f_v)[1] // 2),
251
+ val, font=f_v, fill=c_v)
252
+ # Horizontal grid line
253
+ draw.rectangle([stat_x, ry1 - 1, stat_x + stat_w, ry1], fill=grid_color)
254
+
255
+ # Outer border of table (column header + rows)
256
+ draw.rectangle([stat_x, col_hdr_y0, stat_x + stat_w,
257
+ table_top + COL_HDR_H + N_ROWS * ROW_H], outline=border_color, width=1)
258
+
259
+ # ── Segmentation image ────────────────────────────────────────────────
260
+ # Fit segmented image to exactly match table height (SEC_BAR already above)
261
+ target_h = TABLE_H # must match table area below sec bar
262
+ target_w = img_col_w
263
+
264
+ seg_np = np.array(segmented_pil)
265
+ ih, iw = seg_np.shape[:2]
266
+ scale = min(target_w / iw, target_h / ih)
267
+ new_w = int(iw * scale)
268
+ new_h = int(ih * scale)
269
+ seg_resized = segmented_pil.resize((new_w, new_h), Image.BICUBIC)
270
+
271
+ # Black background box β€” same height as table
272
+ box_x0 = img_x
273
+ box_y0 = table_top # align top with table (below sec bar)
274
+ box_x1 = img_x + img_col_w
275
+ box_y1 = table_top + TABLE_H
276
+
277
+ _draw_rect(img, box_x0, box_y0, box_x1, box_y1,
278
+ fill=C_BLACK, border=C_VIOLET, border_width=2)
279
+
280
+ # Centre the image inside the black box
281
+ paste_x = box_x0 + (img_col_w - new_w) // 2
282
+ paste_y = box_y0 + (TABLE_H - new_h) // 2
283
+ img.paste(seg_resized, (paste_x, paste_y))
284
+
285
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
 
288
  # ─── Sample example images ────────────────────────────────────────────────────
289
  SAMPLE_PATHS = [
290
+ "sample_image/IMG_0614.jpg",
291
+ "sample_image/IMG_0623.jpg",
292
+ "sample_image/IMG_0693.jpg",
293
  ]
294
 
295
  # ─── Status helpers ───────────────────────────────────────────────────────────
 
301
 
302
  # ─── Main processing ──────────────────────────────────────────────────────────
303
  def process_image(pil_image):
304
+ # Returns: (report_image, status_update)
305
  if pil_image is None:
306
+ return None, make_status("warning", "No image provided. Please upload or select a sample image first.")
307
 
308
  try:
309
  img_np = np.array(pil_image.convert("RGB"))
 
314
  total_count = int(masks.max())
315
 
316
  if total_count == 0:
317
+ return None, make_status(
318
  "warning",
319
  "No rice grains were detected in this image. "
320
  "Try a clearer photo or adjust the image contrast."
 
325
  (img_resized.shape[1], img_resized.shape[0]), resample=Image.BICUBIC
326
  )
327
 
328
+ report_img = build_report_image(outline_pil, total_count)
 
329
 
330
  return (
331
+ report_img,
332
+ make_status("success", f"{total_count} rice grains detected. Report image shown on the right."),
 
333
  )
334
 
335
  except MemoryError:
336
+ return None, make_status("error", "Out of memory. Try uploading a smaller image.")
337
 
338
  except Exception as e:
339
  import traceback
340
  traceback.print_exc()
341
+ return None, make_status("error", f"Unexpected error: {type(e).__name__}: {str(e)}")
342
 
343
 
344
  # ─── UI ───────────────────────────────────────────────────────────────────────
 
354
  #status-box textarea { font-size: 0.92rem; }
355
  """
356
 
357
+ with gr.Blocks(title="Rice Grain Counter") as demo:
358
 
359
  gr.HTML("""
360
  <div style="padding:18px 12px 10px 12px; background-color:#0F172A; border-radius:10px; margin-bottom:10px;">
 
362
  Rice Grain Counter
363
  </span>
364
  <p style="color:#CBD5E1;font-size:0.9rem;margin-top:4px;font-family:sans-serif;">
365
+ Upload a rice image to segment each grain and generate a report.
366
  </p>
367
  </div>
368
  """)
 
397
 
398
  # ── RIGHT COLUMN ──────────────────────────────────────────────────
399
  with gr.Column(scale=1):
400
+ gr.Markdown("### Report")
401
+ report_out = gr.Image(
402
+ label="",
 
 
 
 
 
 
403
  interactive=False,
404
+
405
  )
406
 
407
  run_btn.click(
408
  fn=process_image,
409
  inputs=[inp_image],
410
+ outputs=[report_out, status_box],
411
  )
412
 
413
  if __name__ == "__main__":
414
+ demo.launch(share=True, css=CSS)