File size: 3,155 Bytes
b4123b8 dd1d7f5 b4123b8 dd1d7f5 b4123b8 dd1d7f5 b4123b8 dd1d7f5 b4123b8 dd1d7f5 dffab99 7ac2007 dd1d7f5 7ac2007 dd1d7f5 dffab99 7ac2007 dd1d7f5 668a993 dd1d7f5 668a993 dd1d7f5 b4123b8 dd1d7f5 b4123b8 31ddfa7 b4123b8 31ddfa7 dd1d7f5 b4123b8 31ddfa7 dd1d7f5 b4123b8 31ddfa7 b4123b8 dd1d7f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
"""
Minimal segmentation manager.
"""
import numpy as np
import cv2
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class SegmentationManager:
"""Minimal BRIA segmentation."""
def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto",
threshold: float = 0.5, trust_remote_code: bool = True,
cache_dir: Optional[str] = None, local_files_only: bool = False):
"""Initialize segmentation."""
self.model_name = model_name
self.threshold = threshold
self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device
# Get HF token from environment (set as Space secret)
import os
hf_token = os.environ.get("HF_TOKEN")
# Set cache directory to /tmp to avoid persistent storage issues
if cache_dir is None:
cache_dir = "/tmp/huggingface_cache"
logger.info(f"Loading BRIA model: {model_name} (cache: {cache_dir})")
self.model = AutoModelForImageSegmentation.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir,
local_files_only=local_files_only,
token=hf_token,
low_cpu_mem_usage=True, # Reduce memory usage during loading
).eval().to(self.device)
# Use 512x512 for 4x speed improvement
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
logger.info("BRIA model loaded")
def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
"""Segment image and return soft mask [0,1]."""
try:
logger.info(f"Segmentation: input image shape={image.shape}, dtype={image.dtype}")
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_image)
input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
try:
logger.info(f"Segmentation: tensor shape={input_tensor.shape}, device={self.device}")
except Exception:
pass
with torch.no_grad():
preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
logger.info(f"Segmentation: raw preds shape={preds.shape}, dtype={preds.dtype}")
original_size = (image.shape[1], image.shape[0])
soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
logger.info(f"Segmentation: resized soft_mask shape={soft_mask.shape}, dtype={soft_mask.dtype}")
return np.clip(soft_mask, 0.0, 1.0)
except Exception as e:
logger.error(f"Segmentation failed: {e}")
return np.zeros(image.shape[:2], dtype=np.float32) |