""" Minimal segmentation manager. """ import numpy as np import cv2 import torch from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Optional import logging logger = logging.getLogger(__name__) class SegmentationManager: """Minimal BRIA segmentation.""" def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", threshold: float = 0.5, trust_remote_code: bool = True, cache_dir: Optional[str] = None, local_files_only: bool = False): """Initialize segmentation.""" self.model_name = model_name self.threshold = threshold self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device logger.info(f"Loading BRIA model: {model_name}") self.model = AutoModelForImageSegmentation.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir if cache_dir else None, local_files_only=local_files_only, ).eval().to(self.device) self.transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) logger.info("BRIA model loaded") def segment_image_soft(self, image: np.ndarray) -> np.ndarray: """Segment image and return soft mask [0,1].""" try: rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_image) input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) with torch.no_grad(): preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() original_size = (image.shape[1], image.shape[0]) soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) return np.clip(soft_mask, 0.0, 1.0) except Exception as e: logger.error(f"Segmentation failed: {e}") return np.zeros(image.shape[:2], dtype=np.float32)