File size: 3,553 Bytes
fe98e45
fb6147c
 
 
 
 
 
 
 
fe98e45
fb6147c
 
 
 
 
 
 
 
fe98e45
fb6147c
 
 
eaa9aed
fb6147c
 
fe98e45
 
 
 
 
 
 
 
 
 
 
fb6147c
 
c65bcc1
fb6147c
fe98e45
 
fb6147c
 
 
 
 
 
 
fe98e45
fb6147c
c65bcc1
fe98e45
 
eaa9aed
4f29f63
 
fb6147c
 
 
 
c65bcc1
fb6147c
4f29f63
eaa9aed
4f29f63
 
 
eaa9aed
 
fe98e45
 
fb6147c
4f29f63
fe98e45
fb6147c
 
 
4f29f63
fb6147c
 
c65bcc1
fb6147c
4f29f63
fe98e45
 
 
c65bcc1
fe98e45
 
c65bcc1
fe98e45
 
 
c65bcc1
fe98e45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Dict, Any, Tuple
import os
import requests
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

# --- 1. Configuración ---
torch.set_float32_matmul_precision(["high", "highest"][0])
device = "cuda" if torch.cuda.is_available() else "cpu"

usage_to_weights_file = {
    'General': 'BiRefNet',
    'General-Lite': 'BiRefNet_lite',
    'General-Lite-2K': 'BiRefNet_lite-2K',
    'General-reso_512': 'BiRefNet-reso_512',
    'General-HR': 'BiRefNet_HR'
}

usage = 'General'
resolution = (1024, 1024)
half_precision = True

class ImagePreprocessor():
    def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
        self.transform_image = transforms.Compose([
            transforms.Resize(resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    def proc(self, image: Image.Image) -> torch.Tensor:
        image = self.transform_image(image)
        return image

class EndpointHandler():
    def __init__(self, path=''):
        # Carga del modelo
        self.birefnet = AutoModelForImageSegmentation.from_pretrained(
            '/'.join(('zhengpeng7', usage_to_weights_file[usage])), 
            trust_remote_code=True
        )
        self.birefnet.to(device)
        self.birefnet.eval()
        if half_precision:
            self.birefnet.half()

    def __call__(self, data: Dict[str, Any]):
        # --- PASO 1: Carga Segura de la Imagen ---
        image_src = data["inputs"]
        image_ori = None
        
        # Detectamos qué nos enviaron (Objeto, URL o Bytes)
        if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
            image_ori = image_src
        elif isinstance(image_src, str):
            if os.path.isfile(image_src):
                image_ori = Image.open(image_src)
            else:
                response = requests.get(image_src)
                image_ori = Image.open(BytesIO(response.content))
        else:
            try:
                image_ori = Image.open(BytesIO(image_src))
            except Exception:
                try:
                    image_ori = Image.fromarray(image_src)
                except Exception:
                    image_ori = image_src

        # Convertimos a RGB (Esto limpia cualquier rareza del archivo original y asegura color)
        image = image_ori.convert('RGB')
        
        # --- PASO 2: La IA detecta la silueta ---
        image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
        image_proc = image_preprocessor.proc(image)
        image_proc = image_proc.unsqueeze(0)
        
        with torch.no_grad():
            preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
        
        pred = preds[0].squeeze()
        
        # --- PASO 3: Recorte Limpio (Sin matemáticas raras) ---
        # Convertimos la predicción en una máscara (imagen en blanco y negro)
        mask_pil = transforms.ToPILImage()(pred)
        
        # Redimensionamos la máscara al tamaño EXACTO de la foto original
        mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS)
        
        # ✨ MAGIA: Simplemente le decimos a la foto original "Usa esta transparencia"
        # No tocamos los canales de color (RGB), solo añadimos el canal Alpha.
        image.putalpha(mask_pil)
        
        return image