CLR_Severity / app.py
MaryPazRB's picture
update app: SAM2 vit_b + SAM3 integration
8006379
import os
import cv2
import gradio as gr
import numpy as np
from PIL import Image
import torch
import urllib.request
from ultralytics import YOLO
# Try importing SAM2
try:
from segment_anything import sam_model_registry, SamPredictor
SAM2_AVAILABLE = True
except ImportError:
SAM2_AVAILABLE = False
print("SAM2 not available")
# Try importing SAM3
try:
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
SAM3_AVAILABLE = True
except ImportError:
SAM3_AVAILABLE = False
print("SAM3 not available")
############################################
# Configuration
############################################
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", DEVICE)
YOLO_MODEL_PATH = "clr_YOLOV8.pt"
# SAM2 vit_b (lighter)
SAM2_MODEL_TYPE = "vit_b"
SAM2_CHECKPOINT_PATH = "sam_vit_b_01ec64.pth"
############################################
# Download SAM2 if needed
############################################
if SAM2_AVAILABLE:
if not os.path.exists(SAM2_CHECKPOINT_PATH):
print("Downloading SAM2 vit_b checkpoint...")
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
SAM2_CHECKPOINT_PATH
)
############################################
# Load Models
############################################
print("Loading models...")
# YOLO
yolo_model = None
try:
if os.path.exists(YOLO_MODEL_PATH):
yolo_model = YOLO(YOLO_MODEL_PATH)
print("YOLO loaded")
else:
print("YOLO model not found")
except Exception as e:
print("YOLO error:", e)
# SAM2
sam2_predictor = None
if SAM2_AVAILABLE:
try:
sam2 = sam_model_registry[SAM2_MODEL_TYPE](
checkpoint=SAM2_CHECKPOINT_PATH
)
sam2.to(DEVICE)
sam2_predictor = SamPredictor(sam2)
print("SAM2 loaded")
except Exception as e:
print("SAM2 error:", e)
# SAM3 (official, no checkpoint)
sam3_processor = None
if SAM3_AVAILABLE:
try:
sam3_model = build_sam3_image_model(device=DEVICE)
sam3_processor = Sam3Processor(sam3_model)
print("SAM3 loaded")
except Exception as e:
print("SAM3 error:", e)
############################################
# Helper Functions
############################################
def fallback_segment_rust(leaf_img):
hsv = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2HSV)
lower = np.array([10, 80, 80])
upper = np.array([35, 255, 255])
mask = cv2.inRange(hsv, lower, upper)
kernel = np.ones((3,3), np.uint8)
return cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
def extract_leaf_sam2(image_rgb, box):
if not sam2_predictor:
return np.ones((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) * 255
sam2_predictor.set_image(image_rgb)
masks, _, _ = sam2_predictor.predict(
box=np.array(box),
multimask_output=False
)
return (masks[0] * 255).astype(np.uint8)
def segment_lesions_sam3(image_rgb):
if not sam3_processor:
return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
try:
pil_img = Image.fromarray(image_rgb)
state = sam3_processor.set_image(pil_img)
output = sam3_processor.set_text_prompt(
state=state,
prompt="yellow spot"
)
masks = output.get("masks", None)
if masks is None or len(masks) == 0:
return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8)
combined = None
for m in masks:
m_np = m.detach().cpu().numpy()
m_np = np.squeeze(m_np)
m_np = (m_np > 0).astype(np.uint8)
if combined is None:
combined = m_np
else:
combined = np.maximum(combined, m_np)
return combined * 255
except Exception as e:
print("SAM3 error:", e)
return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
############################################
# Main Function
############################################
def process_coffee_leaf(image):
if image is None:
return None, None, [["Upload image", "-", "-"]]
if yolo_model is None:
return image, None, [["Error", "YOLO not loaded", "-"]]
image_np = np.array(image)
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
h, w = image_cv.shape[:2]
results = yolo_model(image_cv, verbose=False)
boxes = results[0].boxes.xyxy.cpu().numpy()
if len(boxes) == 0:
return image_np, None, [["No leaves detected", "-", "-"]]
annotated = image_np.copy()
gallery = []
table = []
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box.astype(int)
leaf_crop = image_np[y1:y2, x1:x2]
if leaf_crop.size == 0:
continue
# SAM2 leaf mask
leaf_mask_full = extract_leaf_sam2(image_np, box)
leaf_mask = leaf_mask_full[y1:y2, x1:x2]
leaf_pixels = cv2.countNonZero(leaf_mask)
# Cutout
cutout = np.zeros_like(leaf_crop)
cutout[leaf_mask > 0] = leaf_crop[leaf_mask > 0]
# SAM3 lesions
rust_mask = segment_lesions_sam3(cutout)
rust_mask = cv2.bitwise_and(rust_mask, leaf_mask)
rust_pixels = cv2.countNonZero(rust_mask)
severity = (rust_pixels / leaf_pixels) * 100 if leaf_pixels > 0 else 0
# Draw bbox
cv2.rectangle(annotated, (x1,y1), (x2,y2), (0,255,0), 2)
cv2.putText(annotated, f"{severity:.1f}%", (x1,y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
# Overlay
overlay = cutout.copy()
overlay[rust_mask > 0] = [128, 0, 128]
gallery.append((cutout, f"Leaf {i+1}"))
gallery.append((overlay, f"{severity:.1f}%"))
table.append([str(i+1), f"{severity:.1f}%", f"{100-severity:.1f}%"])
return annotated, gallery, table
############################################
# UI
############################################
with gr.Blocks() as demo:
gr.Markdown("# ☕ Coffee Leaf Rust Severity Estimator")
image_input = gr.Image(type="pil")
submit = gr.Button("Run")
clear = gr.Button("Clear")
output_img = gr.Image()
gallery = gr.Gallery(columns=2)
table = gr.Dataframe(headers=["Leaf", "Severity", "Healthy"])
submit.click(
process_coffee_leaf,
inputs=image_input,
outputs=[output_img, gallery, table]
)
def clear_all():
return None, None, None, None
clear.click(
clear_all,
outputs=[image_input, output_img, gallery, table]
)
############################################
# Launch
############################################
if __name__ == "__main__":
demo.launch()