muhammadhamza-stack commited on
Commit
64af005
·
1 Parent(s): 697fc7d

resolve the image loading issue

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +424 -79
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py CHANGED
@@ -1,6 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
 
2
  import numpy as np
3
- from PIL import Image
4
  import torch
5
  from torchvision import models, transforms
6
  from ultralytics import YOLO
@@ -65,7 +367,7 @@ def classify_grain(grain_image):
65
  """
66
  if not models_loaded:
67
  return "System Error"
68
-
69
  tensor_input = image_preprocessor(grain_image).unsqueeze(0).to(device)
70
  with torch.no_grad():
71
  output = classifier_network(tensor_input)
@@ -80,101 +382,139 @@ def generate_distribution_report(variety_counts):
80
  total = sum(variety_counts.values())
81
  if total == 0:
82
  return "No grains detected for analysis."
83
-
84
  report = ["## Grain Distribution Report\n"]
85
  report.append(f"Total Grains Detected: **{total}**\n\n")
86
  report.append("### Breakdown by Variety:\n")
87
-
88
  for variety, count in sorted(variety_counts.items(), key=lambda x: x[1], reverse=True):
89
  percentage = (count / total) * 100
90
  bar_length = int(percentage / 5)
91
  bar = "█" * bar_length + "░" * (20 - bar_length)
92
  report.append(f"- {variety}: {count} ({percentage:.1f}%) {bar}\n")
93
-
94
  dominant_variety = max(variety_counts.items(), key=lambda x: x[1])[0]
95
  report.append(f"\nDominant Variety: **{dominant_variety}**\n")
96
  return "".join(report)
97
 
98
  def generate_csv_export(grain_details):
99
  """
100
- Convert grain detection results into CSV format for export.
 
101
  """
102
  if not grain_details:
103
  return None
104
-
105
  df = pd.DataFrame(grain_details)
106
- csv_buffer = BytesIO()
107
- df.to_csv(csv_buffer, index=False)
108
- csv_buffer.seek(0)
109
- return csv_buffer.getvalue().decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def analyze_rice_image(input_image):
112
  """
113
  Full analysis pipeline:
114
- 1. Detect grains
115
- 2. Classify each grain
116
- 3. Annotate image
117
- 4. Generate distribution report
118
- 5. Generate CSV export
 
119
  """
120
  if not models_loaded:
121
  raise gr.Error("Analysis engine not available. Check model files.")
122
-
123
- if input_image is None:
124
- raise gr.Error("Please upload an image to start analysis.")
125
-
126
  # Convert PIL image to BGR array for OpenCV
127
- img_array = np.array(input_image)
128
  img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
129
-
130
  # Step 1: Detect grains
131
  results = detection_model(img_bgr, verbose=False)[0]
132
  boxes = results.boxes.xyxy.cpu().numpy()
133
-
134
  if len(boxes) == 0:
135
  return (
136
- Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)),
137
  "No grains detected. Try a clearer image.",
138
  None
139
  )
140
-
141
  # Step 2: Classify grains
142
  variety_counts = {v: 0 for v in VARIETY_MAP.values()}
143
  grain_details = []
144
-
145
  for idx, box in enumerate(boxes):
146
  x1, y1, x2, y2 = map(int, box[:4])
147
  crop = img_bgr[y1:y2, x1:x2]
148
-
149
  if crop.shape[0] > 0 and crop.shape[1] > 0:
150
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
151
  variety_label = classify_grain(pil_crop)
152
  variety_counts[variety_label] += 1
153
-
154
  # Save details for CSV export
155
  grain_details.append({
156
  "Grain_ID": f"G{idx+1:04d}",
157
  "Variety": variety_label,
158
- "X_center": (x1 + x2)//2,
159
- "Y_center": (y1 + y2)//2
160
  })
161
-
162
  # Annotate image
163
  color = VARIETY_COLORS[variety_label]
164
  cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 3)
165
  label = variety_label
166
  (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
167
- cv2.rectangle(img_bgr, (x1, y1-h-10), (x1+w, y1), color, -1)
168
- cv2.putText(img_bgr, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
169
-
170
- # Step 3: Generate analytics report
171
  report_text = generate_distribution_report(variety_counts)
172
- csv_export = generate_csv_export(grain_details)
173
-
174
  return (
175
  Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)),
176
  report_text,
177
- csv_export
178
  )
179
 
180
  # ============================================
@@ -195,76 +535,82 @@ custom_css = """
195
  }
196
  """
197
 
 
 
 
 
 
 
 
 
 
198
  with gr.Blocks(css=custom_css, title="Rice Classifier") as app:
199
-
200
  gr.HTML("""
201
  <div class="header-box">
202
  <h1>Rice Analyzer Pro</h1>
203
- <p>Advanced Grain Classification | Rice Grain Locattor</p>
204
  </div>
205
  """)
206
-
207
  with gr.Tabs():
208
  # Analysis Tab
209
  with gr.Tab("Analysis"):
210
  gr.Markdown("""
211
  ### How to Use
212
  1. Upload a clear image of rice grains.
213
- 2. Click "Start Analysis".
214
- 3. Review annotated results, distribution report, and export CSV.
215
-
216
- **Color Coding:** Red = C9 Premium, Blue = Kant Special, Green = Superfine Grade
217
  """)
218
-
219
  with gr.Row():
220
  with gr.Column(scale=1):
221
- image_input = gr.Image(type="pil", label="Sample Image")
222
  start_btn = gr.Button("Start Analysis", variant="primary", size="lg")
223
-
224
- #show the annotated image in specific width and height
225
  with gr.Column(scale=1):
 
226
  annotated_output = gr.Image(label="Annotated Results", height=600, width=600)
227
-
228
  with gr.Row():
229
  report_output = gr.Markdown(label="Distribution Report")
230
-
231
  with gr.Row():
232
- csv_output = gr.Textbox(
233
- label="CSV Export (Copy or Save)",
234
- lines=8,
235
- max_lines=15,
236
- )
237
-
238
  start_btn.click(
239
  fn=analyze_rice_image,
240
  inputs=image_input,
241
  outputs=[annotated_output, report_output, csv_output]
242
  )
243
-
244
  # Documentation Tab
245
  with gr.Tab("Documentation"):
246
  gr.Markdown("""
247
  ## System Overview
248
-
249
  Rice Classifier uses a deep learning pipeline:
250
-
251
  1. **Grain Detection:** YOLO-based model identifies rice grains.
252
  2. **Grain Classification:** ResNet50 model classifies grains into three varieties.
253
- 3. **CSV Export:** Detailed grain data available for download or copy.
254
-
255
  ### Supported Varieties
256
  | Variety | Description |
257
  |---------|-------------|
258
  | C9 Premium | High-quality long grain |
259
  | Kant Special | Medium grain specialty |
260
  | Superfine Grade | Ultra-refined grain |
261
-
262
  ### Best Practices
263
  - Use well-lit images without shadows
264
  - Keep grains separated
265
  - Use plain backgrounds
266
  - Resolution: 1024x1024 or higher for best results
267
-
268
  ### Technical Details
269
  - Detection: YOLOv8
270
  - Classification: ResNet50 fine-tuned
@@ -272,21 +618,20 @@ with gr.Blocks(css=custom_css, title="Rice Classifier") as app:
272
  """)
273
 
274
  gr.Markdown("---")
275
- gr.Markdown("### Sample Gallery")
276
-
277
- gr.Examples(
278
- examples=[
279
- "samples/rice1.jpg",
280
- "samples/rice2.jpg",
281
- "samples/rice4.jpg",
282
- "samples/rice5.jpg"
283
- ],
284
- inputs=image_input,
285
- outputs=[annotated_output, report_output, csv_output],
286
- fn=analyze_rice_image,
287
- label="Click any sample to run analysis"
288
- )
289
 
290
  if __name__ == "__main__":
291
  app.queue()
292
- app.launch()
 
1
+ # import cv2
2
+ # import numpy as np
3
+ # from PIL import Image
4
+ # import torch
5
+ # from torchvision import models, transforms
6
+ # from ultralytics import YOLO
7
+ # import gradio as gr
8
+ # import torch.nn as nn
9
+ # import pandas as pd
10
+ # from io import BytesIO
11
+
12
+ # # ============================================
13
+ # # RICE ANALYZER PRO
14
+ # # Advanced Grain Analytics and Quality Assessment Platform
15
+ # # ============================================
16
+
17
+ # # --- SYSTEM CONFIGURATION ---
18
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ # # Initialize detection and classification models
21
+ # try:
22
+ # detection_model = YOLO('best.pt')
23
+ # classifier_network = models.resnet50(weights=None)
24
+ # classifier_network.fc = nn.Linear(classifier_network.fc.in_features, 3)
25
+ # classifier_network.load_state_dict(
26
+ # torch.load('rice_resnet_model.pth', map_location=device)
27
+ # )
28
+ # classifier_network = classifier_network.to(device)
29
+ # classifier_network.eval()
30
+ # models_loaded = True
31
+ # except Exception as e:
32
+ # print(f"Model initialization failed: {e}")
33
+ # detection_model = None
34
+ # classifier_network = None
35
+ # models_loaded = False
36
+
37
+ # # --- VARIETY DEFINITIONS ---
38
+ # VARIETY_MAP = {
39
+ # 0: "C9 Premium",
40
+ # 1: "Kant Special",
41
+ # 2: "Superfine Grade"
42
+ # }
43
+
44
+ # VARIETY_COLORS = {
45
+ # "C9 Premium": (255, 100, 100), # Red
46
+ # "Kant Special": (100, 100, 255), # Blue
47
+ # "Superfine Grade": (100, 255, 100) # Green
48
+ # }
49
+
50
+ # # --- IMAGE PREPROCESSING ---
51
+ # image_preprocessor = transforms.Compose([
52
+ # transforms.Resize((224, 224)),
53
+ # transforms.ToTensor(),
54
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
55
+ # ])
56
+
57
+ # # ============================================
58
+ # # ANALYTICS FUNCTIONS
59
+ # # ============================================
60
+
61
+ # def classify_grain(grain_image):
62
+ # """
63
+ # Classify a single grain using the neural network.
64
+ # Returns the grain variety label.
65
+ # """
66
+ # if not models_loaded:
67
+ # return "System Error"
68
+
69
+ # tensor_input = image_preprocessor(grain_image).unsqueeze(0).to(device)
70
+ # with torch.no_grad():
71
+ # output = classifier_network(tensor_input)
72
+ # class_idx = torch.argmax(output, dim=1).item()
73
+ # return VARIETY_MAP[class_idx]
74
+
75
+ # def generate_distribution_report(variety_counts):
76
+ # """
77
+ # Generate a text-based summary of grain variety distribution
78
+ # with total counts, percentages, and dominant variety.
79
+ # """
80
+ # total = sum(variety_counts.values())
81
+ # if total == 0:
82
+ # return "No grains detected for analysis."
83
+
84
+ # report = ["## Grain Distribution Report\n"]
85
+ # report.append(f"Total Grains Detected: **{total}**\n\n")
86
+ # report.append("### Breakdown by Variety:\n")
87
+
88
+ # for variety, count in sorted(variety_counts.items(), key=lambda x: x[1], reverse=True):
89
+ # percentage = (count / total) * 100
90
+ # bar_length = int(percentage / 5)
91
+ # bar = "█" * bar_length + "░" * (20 - bar_length)
92
+ # report.append(f"- {variety}: {count} ({percentage:.1f}%) {bar}\n")
93
+
94
+ # dominant_variety = max(variety_counts.items(), key=lambda x: x[1])[0]
95
+ # report.append(f"\nDominant Variety: **{dominant_variety}**\n")
96
+ # return "".join(report)
97
+
98
+ # def generate_csv_export(grain_details):
99
+ # """
100
+ # Convert grain detection results into CSV format for export.
101
+ # """
102
+ # if not grain_details:
103
+ # return None
104
+
105
+ # df = pd.DataFrame(grain_details)
106
+ # csv_buffer = BytesIO()
107
+ # df.to_csv(csv_buffer, index=False)
108
+ # csv_buffer.seek(0)
109
+ # return csv_buffer.getvalue().decode()
110
+
111
+ # def analyze_rice_image(input_image):
112
+ # """
113
+ # Full analysis pipeline:
114
+ # 1. Detect grains
115
+ # 2. Classify each grain
116
+ # 3. Annotate image
117
+ # 4. Generate distribution report
118
+ # 5. Generate CSV export
119
+ # """
120
+ # if not models_loaded:
121
+ # raise gr.Error("Analysis engine not available. Check model files.")
122
+
123
+ # if input_image is None:
124
+ # raise gr.Error("Please upload an image to start analysis.")
125
+
126
+ # # Convert PIL image to BGR array for OpenCV
127
+ # img_array = np.array(input_image)
128
+ # img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
129
+
130
+ # # Step 1: Detect grains
131
+ # results = detection_model(img_bgr, verbose=False)[0]
132
+ # boxes = results.boxes.xyxy.cpu().numpy()
133
+
134
+ # if len(boxes) == 0:
135
+ # return (
136
+ # Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)),
137
+ # "No grains detected. Try a clearer image.",
138
+ # None
139
+ # )
140
+
141
+ # # Step 2: Classify grains
142
+ # variety_counts = {v: 0 for v in VARIETY_MAP.values()}
143
+ # grain_details = []
144
+
145
+ # for idx, box in enumerate(boxes):
146
+ # x1, y1, x2, y2 = map(int, box[:4])
147
+ # crop = img_bgr[y1:y2, x1:x2]
148
+
149
+ # if crop.shape[0] > 0 and crop.shape[1] > 0:
150
+ # pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
151
+ # variety_label = classify_grain(pil_crop)
152
+ # variety_counts[variety_label] += 1
153
+
154
+ # # Save details for CSV export
155
+ # grain_details.append({
156
+ # "Grain_ID": f"G{idx+1:04d}",
157
+ # "Variety": variety_label,
158
+ # "X_center": (x1 + x2)//2,
159
+ # "Y_center": (y1 + y2)//2
160
+ # })
161
+
162
+ # # Annotate image
163
+ # color = VARIETY_COLORS[variety_label]
164
+ # cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 3)
165
+ # label = variety_label
166
+ # (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
167
+ # cv2.rectangle(img_bgr, (x1, y1-h-10), (x1+w, y1), color, -1)
168
+ # cv2.putText(img_bgr, label, (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
169
+
170
+ # # Step 3: Generate analytics report
171
+ # report_text = generate_distribution_report(variety_counts)
172
+ # csv_export = generate_csv_export(grain_details)
173
+
174
+ # return (
175
+ # Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)),
176
+ # report_text,
177
+ # csv_export
178
+ # )
179
+
180
+ # # ============================================
181
+ # # GRADIO USER INTERFACE
182
+ # # ============================================
183
+
184
+ # custom_css = """
185
+ # .gradio-container {
186
+ # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
187
+ # }
188
+ # .header-box {
189
+ # background: linear-gradient(135deg, #1e5631 0%, #4c9a2a 100%);
190
+ # padding: 25px;
191
+ # border-radius: 12px;
192
+ # color: white;
193
+ # text-align: center;
194
+ # margin-bottom: 20px;
195
+ # }
196
+ # """
197
+
198
+ # with gr.Blocks(css=custom_css, title="Rice Classifier") as app:
199
+
200
+ # gr.HTML("""
201
+ # <div class="header-box">
202
+ # <h1>Rice Analyzer Pro</h1>
203
+ # <p>Advanced Grain Classification | Rice Grain Locattor</p>
204
+ # </div>
205
+ # """)
206
+
207
+ # with gr.Tabs():
208
+ # # Analysis Tab
209
+ # with gr.Tab("Analysis"):
210
+ # gr.Markdown("""
211
+ # ### How to Use
212
+ # 1. Upload a clear image of rice grains.
213
+ # 2. Click "Start Analysis".
214
+ # 3. Review annotated results, distribution report, and export CSV.
215
+
216
+ # **Color Coding:** Red = C9 Premium, Blue = Kant Special, Green = Superfine Grade
217
+ # """)
218
+
219
+ # with gr.Row():
220
+ # with gr.Column(scale=1):
221
+ # image_input = gr.Image(type="pil", label="Sample Image")
222
+ # start_btn = gr.Button("Start Analysis", variant="primary", size="lg")
223
+
224
+ # #show the annotated image in specific width and height
225
+ # with gr.Column(scale=1):
226
+ # annotated_output = gr.Image(label="Annotated Results", height=600, width=600)
227
+
228
+ # with gr.Row():
229
+ # report_output = gr.Markdown(label="Distribution Report")
230
+
231
+ # with gr.Row():
232
+ # csv_output = gr.Textbox(
233
+ # label="CSV Export (Copy or Save)",
234
+ # lines=8,
235
+ # max_lines=15,
236
+ # )
237
+
238
+ # start_btn.click(
239
+ # fn=analyze_rice_image,
240
+ # inputs=image_input,
241
+ # outputs=[annotated_output, report_output, csv_output]
242
+ # )
243
+
244
+ # # Documentation Tab
245
+ # with gr.Tab("Documentation"):
246
+ # gr.Markdown("""
247
+ # ## System Overview
248
+
249
+ # Rice Classifier uses a deep learning pipeline:
250
+
251
+ # 1. **Grain Detection:** YOLO-based model identifies rice grains.
252
+ # 2. **Grain Classification:** ResNet50 model classifies grains into three varieties.
253
+ # 3. **CSV Export:** Detailed grain data available for download or copy.
254
+
255
+ # ### Supported Varieties
256
+ # | Variety | Description |
257
+ # |---------|-------------|
258
+ # | C9 Premium | High-quality long grain |
259
+ # | Kant Special | Medium grain specialty |
260
+ # | Superfine Grade | Ultra-refined grain |
261
+
262
+ # ### Best Practices
263
+ # - Use well-lit images without shadows
264
+ # - Keep grains separated
265
+ # - Use plain backgrounds
266
+ # - Resolution: 1024x1024 or higher for best results
267
+
268
+ # ### Technical Details
269
+ # - Detection: YOLOv8
270
+ # - Classification: ResNet50 fine-tuned
271
+ # - GPU recommended for faster processing
272
+ # """)
273
+
274
+ # gr.Markdown("---")
275
+ # gr.Markdown("### Sample Gallery")
276
+
277
+ # gr.Examples(
278
+ # examples=[
279
+ # "samples/rice1.jpg",
280
+ # "samples/rice2.jpg",
281
+ # "samples/rice4.jpg",
282
+ # "samples/rice5.jpg"
283
+ # ],
284
+ # inputs=image_input,
285
+ # outputs=[annotated_output, report_output, csv_output],
286
+ # fn=analyze_rice_image,
287
+ # label="Click any sample to run analysis"
288
+ # )
289
+
290
+ # if __name__ == "__main__":
291
+ # app.queue()
292
+ # app.launch()
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+ import os
302
  import cv2
303
+ import tempfile
304
  import numpy as np
305
+ from PIL import Image, UnidentifiedImageError
306
  import torch
307
  from torchvision import models, transforms
308
  from ultralytics import YOLO
 
367
  """
368
  if not models_loaded:
369
  return "System Error"
370
+
371
  tensor_input = image_preprocessor(grain_image).unsqueeze(0).to(device)
372
  with torch.no_grad():
373
  output = classifier_network(tensor_input)
 
382
  total = sum(variety_counts.values())
383
  if total == 0:
384
  return "No grains detected for analysis."
385
+
386
  report = ["## Grain Distribution Report\n"]
387
  report.append(f"Total Grains Detected: **{total}**\n\n")
388
  report.append("### Breakdown by Variety:\n")
389
+
390
  for variety, count in sorted(variety_counts.items(), key=lambda x: x[1], reverse=True):
391
  percentage = (count / total) * 100
392
  bar_length = int(percentage / 5)
393
  bar = "█" * bar_length + "░" * (20 - bar_length)
394
  report.append(f"- {variety}: {count} ({percentage:.1f}%) {bar}\n")
395
+
396
  dominant_variety = max(variety_counts.items(), key=lambda x: x[1])[0]
397
  report.append(f"\nDominant Variety: **{dominant_variety}**\n")
398
  return "".join(report)
399
 
400
  def generate_csv_export(grain_details):
401
  """
402
+ Convert grain detection results into a temporary CSV file for download.
403
+ Returns the file path.
404
  """
405
  if not grain_details:
406
  return None
407
+
408
  df = pd.DataFrame(grain_details)
409
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w')
410
+ df.to_csv(tmp.name, index=False)
411
+ tmp.close()
412
+ return tmp.name
413
+
414
+ def load_image_safe(input_image):
415
+ """
416
+ Safely load and validate an image from various input types.
417
+ Accepts PIL Image, numpy array, or file path string.
418
+ Returns a valid RGB PIL Image or raises gr.Error.
419
+ """
420
+ try:
421
+ if input_image is None:
422
+ raise gr.Error("Please upload an image to start analysis.")
423
+
424
+ # If it's a file path string (e.g. from gr.Examples)
425
+ if isinstance(input_image, str):
426
+ if not os.path.exists(input_image):
427
+ raise gr.Error(f"Image file not found: {input_image}")
428
+ img = Image.open(input_image).convert("RGB")
429
+
430
+ # If it's already a PIL Image
431
+ elif isinstance(input_image, Image.Image):
432
+ img = input_image.convert("RGB")
433
+
434
+ # If it's a numpy array
435
+ elif isinstance(input_image, np.ndarray):
436
+ img = Image.fromarray(input_image).convert("RGB")
437
+
438
+ else:
439
+ raise gr.Error(f"Unsupported image type: {type(input_image)}")
440
+
441
+ return img
442
+
443
+ except UnidentifiedImageError:
444
+ raise gr.Error("Could not read the image file. It may be corrupted or in an unsupported format.")
445
+ except gr.Error:
446
+ raise
447
+ except Exception as e:
448
+ raise gr.Error(f"Image loading failed: {str(e)}")
449
 
450
  def analyze_rice_image(input_image):
451
  """
452
  Full analysis pipeline:
453
+ 1. Validate and load image
454
+ 2. Detect grains
455
+ 3. Classify each grain
456
+ 4. Annotate image
457
+ 5. Generate distribution report
458
+ 6. Generate CSV export
459
  """
460
  if not models_loaded:
461
  raise gr.Error("Analysis engine not available. Check model files.")
462
+
463
+ # Safely load and validate the image
464
+ pil_image = load_image_safe(input_image)
465
+
466
  # Convert PIL image to BGR array for OpenCV
467
+ img_array = np.array(pil_image)
468
  img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
469
+
470
  # Step 1: Detect grains
471
  results = detection_model(img_bgr, verbose=False)[0]
472
  boxes = results.boxes.xyxy.cpu().numpy()
473
+
474
  if len(boxes) == 0:
475
  return (
476
+ pil_image,
477
  "No grains detected. Try a clearer image.",
478
  None
479
  )
480
+
481
  # Step 2: Classify grains
482
  variety_counts = {v: 0 for v in VARIETY_MAP.values()}
483
  grain_details = []
484
+
485
  for idx, box in enumerate(boxes):
486
  x1, y1, x2, y2 = map(int, box[:4])
487
  crop = img_bgr[y1:y2, x1:x2]
488
+
489
  if crop.shape[0] > 0 and crop.shape[1] > 0:
490
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
491
  variety_label = classify_grain(pil_crop)
492
  variety_counts[variety_label] += 1
493
+
494
  # Save details for CSV export
495
  grain_details.append({
496
  "Grain_ID": f"G{idx+1:04d}",
497
  "Variety": variety_label,
498
+ "X_center": (x1 + x2) // 2,
499
+ "Y_center": (y1 + y2) // 2
500
  })
501
+
502
  # Annotate image
503
  color = VARIETY_COLORS[variety_label]
504
  cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 3)
505
  label = variety_label
506
  (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
507
+ cv2.rectangle(img_bgr, (x1, y1 - h - 10), (x1 + w, y1), color, -1)
508
+ cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
509
+
510
+ # Step 3: Generate analytics
511
  report_text = generate_distribution_report(variety_counts)
512
+ csv_path = generate_csv_export(grain_details)
513
+
514
  return (
515
  Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)),
516
  report_text,
517
+ csv_path
518
  )
519
 
520
  # ============================================
 
535
  }
536
  """
537
 
538
+ # Only include sample images that actually exist on disk
539
+ _all_samples = [
540
+ "samples/rice1.jpg",
541
+ "samples/rice2.jpg",
542
+ "samples/rice4.jpg",
543
+ "samples/rice5.jpg"
544
+ ]
545
+ sample_images = [s for s in _all_samples if os.path.exists(s)]
546
+
547
  with gr.Blocks(css=custom_css, title="Rice Classifier") as app:
548
+
549
  gr.HTML("""
550
  <div class="header-box">
551
  <h1>Rice Analyzer Pro</h1>
552
+ <p>Advanced Grain Classification | Rice Grain Locator</p>
553
  </div>
554
  """)
555
+
556
  with gr.Tabs():
557
  # Analysis Tab
558
  with gr.Tab("Analysis"):
559
  gr.Markdown("""
560
  ### How to Use
561
  1. Upload a clear image of rice grains.
562
+ 2. Click **Start Analysis**.
563
+ 3. Review annotated results, distribution report, and download CSV.
564
+
565
+ **Color Coding:** Red = C9 Premium &nbsp; Blue = Kant Special &nbsp; Green = Superfine Grade
566
  """)
567
+
568
  with gr.Row():
569
  with gr.Column(scale=1):
570
+ image_input = gr.Image(type="pil", label="Upload Sample Image")
571
  start_btn = gr.Button("Start Analysis", variant="primary", size="lg")
572
+
 
573
  with gr.Column(scale=1):
574
+ # Removed unsupported `width` parameter
575
  annotated_output = gr.Image(label="Annotated Results", height=600, width=600)
576
+
577
  with gr.Row():
578
  report_output = gr.Markdown(label="Distribution Report")
579
+
580
  with gr.Row():
581
+ # Changed to gr.File so users can download the CSV properly
582
+ csv_output = gr.File(label="Download CSV Export")
583
+
 
 
 
584
  start_btn.click(
585
  fn=analyze_rice_image,
586
  inputs=image_input,
587
  outputs=[annotated_output, report_output, csv_output]
588
  )
589
+
590
  # Documentation Tab
591
  with gr.Tab("Documentation"):
592
  gr.Markdown("""
593
  ## System Overview
594
+
595
  Rice Classifier uses a deep learning pipeline:
596
+
597
  1. **Grain Detection:** YOLO-based model identifies rice grains.
598
  2. **Grain Classification:** ResNet50 model classifies grains into three varieties.
599
+ 3. **CSV Export:** Detailed grain data available for download.
600
+
601
  ### Supported Varieties
602
  | Variety | Description |
603
  |---------|-------------|
604
  | C9 Premium | High-quality long grain |
605
  | Kant Special | Medium grain specialty |
606
  | Superfine Grade | Ultra-refined grain |
607
+
608
  ### Best Practices
609
  - Use well-lit images without shadows
610
  - Keep grains separated
611
  - Use plain backgrounds
612
  - Resolution: 1024x1024 or higher for best results
613
+
614
  ### Technical Details
615
  - Detection: YOLOv8
616
  - Classification: ResNet50 fine-tuned
 
618
  """)
619
 
620
  gr.Markdown("---")
621
+
622
+ if sample_images:
623
+ gr.Markdown("### Sample Gallery")
624
+ gr.Examples(
625
+ examples=sample_images,
626
+ inputs=image_input,
627
+ outputs=[annotated_output, report_output, csv_output],
628
+ fn=analyze_rice_image,
629
+ cache_examples=False, # Prevents stale/corrupted cache issues
630
+ label="Click any sample to run analysis"
631
+ )
632
+ else:
633
+ gr.Markdown("*No sample images found. Add images to the `samples/` folder.*")
 
634
 
635
  if __name__ == "__main__":
636
  app.queue()
637
+ app.launch()