""" Minimal segmentation manager. """ import numpy as np import cv2 import torch from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Optional import logging logger = logging.getLogger(__name__) class SegmentationManager: """Minimal BRIA segmentation.""" def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto", threshold: float = 0.5, trust_remote_code: bool = True, cache_dir: Optional[str] = None, local_files_only: bool = False): """Initialize segmentation.""" self.model_name = model_name self.threshold = threshold self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device # Get HF token from environment (set as Space secret) import os hf_token = os.environ.get("HF_TOKEN") logger.info(f"Loading BRIA model: {model_name}") self.model = AutoModelForImageSegmentation.from_pretrained( model_name, trust_remote_code=trust_remote_code, cache_dir=cache_dir if cache_dir else None, local_files_only=local_files_only, token=hf_token, ).eval().to(self.device) self.transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) logger.info("BRIA model loaded") def segment_image_soft(self, image: np.ndarray) -> np.ndarray: """Segment image and return soft mask [0,1].""" try: rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_image) input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) with torch.no_grad(): preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy() original_size = (image.shape[1], image.shape[0]) soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR) return np.clip(soft_mask, 0.0, 1.0) except Exception as e: logger.error(f"Segmentation failed: {e}") return np.zeros(image.shape[:2], dtype=np.float32)