|
|
""" |
|
|
Segmentation manager for the Sorghum Pipeline. |
|
|
|
|
|
This module handles image segmentation using the BRIA model |
|
|
and provides post-processing capabilities. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
from typing import Optional, Tuple |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SegmentationManager: |
|
|
"""Manages image segmentation using BRIA model.""" |
|
|
|
|
|
def __init__(self, |
|
|
model_name: str = "briaai/RMBG-2.0", |
|
|
device: str = "auto", |
|
|
threshold: float = 0.5, |
|
|
trust_remote_code: bool = True, |
|
|
cache_dir: Optional[str] = None, |
|
|
local_files_only: bool = False): |
|
|
""" |
|
|
Initialize segmentation manager. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the BRIA model |
|
|
device: Device to run model on ("auto", "cpu", "cuda") |
|
|
threshold: Segmentation threshold |
|
|
trust_remote_code: Whether to trust remote code |
|
|
cache_dir: Hugging Face cache directory for model weights |
|
|
local_files_only: If True, only load from local cache |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.threshold = threshold |
|
|
self.trust_remote_code = trust_remote_code |
|
|
self.cache_dir = cache_dir |
|
|
self.local_files_only = local_files_only |
|
|
|
|
|
|
|
|
if device == "auto": |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
else: |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.transform = None |
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the BRIA segmentation model.""" |
|
|
try: |
|
|
logger.info(f"Loading BRIA model: {self.model_name}") |
|
|
|
|
|
self.model = AutoModelForImageSegmentation.from_pretrained( |
|
|
self.model_name, |
|
|
trust_remote_code=self.trust_remote_code, |
|
|
cache_dir=self.cache_dir if self.cache_dir else None, |
|
|
local_files_only=self.local_files_only, |
|
|
).eval().to(self.device) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((1024, 1024)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
logger.info("BRIA model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load BRIA model: {e}") |
|
|
raise |
|
|
|
|
|
def segment_image(self, image: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Segment an image using the BRIA model. |
|
|
|
|
|
Args: |
|
|
image: Input image (BGR format) |
|
|
|
|
|
Returns: |
|
|
Binary mask (0/255) |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(rgb_image) |
|
|
|
|
|
|
|
|
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() |
|
|
|
|
|
|
|
|
mask = (predictions > self.threshold).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
original_size = (image.shape[1], image.shape[0]) |
|
|
mask_resized = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
return mask_resized |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Segmentation failed: {e}") |
|
|
|
|
|
return np.zeros(image.shape[:2], dtype=np.uint8) |
|
|
|
|
|
def segment_image_soft(self, image: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Segment an image and return a soft mask in [0, 1] resized to original size. |
|
|
No thresholding or post-processing is applied. |
|
|
|
|
|
Args: |
|
|
image: Input image (BGR format) |
|
|
|
|
|
Returns: |
|
|
Float mask in [0,1] with shape (H, W) |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded") |
|
|
try: |
|
|
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
pil_image = Image.fromarray(rgb_image) |
|
|
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) |
|
|
with torch.no_grad(): |
|
|
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() |
|
|
original_size = (image.shape[1], image.shape[0]) |
|
|
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) |
|
|
return np.clip(soft_mask, 0.0, 1.0) |
|
|
except Exception as e: |
|
|
logger.error(f"Soft segmentation failed: {e}") |
|
|
return np.zeros(image.shape[:2], dtype=np.float32) |
|
|
|
|
|
def post_process_mask(self, mask: np.ndarray, |
|
|
min_area: int = 1000, |
|
|
kernel_size: int = 5) -> np.ndarray: |
|
|
""" |
|
|
Post-process segmentation mask. |
|
|
|
|
|
Args: |
|
|
mask: Input mask |
|
|
min_area: Minimum area for connected components |
|
|
kernel_size: Kernel size for morphological operations |
|
|
|
|
|
Returns: |
|
|
Post-processed mask |
|
|
""" |
|
|
try: |
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
|
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
|
|
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( |
|
|
opened, connectivity=8 |
|
|
) |
|
|
|
|
|
processed_mask = np.zeros_like(opened) |
|
|
for label in range(1, num_labels): |
|
|
if stats[label, cv2.CC_STAT_AREA] >= min_area: |
|
|
processed_mask[labels == label] = 255 |
|
|
|
|
|
return processed_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Mask post-processing failed: {e}") |
|
|
return mask |
|
|
|
|
|
def keep_largest_component(self, mask: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Keep only the largest connected component. |
|
|
|
|
|
Args: |
|
|
mask: Input mask |
|
|
|
|
|
Returns: |
|
|
Mask with only the largest component |
|
|
""" |
|
|
try: |
|
|
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, 8) |
|
|
|
|
|
if num_labels <= 1: |
|
|
return mask |
|
|
|
|
|
|
|
|
areas = stats[1:, cv2.CC_STAT_AREA] |
|
|
largest_label = 1 + np.argmax(areas) |
|
|
|
|
|
|
|
|
largest_mask = (labels == largest_label).astype(np.uint8) * 255 |
|
|
|
|
|
return largest_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Largest component extraction failed: {e}") |
|
|
return mask |
|
|
|
|
|
def validate_mask(self, mask: np.ndarray) -> bool: |
|
|
""" |
|
|
Validate segmentation mask. |
|
|
|
|
|
Args: |
|
|
mask: Mask to validate |
|
|
|
|
|
Returns: |
|
|
True if valid, False otherwise |
|
|
""" |
|
|
if mask is None: |
|
|
return False |
|
|
|
|
|
if not isinstance(mask, np.ndarray): |
|
|
return False |
|
|
|
|
|
if mask.ndim != 2: |
|
|
return False |
|
|
|
|
|
if mask.dtype not in [np.uint8, np.bool_]: |
|
|
return False |
|
|
|
|
|
|
|
|
if np.sum(mask > 0) == 0: |
|
|
logger.warning("Mask has no foreground pixels") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def get_mask_properties(self, mask: np.ndarray) -> dict: |
|
|
""" |
|
|
Get properties of the segmentation mask. |
|
|
|
|
|
Args: |
|
|
mask: Binary mask |
|
|
|
|
|
Returns: |
|
|
Dictionary of mask properties |
|
|
""" |
|
|
if not self.validate_mask(mask): |
|
|
return {} |
|
|
|
|
|
try: |
|
|
|
|
|
binary_mask = (mask > 127).astype(np.uint8) |
|
|
|
|
|
|
|
|
area = np.sum(binary_mask) |
|
|
perimeter = 0 |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
if contours: |
|
|
perimeter = cv2.arcLength(contours[0], True) |
|
|
|
|
|
|
|
|
x, y, w, h = cv2.boundingRect(contours[0]) |
|
|
bbox_area = w * h |
|
|
aspect_ratio = w / h if h > 0 else 0 |
|
|
else: |
|
|
bbox_area = 0 |
|
|
aspect_ratio = 0 |
|
|
|
|
|
return { |
|
|
"area": int(area), |
|
|
"perimeter": float(perimeter), |
|
|
"bbox_area": int(bbox_area), |
|
|
"aspect_ratio": float(aspect_ratio), |
|
|
"coverage": float(area) / (mask.shape[0] * mask.shape[1]) if mask.size > 0 else 0.0, |
|
|
"num_components": len(contours) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Mask property calculation failed: {e}") |
|
|
return {} |
|
|
|
|
|
def create_overlay(self, image: np.ndarray, mask: np.ndarray, |
|
|
color: Tuple[int, int, int] = (0, 255, 0), |
|
|
alpha: float = 0.5) -> np.ndarray: |
|
|
""" |
|
|
Create overlay of mask on image. |
|
|
|
|
|
Args: |
|
|
image: Base image |
|
|
mask: Binary mask |
|
|
color: Overlay color (B, G, R) |
|
|
alpha: Overlay transparency |
|
|
|
|
|
Returns: |
|
|
Image with mask overlay |
|
|
""" |
|
|
try: |
|
|
overlay = image.copy() |
|
|
overlay[mask == 255] = color |
|
|
return cv2.addWeighted(image, 1.0 - alpha, overlay, alpha, 0) |
|
|
except Exception as e: |
|
|
logger.error(f"Overlay creation failed: {e}") |
|
|
return image |
|
|
|