File size: 2,247 Bytes
b4123b8
dd1d7f5
b4123b8
 
 
 
 
 
 
 
dd1d7f5
b4123b8
 
 
 
 
 
dd1d7f5
b4123b8
dd1d7f5
 
 
 
b4123b8
 
dd1d7f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4123b8
 
dd1d7f5
b4123b8
 
 
 
dd1d7f5
b4123b8
 
dd1d7f5
b4123b8
 
 
 
dd1d7f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Minimal segmentation manager.
"""

import numpy as np
import cv2
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Optional
import logging

logger = logging.getLogger(__name__)


class SegmentationManager:
    """Minimal BRIA segmentation."""
    
    def __init__(self, model_name: str = "briaai/RMBG-2.0", device: str = "auto",
                 threshold: float = 0.5, trust_remote_code: bool = True,
                 cache_dir: Optional[str] = None, local_files_only: bool = False):
        """Initialize segmentation."""
        self.model_name = model_name
        self.threshold = threshold
        self.device = "cuda" if device == "auto" and torch.cuda.is_available() else device
        
        logger.info(f"Loading BRIA model: {model_name}")
        self.model = AutoModelForImageSegmentation.from_pretrained(
            model_name,
            trust_remote_code=trust_remote_code,
            cache_dir=cache_dir if cache_dir else None,
            local_files_only=local_files_only,
        ).eval().to(self.device)
        
        self.transform = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        logger.info("BRIA model loaded")
    
    def segment_image_soft(self, image: np.ndarray) -> np.ndarray:
        """Segment image and return soft mask [0,1]."""
        try:
            rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(rgb_image)
            input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                preds = self.model(input_tensor)[-1].sigmoid().cpu()[0].squeeze(0).numpy()
            
            original_size = (image.shape[1], image.shape[0])
            soft_mask = cv2.resize(preds.astype(np.float32), original_size, interpolation=cv2.INTER_LINEAR)
            return np.clip(soft_mask, 0.0, 1.0)
        except Exception as e:
            logger.error(f"Segmentation failed: {e}")
            return np.zeros(image.shape[:2], dtype=np.float32)