| | import os |
| | import cv2 |
| | import tempfile |
| | import numpy as np |
| | from PIL import Image, UnidentifiedImageError |
| | import torch |
| | from torchvision import models, transforms |
| | from ultralytics import YOLO |
| | import gradio as gr |
| | import torch.nn as nn |
| | import pandas as pd |
| | from io import BytesIO |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | try: |
| | detection_model = YOLO('best.pt') |
| | classifier_network = models.resnet50(weights=None) |
| | classifier_network.fc = nn.Linear(classifier_network.fc.in_features, 3) |
| | classifier_network.load_state_dict( |
| | torch.load('rice_resnet_model.pth', map_location=device) |
| | ) |
| | classifier_network = classifier_network.to(device) |
| | classifier_network.eval() |
| | models_loaded = True |
| | except Exception as e: |
| | print(f"Model initialization failed: {e}") |
| | detection_model = None |
| | classifier_network = None |
| | models_loaded = False |
| |
|
| | |
| | VARIETY_MAP = { |
| | 0: "C9 Premium", |
| | 1: "Kant Special", |
| | 2: "Superfine Grade" |
| | } |
| |
|
| | VARIETY_COLORS = { |
| | "C9 Premium": (255, 100, 100), |
| | "Kant Special": (100, 100, 255), |
| | "Superfine Grade": (100, 255, 100) |
| | } |
| |
|
| | |
| | image_preprocessor = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | |
| | |
| | |
| |
|
| | def classify_grain(grain_image): |
| | """ |
| | Classify a single grain using the neural network. |
| | Returns the grain variety label. |
| | """ |
| | if not models_loaded: |
| | return "System Error" |
| |
|
| | tensor_input = image_preprocessor(grain_image).unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | output = classifier_network(tensor_input) |
| | class_idx = torch.argmax(output, dim=1).item() |
| | return VARIETY_MAP[class_idx] |
| |
|
| | def generate_distribution_report(variety_counts): |
| | """ |
| | Generate a text-based summary of grain variety distribution |
| | with total counts, percentages, and dominant variety. |
| | """ |
| | total = sum(variety_counts.values()) |
| | if total == 0: |
| | return "No grains detected for analysis." |
| |
|
| | report = ["## Grain Distribution Report\n"] |
| | report.append(f"Total Grains Detected: **{total}**\n\n") |
| | report.append("### Breakdown by Variety:\n") |
| |
|
| | for variety, count in sorted(variety_counts.items(), key=lambda x: x[1], reverse=True): |
| | percentage = (count / total) * 100 |
| | bar_length = int(percentage / 5) |
| | bar = "█" * bar_length + "░" * (20 - bar_length) |
| | report.append(f"- {variety}: {count} ({percentage:.1f}%) {bar}\n") |
| |
|
| | dominant_variety = max(variety_counts.items(), key=lambda x: x[1])[0] |
| | report.append(f"\nDominant Variety: **{dominant_variety}**\n") |
| | return "".join(report) |
| |
|
| | def generate_csv_export(grain_details): |
| | """ |
| | Convert grain detection results into a temporary CSV file for download. |
| | Returns the file path. |
| | """ |
| | if not grain_details: |
| | return None |
| |
|
| | df = pd.DataFrame(grain_details) |
| | tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='w') |
| | df.to_csv(tmp.name, index=False) |
| | tmp.close() |
| | return tmp.name |
| |
|
| | def load_image_safe(input_image): |
| | """ |
| | Safely load and validate an image from various input types. |
| | Accepts PIL Image, numpy array, or file path string. |
| | Returns a valid RGB PIL Image or raises gr.Error. |
| | """ |
| | try: |
| | if input_image is None: |
| | raise gr.Error("Please upload an image to start analysis.") |
| |
|
| | |
| | if isinstance(input_image, str): |
| | if not os.path.exists(input_image): |
| | raise gr.Error(f"Image file not found: {input_image}") |
| | img = Image.open(input_image).convert("RGB") |
| |
|
| | |
| | elif isinstance(input_image, Image.Image): |
| | img = input_image.convert("RGB") |
| |
|
| | |
| | elif isinstance(input_image, np.ndarray): |
| | img = Image.fromarray(input_image).convert("RGB") |
| |
|
| | else: |
| | raise gr.Error(f"Unsupported image type: {type(input_image)}") |
| |
|
| | return img |
| |
|
| | except UnidentifiedImageError: |
| | raise gr.Error("Could not read the image file. It may be corrupted or in an unsupported format.") |
| | except gr.Error: |
| | raise |
| | except Exception as e: |
| | raise gr.Error(f"Image loading failed: {str(e)}") |
| |
|
| | def analyze_rice_image(input_image): |
| | """ |
| | Full analysis pipeline: |
| | 1. Validate and load image |
| | 2. Detect grains |
| | 3. Classify each grain |
| | 4. Annotate image |
| | 5. Generate distribution report |
| | 6. Generate CSV export |
| | """ |
| | if not models_loaded: |
| | raise gr.Error("Analysis engine not available. Check model files.") |
| |
|
| | |
| | pil_image = load_image_safe(input_image) |
| |
|
| | |
| | img_array = np.array(pil_image) |
| | img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
| |
|
| | |
| | results = detection_model(img_bgr, verbose=False)[0] |
| | boxes = results.boxes.xyxy.cpu().numpy() |
| |
|
| | if len(boxes) == 0: |
| | return ( |
| | pil_image, |
| | "No grains detected. Try a clearer image.", |
| | None |
| | ) |
| |
|
| | |
| | variety_counts = {v: 0 for v in VARIETY_MAP.values()} |
| | grain_details = [] |
| |
|
| | for idx, box in enumerate(boxes): |
| | x1, y1, x2, y2 = map(int, box[:4]) |
| | crop = img_bgr[y1:y2, x1:x2] |
| |
|
| | if crop.shape[0] > 0 and crop.shape[1] > 0: |
| | pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) |
| | variety_label = classify_grain(pil_crop) |
| | variety_counts[variety_label] += 1 |
| |
|
| | |
| | grain_details.append({ |
| | "Grain_ID": f"G{idx+1:04d}", |
| | "Variety": variety_label, |
| | "X_center": (x1 + x2) // 2, |
| | "Y_center": (y1 + y2) // 2 |
| | }) |
| |
|
| | |
| | color = VARIETY_COLORS[variety_label] |
| | cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 3) |
| | label = variety_label |
| | (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) |
| | cv2.rectangle(img_bgr, (x1, y1 - h - 10), (x1 + w, y1), color, -1) |
| | cv2.putText(img_bgr, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
| |
|
| | |
| | report_text = generate_distribution_report(variety_counts) |
| | csv_path = generate_csv_export(grain_details) |
| |
|
| | return ( |
| | Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)), |
| | report_text, |
| | csv_path |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | custom_css = """ |
| | .gradio-container { |
| | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| | } |
| | .header-box { |
| | background: linear-gradient(135deg, #1e5631 0%, #4c9a2a 100%); |
| | padding: 25px; |
| | border-radius: 12px; |
| | color: white; |
| | text-align: center; |
| | margin-bottom: 20px; |
| | } |
| | """ |
| |
|
| | |
| | _all_samples = [ |
| | "samples/rice3.jpg", |
| | "samples/rice2.jpg", |
| | "samples/rice4.jpg", |
| | "samples/rice5.jpg" |
| | ] |
| | sample_images = [s for s in _all_samples if os.path.exists(s)] |
| |
|
| | with gr.Blocks(css=custom_css, title="Rice Classifier") as app: |
| |
|
| | gr.HTML(""" |
| | <div class="header-box"> |
| | <h1>Rice Analyzer Pro</h1> |
| | <p>Advanced Grain Classification | Rice Grain Locator</p> |
| | </div> |
| | """) |
| |
|
| | with gr.Tabs(): |
| | |
| | with gr.Tab("Analysis"): |
| | gr.Markdown(""" |
| | ### How to Use |
| | 1. Upload a clear image of rice grains. |
| | 2. Click **Start Analysis**. |
| | 3. Review annotated results, distribution report, and download CSV. |
| | |
| | **Color Coding:** Red = C9 Premium Blue = Kant Special Green = Superfine Grade |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(type="pil", label="Upload Sample Image", height=600, width=600) |
| | start_btn = gr.Button("Start Analysis", variant="primary", size="lg") |
| |
|
| | with gr.Column(scale=1): |
| | |
| | annotated_output = gr.Image(label="Annotated Results", height=600, width=600) |
| |
|
| | with gr.Row(): |
| | report_output = gr.Markdown(label="Distribution Report") |
| |
|
| | with gr.Row(): |
| | |
| | csv_output = gr.File(label="Download CSV Export") |
| |
|
| | start_btn.click( |
| | fn=analyze_rice_image, |
| | inputs=image_input, |
| | outputs=[annotated_output, report_output, csv_output] |
| | ) |
| |
|
| | |
| | with gr.Tab("Documentation"): |
| | gr.Markdown(""" |
| | ## System Overview |
| | |
| | Rice Classifier uses a deep learning pipeline: |
| | |
| | 1. **Grain Detection:** YOLO-based model identifies rice grains. |
| | 2. **Grain Classification:** ResNet50 model classifies grains into three varieties. |
| | 3. **CSV Export:** Detailed grain data available for download. |
| | |
| | ### Supported Varieties |
| | | Variety | Description | |
| | |---------|-------------| |
| | | C9 Premium | High-quality long grain | |
| | | Kant Special | Medium grain specialty | |
| | | Superfine Grade | Ultra-refined grain | |
| | |
| | ### Best Practices |
| | - Use well-lit images without shadows |
| | - Keep grains separated |
| | - Use plain backgrounds |
| | - Resolution: 1024x1024 or higher for best results |
| | |
| | ### Technical Details |
| | - Detection: YOLOv8 |
| | - Classification: ResNet50 fine-tuned |
| | - GPU recommended for faster processing |
| | """) |
| |
|
| | gr.Markdown("---") |
| |
|
| | if sample_images: |
| | gr.Markdown("### Sample Gallery") |
| | gr.Examples( |
| | examples=sample_images, |
| | inputs=image_input, |
| | outputs=[annotated_output, report_output, csv_output], |
| | fn=analyze_rice_image, |
| | cache_examples=False, |
| | label="Click any sample to run analysis" |
| | ) |
| | else: |
| | gr.Markdown("*No sample images found. Add images to the `samples/` folder.*") |
| |
|
| | if __name__ == "__main__": |
| | app.queue() |
| | app.launch() |