import os import cv2 import gradio as gr import numpy as np from PIL import Image import torch import urllib.request from ultralytics import YOLO # Try importing SAM2 try: from segment_anything import sam_model_registry, SamPredictor SAM2_AVAILABLE = True except ImportError: SAM2_AVAILABLE = False print("SAM2 not available") # Try importing SAM3 try: from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor SAM3_AVAILABLE = True except ImportError: SAM3_AVAILABLE = False print("SAM3 not available") ############################################ # Configuration ############################################ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("Running on:", DEVICE) YOLO_MODEL_PATH = "clr_YOLOV8.pt" # SAM2 vit_b (lighter) SAM2_MODEL_TYPE = "vit_b" SAM2_CHECKPOINT_PATH = "sam_vit_b_01ec64.pth" ############################################ # Download SAM2 if needed ############################################ if SAM2_AVAILABLE: if not os.path.exists(SAM2_CHECKPOINT_PATH): print("Downloading SAM2 vit_b checkpoint...") urllib.request.urlretrieve( "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", SAM2_CHECKPOINT_PATH ) ############################################ # Load Models ############################################ print("Loading models...") # YOLO yolo_model = None try: if os.path.exists(YOLO_MODEL_PATH): yolo_model = YOLO(YOLO_MODEL_PATH) print("YOLO loaded") else: print("YOLO model not found") except Exception as e: print("YOLO error:", e) # SAM2 sam2_predictor = None if SAM2_AVAILABLE: try: sam2 = sam_model_registry[SAM2_MODEL_TYPE]( checkpoint=SAM2_CHECKPOINT_PATH ) sam2.to(DEVICE) sam2_predictor = SamPredictor(sam2) print("SAM2 loaded") except Exception as e: print("SAM2 error:", e) # SAM3 (official, no checkpoint) sam3_processor = None if SAM3_AVAILABLE: try: sam3_model = build_sam3_image_model(device=DEVICE) sam3_processor = Sam3Processor(sam3_model) print("SAM3 loaded") except Exception as e: print("SAM3 error:", e) ############################################ # Helper Functions ############################################ def fallback_segment_rust(leaf_img): hsv = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2HSV) lower = np.array([10, 80, 80]) upper = np.array([35, 255, 255]) mask = cv2.inRange(hsv, lower, upper) kernel = np.ones((3,3), np.uint8) return cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) def extract_leaf_sam2(image_rgb, box): if not sam2_predictor: return np.ones((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) * 255 sam2_predictor.set_image(image_rgb) masks, _, _ = sam2_predictor.predict( box=np.array(box), multimask_output=False ) return (masks[0] * 255).astype(np.uint8) def segment_lesions_sam3(image_rgb): if not sam3_processor: return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)) try: pil_img = Image.fromarray(image_rgb) state = sam3_processor.set_image(pil_img) output = sam3_processor.set_text_prompt( state=state, prompt="yellow spot" ) masks = output.get("masks", None) if masks is None or len(masks) == 0: return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) combined = None for m in masks: m_np = m.detach().cpu().numpy() m_np = np.squeeze(m_np) m_np = (m_np > 0).astype(np.uint8) if combined is None: combined = m_np else: combined = np.maximum(combined, m_np) return combined * 255 except Exception as e: print("SAM3 error:", e) return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)) ############################################ # Main Function ############################################ def process_coffee_leaf(image): if image is None: return None, None, [["Upload image", "-", "-"]] if yolo_model is None: return image, None, [["Error", "YOLO not loaded", "-"]] image_np = np.array(image) image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) h, w = image_cv.shape[:2] results = yolo_model(image_cv, verbose=False) boxes = results[0].boxes.xyxy.cpu().numpy() if len(boxes) == 0: return image_np, None, [["No leaves detected", "-", "-"]] annotated = image_np.copy() gallery = [] table = [] for i, box in enumerate(boxes): x1, y1, x2, y2 = box.astype(int) leaf_crop = image_np[y1:y2, x1:x2] if leaf_crop.size == 0: continue # SAM2 leaf mask leaf_mask_full = extract_leaf_sam2(image_np, box) leaf_mask = leaf_mask_full[y1:y2, x1:x2] leaf_pixels = cv2.countNonZero(leaf_mask) # Cutout cutout = np.zeros_like(leaf_crop) cutout[leaf_mask > 0] = leaf_crop[leaf_mask > 0] # SAM3 lesions rust_mask = segment_lesions_sam3(cutout) rust_mask = cv2.bitwise_and(rust_mask, leaf_mask) rust_pixels = cv2.countNonZero(rust_mask) severity = (rust_pixels / leaf_pixels) * 100 if leaf_pixels > 0 else 0 # Draw bbox cv2.rectangle(annotated, (x1,y1), (x2,y2), (0,255,0), 2) cv2.putText(annotated, f"{severity:.1f}%", (x1,y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2) # Overlay overlay = cutout.copy() overlay[rust_mask > 0] = [128, 0, 128] gallery.append((cutout, f"Leaf {i+1}")) gallery.append((overlay, f"{severity:.1f}%")) table.append([str(i+1), f"{severity:.1f}%", f"{100-severity:.1f}%"]) return annotated, gallery, table ############################################ # UI ############################################ with gr.Blocks() as demo: gr.Markdown("# ☕ Coffee Leaf Rust Severity Estimator") image_input = gr.Image(type="pil") submit = gr.Button("Run") clear = gr.Button("Clear") output_img = gr.Image() gallery = gr.Gallery(columns=2) table = gr.Dataframe(headers=["Leaf", "Severity", "Healthy"]) submit.click( process_coffee_leaf, inputs=image_input, outputs=[output_img, gallery, table] ) def clear_all(): return None, None, None, None clear.click( clear_all, outputs=[image_input, output_img, gallery, table] ) ############################################ # Launch ############################################ if __name__ == "__main__": demo.launch()