| """ | |
| Minimal segmentation manager. | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageSegmentation | |
| from typing import Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SegmentationManager: | |
| """Minimal BRIA segmentation.""" | |
| def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", | |
| threshold: float = 0.5, trust_remote_code: bool = True, | |
| cache_dir: Optional[str] = None, local_files_only: bool = False): | |
| """Initialize segmentation.""" | |
| self.model_name = model_name | |
| self.threshold = threshold | |
| self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device | |
| logger.info(f"Loading BRIA model: {model_name}") | |
| self.model = AutoModelForImageSegmentation.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| cache_dir=cache_dir if cache_dir else None, | |
| local_files_only=local_files_only, | |
| ).eval().to(self.device) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| logger.info("BRIA model loaded") | |
| def segment_image_soft(self, image: np.ndarray) -> np.ndarray: | |
| """Segment image and return soft mask [0,1].""" | |
| try: | |
| rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(rgb_image) | |
| input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() | |
| original_size = (image.shape[1], image.shape[0]) | |
| soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) | |
| return np.clip(soft_mask, 0.0, 1.0) | |
| except Exception as e: | |
| logger.error(f"Segmentation failed: {e}") | |
| return np.zeros(image.shape[:2], dtype=np.float32) |