import gradio as gr from transformers import AutoImageProcessor, SiglipForImageClassification from PIL import Image import torch import json import base64 import io import numpy as np model_name = "prithivMLmods/Rice-Leaf-Disease" model = SiglipForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) model.eval() labels = {0: "Bacterial Leaf Blight", 1: "Rice Blast", 2: "Brown Spot", 3: "Healthy", 4: "Rice Tungro Virus"} def is_likely_rice_leaf(image): """Check if image is obviously NOT a plant (laptop, face, etc). PERMISSIVE: only rejects clearly non-plant images. Diseased leaves can be brown/yellow/dry — we must accept those.""" img = image.resize((64, 64)) arr = np.array(img).astype(float) r, g, b = arr[:,:,0], arr[:,:,1], arr[:,:,2] total_pixels = r.size # Green pixels: green channel dominates (healthy leaves) green_mask = (g > r + 5) & (g > b + 5) & (g > 30) green_ratio = np.sum(green_mask) / total_pixels # Brown/tan pixels: diseased or dry leaves brown_mask = (r > b + 10) & (r > 40) & (r < 220) & (g > 30) brown_ratio = np.sum(brown_mask) / total_pixels # Yellow/orange: tungro, nutrient deficiency yellow_mask = (r > 80) & (g > 60) & (b < g) & (np.abs(r.astype(int) - g.astype(int)) < 60) yellow_ratio = np.sum(yellow_mask) / total_pixels # Any natural/organic color natural_ratio = green_ratio + brown_ratio * 0.8 + yellow_ratio * 0.6 # Pure gray/metallic: electronics, concrete (all channels nearly equal AND not bright white) gray_mask = (np.abs(r.astype(int) - g.astype(int)) < 12) & (np.abs(g.astype(int) - b.astype(int)) < 12) & (r < 200) gray_ratio = np.sum(gray_mask) / total_pixels # Pure black: screens, dark objects black_mask = (r < 30) & (g < 30) & (b < 30) black_ratio = np.sum(black_mask) / total_pixels # Only reject if image is CLEARLY not a plant: # - More than 60% gray/metallic AND less than 8% natural colors # - Or more than 50% pure black (screens, dark electronics) is_not_plant = (gray_ratio > 0.6 and natural_ratio < 0.08) or (black_ratio > 0.5) return not is_not_plant, round(natural_ratio * 100, 1), round(gray_ratio * 100, 1) def classify(image): image = Image.fromarray(image).convert("RGB") # Step 1: Color validation (only rejects obvious non-plants) is_leaf, plant_pct, gray_pct = is_likely_rice_leaf(image) # Step 2: Run disease model inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() results = [] for i in range(len(probs)): results.append({"label": labels[i], "score": round(probs[i], 4)}) results.sort(key=lambda x: x["score"], reverse=True) top = results[0] return json.dumps({ "disease": top["label"], "confidence": round(top["score"] * 100, 1), "probabilities": results, "is_valid_rice_leaf": is_leaf, "validation_message": None if is_leaf else f"This does not appear to be a rice leaf (natural colors: {plant_pct}%, gray: {gray_pct}%). Please scan a rice leaf.", "plant_pixel_ratio": plant_pct, "gray_pixel_ratio": gray_pct }) def detect_base64(img_b64): try: if ',' in img_b64: img_b64 = img_b64.split(',')[1] img_bytes = base64.b64decode(img_b64) image = Image.open(io.BytesIO(img_bytes)).convert("RGB") return classify(np.array(image)) except Exception as e: return json.dumps({"error": str(e)}) demo = gr.TabbedInterface( [ gr.Interface(fn=classify, inputs=gr.Image(), outputs=gr.Textbox(label="Result"), title="RiceGuard Disease Detection"), gr.Interface(fn=detect_base64, inputs=gr.Textbox(label="Base64 Image"), outputs=gr.Textbox(label="Result"), title="API") ], ["Upload Image", "Base64 API"] ) demo.launch()