Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoImageProcessor, SiglipForImageClassification | |
| from PIL import Image | |
| import torch | |
| import json | |
| import base64 | |
| import io | |
| import numpy as np | |
| model_name = "prithivMLmods/Rice-Leaf-Disease" | |
| model = SiglipForImageClassification.from_pretrained(model_name) | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model.eval() | |
| labels = {0: "Bacterial Leaf Blight", 1: "Rice Blast", 2: "Brown Spot", 3: "Healthy", 4: "Rice Tungro Virus"} | |
| def is_likely_rice_leaf(image): | |
| """Check if image is obviously NOT a plant (laptop, face, etc). | |
| PERMISSIVE: only rejects clearly non-plant images. | |
| Diseased leaves can be brown/yellow/dry — we must accept those.""" | |
| img = image.resize((64, 64)) | |
| arr = np.array(img).astype(float) | |
| r, g, b = arr[:,:,0], arr[:,:,1], arr[:,:,2] | |
| total_pixels = r.size | |
| # Green pixels: green channel dominates (healthy leaves) | |
| green_mask = (g > r + 5) & (g > b + 5) & (g > 30) | |
| green_ratio = np.sum(green_mask) / total_pixels | |
| # Brown/tan pixels: diseased or dry leaves | |
| brown_mask = (r > b + 10) & (r > 40) & (r < 220) & (g > 30) | |
| brown_ratio = np.sum(brown_mask) / total_pixels | |
| # Yellow/orange: tungro, nutrient deficiency | |
| yellow_mask = (r > 80) & (g > 60) & (b < g) & (np.abs(r.astype(int) - g.astype(int)) < 60) | |
| yellow_ratio = np.sum(yellow_mask) / total_pixels | |
| # Any natural/organic color | |
| natural_ratio = green_ratio + brown_ratio * 0.8 + yellow_ratio * 0.6 | |
| # Pure gray/metallic: electronics, concrete (all channels nearly equal AND not bright white) | |
| gray_mask = (np.abs(r.astype(int) - g.astype(int)) < 12) & (np.abs(g.astype(int) - b.astype(int)) < 12) & (r < 200) | |
| gray_ratio = np.sum(gray_mask) / total_pixels | |
| # Pure black: screens, dark objects | |
| black_mask = (r < 30) & (g < 30) & (b < 30) | |
| black_ratio = np.sum(black_mask) / total_pixels | |
| # Only reject if image is CLEARLY not a plant: | |
| # - More than 60% gray/metallic AND less than 8% natural colors | |
| # - Or more than 50% pure black (screens, dark electronics) | |
| is_not_plant = (gray_ratio > 0.6 and natural_ratio < 0.08) or (black_ratio > 0.5) | |
| return not is_not_plant, round(natural_ratio * 100, 1), round(gray_ratio * 100, 1) | |
| def classify(image): | |
| image = Image.fromarray(image).convert("RGB") | |
| # Step 1: Color validation (only rejects obvious non-plants) | |
| is_leaf, plant_pct, gray_pct = is_likely_rice_leaf(image) | |
| # Step 2: Run disease model | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() | |
| results = [] | |
| for i in range(len(probs)): | |
| results.append({"label": labels[i], "score": round(probs[i], 4)}) | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| top = results[0] | |
| return json.dumps({ | |
| "disease": top["label"], | |
| "confidence": round(top["score"] * 100, 1), | |
| "probabilities": results, | |
| "is_valid_rice_leaf": is_leaf, | |
| "validation_message": None if is_leaf else f"This does not appear to be a rice leaf (natural colors: {plant_pct}%, gray: {gray_pct}%). Please scan a rice leaf.", | |
| "plant_pixel_ratio": plant_pct, | |
| "gray_pixel_ratio": gray_pct | |
| }) | |
| def detect_base64(img_b64): | |
| try: | |
| if ',' in img_b64: | |
| img_b64 = img_b64.split(',')[1] | |
| img_bytes = base64.b64decode(img_b64) | |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| return classify(np.array(image)) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| demo = gr.TabbedInterface( | |
| [ | |
| gr.Interface(fn=classify, inputs=gr.Image(), outputs=gr.Textbox(label="Result"), title="RiceGuard Disease Detection"), | |
| gr.Interface(fn=detect_base64, inputs=gr.Textbox(label="Base64 Image"), outputs=gr.Textbox(label="Result"), title="API") | |
| ], | |
| ["Upload Image", "Base64 API"] | |
| ) | |
| demo.launch() |