kilimo-guard-ai / app.py
sagecipher's picture
Update app.py
5314260 verified
import gradio as gr
import os
import cv2
import tempfile
import json
import numpy as np
from pathlib import Path
from typing import List, Dict, Any
import onnxruntime as ort
from PIL import Image
# ========================================
# ANALYZER
# ========================================
class KilimoGuardAnalyzer:
def __init__(self, model_path: str, disease_db_path: str):
self.session = None
self.input_name = None
self.output_name = None
self.input_h = 384
self.input_w = 384
self.num_classes = 6
self.disease_db = {}
self.class_mapping = {}
# Load ONNX model
print(f"πŸ”Ή Loading ONNX model: {model_path}")
if os.path.exists(model_path):
try:
self.session = ort.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
shape = self.session.get_inputs()[0].shape
if len(shape) == 4:
self.input_h = shape[2] if isinstance(shape[2], int) else 384
self.input_w = shape[3] if isinstance(shape[3], int) else 384
output_shape = self.session.get_outputs()[0].shape
self.num_classes = output_shape[1] if len(output_shape) > 1 else 6
print(f"βœ… Model loaded: Input: {shape}, Output: {output_shape}")
except Exception as e:
print(f"❌ Failed to load ONNX model: {e}")
else:
print(f"⚠️ Model not found at {model_path}")
# Load disease database
print(f"πŸ”Ή Loading disease database: {disease_db_path}")
if os.path.exists(disease_db_path):
try:
with open(disease_db_path, "r", encoding="utf-8") as f:
data = json.load(f)
self.disease_db = data
self.class_mapping = data.get("class_mapping", {})
print(f"βœ… Disease database loaded: {len(self.class_mapping)} classes")
except Exception as e:
print(f"❌ Failed to load disease database: {e}")
else:
print(f"⚠️ Disease database not found at {disease_db_path}")
# -----------------------------
def preprocess_image(self, img_input) -> np.ndarray:
try:
# Save user input to temp file if it's a PIL Image
if isinstance(img_input, Image.Image):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
img_input.save(tmp.name)
path = tmp.name
else:
path = str(img_input)
img = cv2.imread(path)
if img is None:
raise ValueError(f"Cannot read image: {path}")
img = cv2.resize(img, (self.input_w, self.input_h))
img = img.astype(np.float32)/255.0
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2,0,1)) # HWC -> CHW
img = np.expand_dims(img, 0)
print(f" Preprocessed image shape: {img.shape}")
return img
except Exception as e:
print(f"❌ Preprocessing failed: {e}")
raise
# -----------------------------
def predict(self, images: List) -> (str, float, str):
if not self.session:
return "Model Error", 0.0, "Unknown"
predictions = {}
for idx, img in enumerate(images):
try:
img_tensor = self.preprocess_image(img)
outputs = self.session.run([self.output_name], {self.input_name: img_tensor})
output = outputs[0][0]
print(f" Raw output: min {np.min(output):.6f}, max {np.max(output):.6f}")
# Convert logits to probabilities if needed
if np.max(output) > 1.0:
exp_out = np.exp(output - np.max(output))
probs = exp_out / exp_out.sum()
else:
probs = output
class_idx = int(np.argmax(probs))
confidence = float(probs[class_idx])
disease_name = self.class_mapping.get(str(class_idx), f"Class_{class_idx}")
if disease_name not in predictions:
predictions[disease_name] = []
predictions[disease_name].append(confidence)
except Exception as e:
print(f"❌ Inference failed for image {idx}: {e}")
continue
if not predictions:
return "Unable to Classify", 0.0, "Unknown"
best_disease = max(predictions.items(), key=lambda x: np.mean(x[1]))
disease_name = best_disease[0]
avg_confidence = float(np.mean(best_disease[1]))
if avg_confidence < 0.5:
severity = "🟒 Low"
elif avg_confidence < 0.8:
severity = "🟑 Medium"
else:
severity = "πŸ”΄ High"
print(f"βœ… Final prediction: {disease_name} ({avg_confidence*100:.1f}%), severity: {severity}")
return disease_name, avg_confidence, severity
# -----------------------------
def get_disease_info(self, disease_name: str, language="Swahili") -> Dict[str, Any]:
if "diseases" in self.disease_db:
return self.disease_db["diseases"].get(disease_name, {}).get(language, {})
return {}
# ========================================
# INITIALIZE ANALYZER
# ========================================
MODEL_PATH = "kilimoguard_v2.onnx"
DISEASE_DB_PATH = "kilimoguard_disease_database.json"
analyzer = KilimoGuardAnalyzer(MODEL_PATH, DISEASE_DB_PATH)
# ========================================
# ANALYSIS FUNCTION
# ========================================
def analyze_crop(images, video, description, crop_type, language, progress=gr.Progress()):
print("\nπŸš€ Starting Analysis")
img_inputs = []
temp_files = []
# Save uploaded images
if images:
for img in images:
if isinstance(img, np.ndarray) or isinstance(img, Image.Image):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
temp_files.append(tmp.name)
if isinstance(img, Image.Image):
img.save(tmp.name)
else:
cv2.imwrite(tmp.name, img)
img_inputs.append(tmp.name)
else:
img_inputs.append(img)
# Extract frames from uploaded video
if video:
cap = cv2.VideoCapture(str(video))
count = 0
while len(img_inputs) < 6 and count < 120:
ret, frame = cap.read()
if not ret:
break
if count % 8 == 0:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
cv2.imwrite(tmp.name, frame)
img_inputs.append(tmp.name)
temp_files.append(tmp.name)
count += 1
cap.release()
if not img_inputs:
return "❌ No images or video frames provided", "❌ No images or video frames provided"
img_inputs = img_inputs[:6]
disease_name, confidence, severity = analyzer.predict(img_inputs)
disease_info = analyzer.get_disease_info(disease_name, language)
confidence_pct = confidence * 100
confidence_bar = "β–ˆ" * int(confidence_pct / 10) + "β–‘" * (10 - int(confidence_pct / 10))
report = f"### πŸ” Analysis Results\n**Crop:** {crop_type}\n**Condition:** {disease_name}\n**Confidence:** {confidence_bar} {confidence_pct:.1f}%\n**Severity:** {severity}\n"
if disease_info:
if "description" in disease_info:
report += f"\n**Description:** {disease_info['description']}\n"
if "symptoms" in disease_info:
report += "\n**Symptoms:**\n" + "\n".join(f"β€’ {s}" for s in disease_info["symptoms"])
if "remedies" in disease_info:
for r_type in ["organic","chemical"]:
if r_type in disease_info["remedies"]:
report += f"\n**{r_type.capitalize()} Remedies:**\n" + "\n".join(f"β€’ {r}" for r in disease_info["remedies"][r_type])
if "prevention" in disease_info:
report += "\n**Prevention:**\n" + "\n".join(f"β€’ {p}" for p in disease_info["prevention"])
# Clean up temp files
for f in temp_files:
try: os.unlink(f)
except: pass
return report, report
# ========================================
# GRADIO UI
# ========================================
examples = [
[None, None, "Leaves turning yellow", "Maize (Mahindi)", "Swahili"],
[None, None, "Brown spots on leaves", "Tomato", "Swahili"],
]
with gr.Blocks(title="KilimoGuard AI - Crop Doctor") as demo:
gr.Markdown("# 🌾 KilimoGuard AI\nYour AI-Powered Crop Doctor")
with gr.Row():
crop_type = gr.Dropdown(["Maize (Mahindi)", "Tomato", "Potato", "Cassava", "Beans", "Banana", "Other"], label="Select Crop")
language = gr.Radio(["Swahili", "English"], label="Language")
with gr.Row():
images = gr.Gallery(label="Upload Images", columns=3, height=300, preview=True)
video = gr.Video(label="Upload Video")
description = gr.Textbox(label="Describe the Problem", lines=3)
with gr.Row():
diagnose_btn = gr.Button("Analyze Now")
output = gr.Markdown()
report_state = gr.State("")
diagnose_btn.click(analyze_crop, inputs=[images, video, description, crop_type, language], outputs=[output, report_state])
# ========================================
# LAUNCH
# ========================================
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)