riceguard / app.py
keypass123's picture
Update app.py
06dc015 verified
import gradio as gr
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch
import json
import base64
import io
import numpy as np
model_name = "prithivMLmods/Rice-Leaf-Disease"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model.eval()
labels = {0: "Bacterial Leaf Blight", 1: "Rice Blast", 2: "Brown Spot", 3: "Healthy", 4: "Rice Tungro Virus"}
def is_likely_rice_leaf(image):
"""Check if image is obviously NOT a plant (laptop, face, etc).
PERMISSIVE: only rejects clearly non-plant images.
Diseased leaves can be brown/yellow/dry — we must accept those."""
img = image.resize((64, 64))
arr = np.array(img).astype(float)
r, g, b = arr[:,:,0], arr[:,:,1], arr[:,:,2]
total_pixels = r.size
# Green pixels: green channel dominates (healthy leaves)
green_mask = (g > r + 5) & (g > b + 5) & (g > 30)
green_ratio = np.sum(green_mask) / total_pixels
# Brown/tan pixels: diseased or dry leaves
brown_mask = (r > b + 10) & (r > 40) & (r < 220) & (g > 30)
brown_ratio = np.sum(brown_mask) / total_pixels
# Yellow/orange: tungro, nutrient deficiency
yellow_mask = (r > 80) & (g > 60) & (b < g) & (np.abs(r.astype(int) - g.astype(int)) < 60)
yellow_ratio = np.sum(yellow_mask) / total_pixels
# Any natural/organic color
natural_ratio = green_ratio + brown_ratio * 0.8 + yellow_ratio * 0.6
# Pure gray/metallic: electronics, concrete (all channels nearly equal AND not bright white)
gray_mask = (np.abs(r.astype(int) - g.astype(int)) < 12) & (np.abs(g.astype(int) - b.astype(int)) < 12) & (r < 200)
gray_ratio = np.sum(gray_mask) / total_pixels
# Pure black: screens, dark objects
black_mask = (r < 30) & (g < 30) & (b < 30)
black_ratio = np.sum(black_mask) / total_pixels
# Only reject if image is CLEARLY not a plant:
# - More than 60% gray/metallic AND less than 8% natural colors
# - Or more than 50% pure black (screens, dark electronics)
is_not_plant = (gray_ratio > 0.6 and natural_ratio < 0.08) or (black_ratio > 0.5)
return not is_not_plant, round(natural_ratio * 100, 1), round(gray_ratio * 100, 1)
def classify(image):
image = Image.fromarray(image).convert("RGB")
# Step 1: Color validation (only rejects obvious non-plants)
is_leaf, plant_pct, gray_pct = is_likely_rice_leaf(image)
# Step 2: Run disease model
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
results = []
for i in range(len(probs)):
results.append({"label": labels[i], "score": round(probs[i], 4)})
results.sort(key=lambda x: x["score"], reverse=True)
top = results[0]
return json.dumps({
"disease": top["label"],
"confidence": round(top["score"] * 100, 1),
"probabilities": results,
"is_valid_rice_leaf": is_leaf,
"validation_message": None if is_leaf else f"This does not appear to be a rice leaf (natural colors: {plant_pct}%, gray: {gray_pct}%). Please scan a rice leaf.",
"plant_pixel_ratio": plant_pct,
"gray_pixel_ratio": gray_pct
})
def detect_base64(img_b64):
try:
if ',' in img_b64:
img_b64 = img_b64.split(',')[1]
img_bytes = base64.b64decode(img_b64)
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return classify(np.array(image))
except Exception as e:
return json.dumps({"error": str(e)})
demo = gr.TabbedInterface(
[
gr.Interface(fn=classify, inputs=gr.Image(), outputs=gr.Textbox(label="Result"), title="RiceGuard Disease Detection"),
gr.Interface(fn=detect_base64, inputs=gr.Textbox(label="Base64 Image"), outputs=gr.Textbox(label="Result"), title="API")
],
["Upload Image", "Base64 API"]
)
demo.launch()