BiRefNet-Enterprise / handler.py
RepuestosMOM's picture
Update handler.py
fe98e45 verified
raw
history blame
3.55 kB
from typing import Dict, Any, Tuple
import os
import requests
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# --- 1. Configuración ---
torch.set_float32_matmul_precision(["high", "highest"][0])
device = "cuda" if torch.cuda.is_available() else "cpu"
usage_to_weights_file = {
'General': 'BiRefNet',
'General-Lite': 'BiRefNet_lite',
'General-Lite-2K': 'BiRefNet_lite-2K',
'General-reso_512': 'BiRefNet-reso_512',
'General-HR': 'BiRefNet_HR'
}
usage = 'General'
resolution = (1024, 1024)
half_precision = True
class ImagePreprocessor():
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
self.transform_image = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image: Image.Image) -> torch.Tensor:
image = self.transform_image(image)
return image
class EndpointHandler():
def __init__(self, path=''):
# Carga del modelo
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
'/'.join(('zhengpeng7', usage_to_weights_file[usage])),
trust_remote_code=True
)
self.birefnet.to(device)
self.birefnet.eval()
if half_precision:
self.birefnet.half()
def __call__(self, data: Dict[str, Any]):
# --- PASO 1: Carga Segura de la Imagen ---
image_src = data["inputs"]
image_ori = None
# Detectamos qué nos enviaron (Objeto, URL o Bytes)
if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
image_ori = image_src
elif isinstance(image_src, str):
if os.path.isfile(image_src):
image_ori = Image.open(image_src)
else:
response = requests.get(image_src)
image_ori = Image.open(BytesIO(response.content))
else:
try:
image_ori = Image.open(BytesIO(image_src))
except Exception:
try:
image_ori = Image.fromarray(image_src)
except Exception:
image_ori = image_src
# Convertimos a RGB (Esto limpia cualquier rareza del archivo original y asegura color)
image = image_ori.convert('RGB')
# --- PASO 2: La IA detecta la silueta ---
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
image_proc = image_preprocessor.proc(image)
image_proc = image_proc.unsqueeze(0)
with torch.no_grad():
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# --- PASO 3: Recorte Limpio (Sin matemáticas raras) ---
# Convertimos la predicción en una máscara (imagen en blanco y negro)
mask_pil = transforms.ToPILImage()(pred)
# Redimensionamos la máscara al tamaño EXACTO de la foto original
mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS)
# ✨ MAGIA: Simplemente le decimos a la foto original "Usa esta transparencia"
# No tocamos los canales de color (RGB), solo añadimos el canal Alpha.
image.putalpha(mask_pil)
return image