File size: 2,331 Bytes
8c66073
 
0257a40
 
96f45fe
 
0257a40
 
 
96f45fe
0257a40
 
8c66073
 
f560293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os

# ✅ Use /tmp/hf_cache (always writable in Hugging Face Spaces)
CACHE_DIR = "/tmp/hf_cache"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR

# Create the directory safely
os.makedirs(CACHE_DIR, exist_ok=True)
print("✅ Using Hugging Face cache directory:", CACHE_DIR)



import torch
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from transformers import AutoModelForImageSegmentation
from PIL import Image
from skimage import io
import io as sysio


class RMBGRemover:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForImageSegmentation.from_pretrained(
            "briaai/RMBG-1.4", trust_remote_code=True
        ).to(self.device)
        self.model.eval()

    def preprocess(self, im: np.ndarray, model_input_size: list) -> torch.Tensor:
        if len(im.shape) < 3:
            im = im[:, :, np.newaxis]
        im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
        im_tensor = F.interpolate(im_tensor.unsqueeze(0), size=model_input_size, mode="bilinear")
        image = im_tensor / 255.0
        image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
        return image

    def postprocess(self, result: torch.Tensor, im_size: list) -> np.ndarray:
        result = F.interpolate(result, size=im_size, mode="bilinear").squeeze(0)
        result = (result - result.min()) / (result.max() - result.min())
        im_array = (result * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
        return np.squeeze(im_array)

    def remove_background(self, image_bytes: bytes) -> bytes:
        im = Image.open(sysio.BytesIO(image_bytes)).convert("RGB")
        np_im = np.array(im)
        im_size = np_im.shape[0:2]

        model_input_size = [1024, 1024]
        image = self.preprocess(np_im, model_input_size).to(self.device)

        with torch.no_grad():
            result = self.model(image)

        mask = self.postprocess(result[0][0], im_size)
        pil_mask = Image.fromarray(mask)
        im.putalpha(pil_mask)

        out_bytes = sysio.BytesIO()
        im.save(out_bytes, format="PNG")
        out_bytes.seek(0)
        return out_bytes