Fahimeh Orvati Nia
Add sorghum_pipeline code
b4123b8
raw
history blame
10.6 kB
"""
Segmentation manager for the Sorghum Pipeline.
This module handles image segmentation using the BRIA model
and provides post-processing capabilities.
"""
import numpy as np
import cv2
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Optional, Tuple
import logging
logger = logging.getLogger(__name__)
class SegmentationManager:
"""Manages image segmentation using BRIA model."""
def __init__(self,
model_name: str = "briaai/RMBG-2.0",
device: str = "auto",
threshold: float = 0.5,
trust_remote_code: bool = True,
cache_dir: Optional[str] = None,
local_files_only: bool = False):
"""
Initialize segmentation manager.
Args:
model_name: Name of the BRIA model
device: Device to run model on ("auto", "cpu", "cuda")
threshold: Segmentation threshold
trust_remote_code: Whether to trust remote code
cache_dir: Hugging Face cache directory for model weights
local_files_only: If True, only load from local cache
"""
self.model_name = model_name
self.threshold = threshold
self.trust_remote_code = trust_remote_code
self.cache_dir = cache_dir
self.local_files_only = local_files_only
# Determine device
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# Initialize model
self.model = None
self.transform = None
self._load_model()
def _load_model(self):
"""Load the BRIA segmentation model."""
try:
logger.info(f"Loading BRIA model: {self.model_name}")
self.model = AutoModelForImageSegmentation.from_pretrained(
self.model_name,
trust_remote_code=self.trust_remote_code,
cache_dir=self.cache_dir if self.cache_dir else None,
local_files_only=self.local_files_only,
).eval().to(self.device)
# Define image transform
self.transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
logger.info("BRIA model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BRIA model: {e}")
raise
def segment_image(self, image: np.ndarray) -> np.ndarray:
"""
Segment an image using the BRIA model.
Args:
image: Input image (BGR format)
Returns:
Binary mask (0/255)
"""
if self.model is None:
raise RuntimeError("Model not loaded")
try:
# Convert BGR to RGB
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_image)
# Apply transform
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
# Run inference
with torch.no_grad():
predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
# Apply threshold
mask = (predictions > self.threshold).astype(np.uint8) * 255
# Resize back to original size
original_size = (image.shape[1], image.shape[0]) # (width, height)
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
return mask_resized
except Exception as e:
logger.error(f"Segmentation failed: {e}")
# Return empty mask
return np.zeros(image.shape[:2], dtype=np.uint8)
def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
"""
Segment an image and return a soft mask in [0, 1] resized to original size.
No thresholding or post-processing is applied.
Args:
image: Input image (BGR format)
Returns:
Float mask in [0,1] with shape (H, W)
"""
if self.model is None:
raise RuntimeError("Model not loaded")
try:
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_image)
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
original_size = (image.shape[1], image.shape[0])
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
return np.clip(soft_mask, 0.0, 1.0)
except Exception as e:
logger.error(f"Soft segmentation failed: {e}")
return np.zeros(image.shape[:2], dtype=np.float32)
def post_process_mask(self, mask: np.ndarray,
min_area: int = 1000,
kernel_size: int = 5) -> np.ndarray:
"""
Post-process segmentation mask.
Args:
mask: Input mask
min_area: Minimum area for connected components
kernel_size: Kernel size for morphological operations
Returns:
Post-processed mask
"""
try:
# Morphological opening to remove noise
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
# Remove small connected components
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
opened, connectivity=8
)
processed_mask = np.zeros_like(opened)
for label in range(1, num_labels): # Skip background
if stats[label, cv2.CC_STAT_AREA] >= min_area:
processed_mask[labels == label] = 255
return processed_mask
except Exception as e:
logger.error(f"Mask post-processing failed: {e}")
return mask
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray:
"""
Keep only the largest connected component.
Args:
mask: Input mask
Returns:
Mask with only the largest component
"""
try:
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8)
if num_labels <= 1:
return mask
# Find the largest component (excluding background)
areas = stats[1:, cv2.CC_STAT_AREA]
largest_label = 1 + np.argmax(areas)
# Create mask with only the largest component
largest_mask = (labels == largest_label).astype(np.uint8) * 255
return largest_mask
except Exception as e:
logger.error(f"Largest component extraction failed: {e}")
return mask
def validate_mask(self, mask: np.ndarray) -> bool:
"""
Validate segmentation mask.
Args:
mask: Mask to validate
Returns:
True if valid, False otherwise
"""
if mask is None:
return False
if not isinstance(mask, np.ndarray):
return False
if mask.ndim != 2:
return False
if mask.dtype not in [np.uint8, np.bool_]:
return False
# Check if mask has any foreground pixels
if np.sum(mask > 0) == 0:
logger.warning("Mask has no foreground pixels")
return False
return True
def get_mask_properties(self, mask: np.ndarray) -> dict:
"""
Get properties of the segmentation mask.
Args:
mask: Binary mask
Returns:
Dictionary of mask properties
"""
if not self.validate_mask(mask):
return {}
try:
# Convert to binary
binary_mask = (mask > 127).astype(np.uint8)
# Calculate properties
area = np.sum(binary_mask)
perimeter = 0
# Find contours
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
perimeter = cv2.arcLength(contours[0], True)
# Bounding box
x, y, w, h = cv2.boundingRect(contours[0])
bbox_area = w * h
aspect_ratio = w / h if h > 0 else 0
else:
bbox_area = 0
aspect_ratio = 0
return {
"area": int(area),
"perimeter": float(perimeter),
"bbox_area": int(bbox_area),
"aspect_ratio": float(aspect_ratio),
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0,
"num_components": len(contours)
}
except Exception as e:
logger.error(f"Mask property calculation failed: {e}")
return {}
def create_overlay(self, image: np.ndarray, mask: np.ndarray,
color: Tuple[int, int, int] = (0, 255, 0),
alpha: float = 0.5) -> np.ndarray:
"""
Create overlay of mask on image.
Args:
image: Base image
mask: Binary mask
color: Overlay color (B, G, R)
alpha: Overlay transparency
Returns:
Image with mask overlay
"""
try:
overlay = image.copy()
overlay[mask == 255] = color
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0)
except Exception as e:
logger.error(f"Overlay creation failed: {e}")
return image