Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import cv2 | |
| import tempfile | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import onnxruntime as ort | |
| from PIL import Image | |
| # ======================================== | |
| # ANALYZER | |
| # ======================================== | |
| class KilimoGuardAnalyzer: | |
| def __init__(self, model_path: str, disease_db_path: str): | |
| self.session = None | |
| self.input_name = None | |
| self.output_name = None | |
| self.input_h = 384 | |
| self.input_w = 384 | |
| self.num_classes = 6 | |
| self.disease_db = {} | |
| self.class_mapping = {} | |
| # Load ONNX model | |
| print(f"πΉ Loading ONNX model: {model_path}") | |
| if os.path.exists(model_path): | |
| try: | |
| self.session = ort.InferenceSession(model_path) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.output_name = self.session.get_outputs()[0].name | |
| shape = self.session.get_inputs()[0].shape | |
| if len(shape) == 4: | |
| self.input_h = shape[2] if isinstance(shape[2], int) else 384 | |
| self.input_w = shape[3] if isinstance(shape[3], int) else 384 | |
| output_shape = self.session.get_outputs()[0].shape | |
| self.num_classes = output_shape[1] if len(output_shape) > 1 else 6 | |
| print(f"β Model loaded: Input: {shape}, Output: {output_shape}") | |
| except Exception as e: | |
| print(f"β Failed to load ONNX model: {e}") | |
| else: | |
| print(f"β οΈ Model not found at {model_path}") | |
| # Load disease database | |
| print(f"πΉ Loading disease database: {disease_db_path}") | |
| if os.path.exists(disease_db_path): | |
| try: | |
| with open(disease_db_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| self.disease_db = data | |
| self.class_mapping = data.get("class_mapping", {}) | |
| print(f"β Disease database loaded: {len(self.class_mapping)} classes") | |
| except Exception as e: | |
| print(f"β Failed to load disease database: {e}") | |
| else: | |
| print(f"β οΈ Disease database not found at {disease_db_path}") | |
| # ----------------------------- | |
| def preprocess_image(self, img_input) -> np.ndarray: | |
| try: | |
| # Save user input to temp file if it's a PIL Image | |
| if isinstance(img_input, Image.Image): | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
| img_input.save(tmp.name) | |
| path = tmp.name | |
| else: | |
| path = str(img_input) | |
| img = cv2.imread(path) | |
| if img is None: | |
| raise ValueError(f"Cannot read image: {path}") | |
| img = cv2.resize(img, (self.input_w, self.input_h)) | |
| img = img.astype(np.float32)/255.0 | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = np.transpose(img, (2,0,1)) # HWC -> CHW | |
| img = np.expand_dims(img, 0) | |
| print(f" Preprocessed image shape: {img.shape}") | |
| return img | |
| except Exception as e: | |
| print(f"β Preprocessing failed: {e}") | |
| raise | |
| # ----------------------------- | |
| def predict(self, images: List) -> (str, float, str): | |
| if not self.session: | |
| return "Model Error", 0.0, "Unknown" | |
| predictions = {} | |
| for idx, img in enumerate(images): | |
| try: | |
| img_tensor = self.preprocess_image(img) | |
| outputs = self.session.run([self.output_name], {self.input_name: img_tensor}) | |
| output = outputs[0][0] | |
| print(f" Raw output: min {np.min(output):.6f}, max {np.max(output):.6f}") | |
| # Convert logits to probabilities if needed | |
| if np.max(output) > 1.0: | |
| exp_out = np.exp(output - np.max(output)) | |
| probs = exp_out / exp_out.sum() | |
| else: | |
| probs = output | |
| class_idx = int(np.argmax(probs)) | |
| confidence = float(probs[class_idx]) | |
| disease_name = self.class_mapping.get(str(class_idx), f"Class_{class_idx}") | |
| if disease_name not in predictions: | |
| predictions[disease_name] = [] | |
| predictions[disease_name].append(confidence) | |
| except Exception as e: | |
| print(f"β Inference failed for image {idx}: {e}") | |
| continue | |
| if not predictions: | |
| return "Unable to Classify", 0.0, "Unknown" | |
| best_disease = max(predictions.items(), key=lambda x: np.mean(x[1])) | |
| disease_name = best_disease[0] | |
| avg_confidence = float(np.mean(best_disease[1])) | |
| if avg_confidence < 0.5: | |
| severity = "π’ Low" | |
| elif avg_confidence < 0.8: | |
| severity = "π‘ Medium" | |
| else: | |
| severity = "π΄ High" | |
| print(f"β Final prediction: {disease_name} ({avg_confidence*100:.1f}%), severity: {severity}") | |
| return disease_name, avg_confidence, severity | |
| # ----------------------------- | |
| def get_disease_info(self, disease_name: str, language="Swahili") -> Dict[str, Any]: | |
| if "diseases" in self.disease_db: | |
| return self.disease_db["diseases"].get(disease_name, {}).get(language, {}) | |
| return {} | |
| # ======================================== | |
| # INITIALIZE ANALYZER | |
| # ======================================== | |
| MODEL_PATH = "kilimoguard_v2.onnx" | |
| DISEASE_DB_PATH = "kilimoguard_disease_database.json" | |
| analyzer = KilimoGuardAnalyzer(MODEL_PATH, DISEASE_DB_PATH) | |
| # ======================================== | |
| # ANALYSIS FUNCTION | |
| # ======================================== | |
| def analyze_crop(images, video, description, crop_type, language, progress=gr.Progress()): | |
| print("\nπ Starting Analysis") | |
| img_inputs = [] | |
| temp_files = [] | |
| # Save uploaded images | |
| if images: | |
| for img in images: | |
| if isinstance(img, np.ndarray) or isinstance(img, Image.Image): | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
| temp_files.append(tmp.name) | |
| if isinstance(img, Image.Image): | |
| img.save(tmp.name) | |
| else: | |
| cv2.imwrite(tmp.name, img) | |
| img_inputs.append(tmp.name) | |
| else: | |
| img_inputs.append(img) | |
| # Extract frames from uploaded video | |
| if video: | |
| cap = cv2.VideoCapture(str(video)) | |
| count = 0 | |
| while len(img_inputs) < 6 and count < 120: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % 8 == 0: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") | |
| cv2.imwrite(tmp.name, frame) | |
| img_inputs.append(tmp.name) | |
| temp_files.append(tmp.name) | |
| count += 1 | |
| cap.release() | |
| if not img_inputs: | |
| return "β No images or video frames provided", "β No images or video frames provided" | |
| img_inputs = img_inputs[:6] | |
| disease_name, confidence, severity = analyzer.predict(img_inputs) | |
| disease_info = analyzer.get_disease_info(disease_name, language) | |
| confidence_pct = confidence * 100 | |
| confidence_bar = "β" * int(confidence_pct / 10) + "β" * (10 - int(confidence_pct / 10)) | |
| report = f"### π Analysis Results\n**Crop:** {crop_type}\n**Condition:** {disease_name}\n**Confidence:** {confidence_bar} {confidence_pct:.1f}%\n**Severity:** {severity}\n" | |
| if disease_info: | |
| if "description" in disease_info: | |
| report += f"\n**Description:** {disease_info['description']}\n" | |
| if "symptoms" in disease_info: | |
| report += "\n**Symptoms:**\n" + "\n".join(f"β’ {s}" for s in disease_info["symptoms"]) | |
| if "remedies" in disease_info: | |
| for r_type in ["organic","chemical"]: | |
| if r_type in disease_info["remedies"]: | |
| report += f"\n**{r_type.capitalize()} Remedies:**\n" + "\n".join(f"β’ {r}" for r in disease_info["remedies"][r_type]) | |
| if "prevention" in disease_info: | |
| report += "\n**Prevention:**\n" + "\n".join(f"β’ {p}" for p in disease_info["prevention"]) | |
| # Clean up temp files | |
| for f in temp_files: | |
| try: os.unlink(f) | |
| except: pass | |
| return report, report | |
| # ======================================== | |
| # GRADIO UI | |
| # ======================================== | |
| examples = [ | |
| [None, None, "Leaves turning yellow", "Maize (Mahindi)", "Swahili"], | |
| [None, None, "Brown spots on leaves", "Tomato", "Swahili"], | |
| ] | |
| with gr.Blocks(title="KilimoGuard AI - Crop Doctor") as demo: | |
| gr.Markdown("# πΎ KilimoGuard AI\nYour AI-Powered Crop Doctor") | |
| with gr.Row(): | |
| crop_type = gr.Dropdown(["Maize (Mahindi)", "Tomato", "Potato", "Cassava", "Beans", "Banana", "Other"], label="Select Crop") | |
| language = gr.Radio(["Swahili", "English"], label="Language") | |
| with gr.Row(): | |
| images = gr.Gallery(label="Upload Images", columns=3, height=300, preview=True) | |
| video = gr.Video(label="Upload Video") | |
| description = gr.Textbox(label="Describe the Problem", lines=3) | |
| with gr.Row(): | |
| diagnose_btn = gr.Button("Analyze Now") | |
| output = gr.Markdown() | |
| report_state = gr.State("") | |
| diagnose_btn.click(analyze_crop, inputs=[images, video, description, crop_type, language], outputs=[output, report_state]) | |
| # ======================================== | |
| # LAUNCH | |
| # ======================================== | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |