Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import urllib.request | |
| from ultralytics import YOLO | |
| # Try importing SAM2 | |
| try: | |
| from segment_anything import sam_model_registry, SamPredictor | |
| SAM2_AVAILABLE = True | |
| except ImportError: | |
| SAM2_AVAILABLE = False | |
| print("SAM2 not available") | |
| # Try importing SAM3 | |
| try: | |
| from sam3.model_builder import build_sam3_image_model | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| SAM3_AVAILABLE = True | |
| except ImportError: | |
| SAM3_AVAILABLE = False | |
| print("SAM3 not available") | |
| ############################################ | |
| # Configuration | |
| ############################################ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Running on:", DEVICE) | |
| YOLO_MODEL_PATH = "clr_YOLOV8.pt" | |
| # SAM2 vit_b (lighter) | |
| SAM2_MODEL_TYPE = "vit_b" | |
| SAM2_CHECKPOINT_PATH = "sam_vit_b_01ec64.pth" | |
| ############################################ | |
| # Download SAM2 if needed | |
| ############################################ | |
| if SAM2_AVAILABLE: | |
| if not os.path.exists(SAM2_CHECKPOINT_PATH): | |
| print("Downloading SAM2 vit_b checkpoint...") | |
| urllib.request.urlretrieve( | |
| "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", | |
| SAM2_CHECKPOINT_PATH | |
| ) | |
| ############################################ | |
| # Load Models | |
| ############################################ | |
| print("Loading models...") | |
| # YOLO | |
| yolo_model = None | |
| try: | |
| if os.path.exists(YOLO_MODEL_PATH): | |
| yolo_model = YOLO(YOLO_MODEL_PATH) | |
| print("YOLO loaded") | |
| else: | |
| print("YOLO model not found") | |
| except Exception as e: | |
| print("YOLO error:", e) | |
| # SAM2 | |
| sam2_predictor = None | |
| if SAM2_AVAILABLE: | |
| try: | |
| sam2 = sam_model_registry[SAM2_MODEL_TYPE]( | |
| checkpoint=SAM2_CHECKPOINT_PATH | |
| ) | |
| sam2.to(DEVICE) | |
| sam2_predictor = SamPredictor(sam2) | |
| print("SAM2 loaded") | |
| except Exception as e: | |
| print("SAM2 error:", e) | |
| # SAM3 (official, no checkpoint) | |
| sam3_processor = None | |
| if SAM3_AVAILABLE: | |
| try: | |
| sam3_model = build_sam3_image_model(device=DEVICE) | |
| sam3_processor = Sam3Processor(sam3_model) | |
| print("SAM3 loaded") | |
| except Exception as e: | |
| print("SAM3 error:", e) | |
| ############################################ | |
| # Helper Functions | |
| ############################################ | |
| def fallback_segment_rust(leaf_img): | |
| hsv = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2HSV) | |
| lower = np.array([10, 80, 80]) | |
| upper = np.array([35, 255, 255]) | |
| mask = cv2.inRange(hsv, lower, upper) | |
| kernel = np.ones((3,3), np.uint8) | |
| return cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| def extract_leaf_sam2(image_rgb, box): | |
| if not sam2_predictor: | |
| return np.ones((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) * 255 | |
| sam2_predictor.set_image(image_rgb) | |
| masks, _, _ = sam2_predictor.predict( | |
| box=np.array(box), | |
| multimask_output=False | |
| ) | |
| return (masks[0] * 255).astype(np.uint8) | |
| def segment_lesions_sam3(image_rgb): | |
| if not sam3_processor: | |
| return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)) | |
| try: | |
| pil_img = Image.fromarray(image_rgb) | |
| state = sam3_processor.set_image(pil_img) | |
| output = sam3_processor.set_text_prompt( | |
| state=state, | |
| prompt="yellow spot" | |
| ) | |
| masks = output.get("masks", None) | |
| if masks is None or len(masks) == 0: | |
| return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) | |
| combined = None | |
| for m in masks: | |
| m_np = m.detach().cpu().numpy() | |
| m_np = np.squeeze(m_np) | |
| m_np = (m_np > 0).astype(np.uint8) | |
| if combined is None: | |
| combined = m_np | |
| else: | |
| combined = np.maximum(combined, m_np) | |
| return combined * 255 | |
| except Exception as e: | |
| print("SAM3 error:", e) | |
| return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)) | |
| ############################################ | |
| # Main Function | |
| ############################################ | |
| def process_coffee_leaf(image): | |
| if image is None: | |
| return None, None, [["Upload image", "-", "-"]] | |
| if yolo_model is None: | |
| return image, None, [["Error", "YOLO not loaded", "-"]] | |
| image_np = np.array(image) | |
| image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| h, w = image_cv.shape[:2] | |
| results = yolo_model(image_cv, verbose=False) | |
| boxes = results[0].boxes.xyxy.cpu().numpy() | |
| if len(boxes) == 0: | |
| return image_np, None, [["No leaves detected", "-", "-"]] | |
| annotated = image_np.copy() | |
| gallery = [] | |
| table = [] | |
| for i, box in enumerate(boxes): | |
| x1, y1, x2, y2 = box.astype(int) | |
| leaf_crop = image_np[y1:y2, x1:x2] | |
| if leaf_crop.size == 0: | |
| continue | |
| # SAM2 leaf mask | |
| leaf_mask_full = extract_leaf_sam2(image_np, box) | |
| leaf_mask = leaf_mask_full[y1:y2, x1:x2] | |
| leaf_pixels = cv2.countNonZero(leaf_mask) | |
| # Cutout | |
| cutout = np.zeros_like(leaf_crop) | |
| cutout[leaf_mask > 0] = leaf_crop[leaf_mask > 0] | |
| # SAM3 lesions | |
| rust_mask = segment_lesions_sam3(cutout) | |
| rust_mask = cv2.bitwise_and(rust_mask, leaf_mask) | |
| rust_pixels = cv2.countNonZero(rust_mask) | |
| severity = (rust_pixels / leaf_pixels) * 100 if leaf_pixels > 0 else 0 | |
| # Draw bbox | |
| cv2.rectangle(annotated, (x1,y1), (x2,y2), (0,255,0), 2) | |
| cv2.putText(annotated, f"{severity:.1f}%", (x1,y1-5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2) | |
| # Overlay | |
| overlay = cutout.copy() | |
| overlay[rust_mask > 0] = [128, 0, 128] | |
| gallery.append((cutout, f"Leaf {i+1}")) | |
| gallery.append((overlay, f"{severity:.1f}%")) | |
| table.append([str(i+1), f"{severity:.1f}%", f"{100-severity:.1f}%"]) | |
| return annotated, gallery, table | |
| ############################################ | |
| # UI | |
| ############################################ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ☕ Coffee Leaf Rust Severity Estimator") | |
| image_input = gr.Image(type="pil") | |
| submit = gr.Button("Run") | |
| clear = gr.Button("Clear") | |
| output_img = gr.Image() | |
| gallery = gr.Gallery(columns=2) | |
| table = gr.Dataframe(headers=["Leaf", "Severity", "Healthy"]) | |
| submit.click( | |
| process_coffee_leaf, | |
| inputs=image_input, | |
| outputs=[output_img, gallery, table] | |
| ) | |
| def clear_all(): | |
| return None, None, None, None | |
| clear.click( | |
| clear_all, | |
| outputs=[image_input, output_img, gallery, table] | |
| ) | |
| ############################################ | |
| # Launch | |
| ############################################ | |
| if __name__ == "__main__": | |
| demo.launch() |