BiRefNet-Enterprise / handler.py
RepuestosMOM's picture
Update handler.py
be6a051 verified
raw
history blame
2.09 kB
from typing import Dict, Any
import os
import requests
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# Configuración
torch.set_float32_matmul_precision(["high", "highest"][0])
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=''):
# Cargamos el modelo BiRefNet original (Efectivo y rápido)
self.model = AutoModelForImageSegmentation.from_pretrained(
'zhengpeng7/BiRefNet',
trust_remote_code=True
)
self.model.to(device)
self.model.eval()
self.model.half()
def __call__(self, data: Dict[str, Any]):
# 1. RECIBIR IMAGEN (Entrada Blindada)
image_src = data["inputs"]
image = None
if isinstance(image_src, Image.Image):
image = image_src
elif isinstance(image_src, str):
if image_src.startswith('http'):
image = Image.open(BytesIO(requests.get(image_src).content))
else:
image = Image.open(image_src)
else:
image = Image.open(BytesIO(image_src))
# 2. LIMPIEZA: Aseguramos RGB (Color Real)
image = image.convert("RGB")
orig_size = image.size
# 3. PROCESAMIENTO IA
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_tensor = transform(image).unsqueeze(0).to(device).half()
with torch.no_grad():
preds = self.model(input_tensor)[-1].sigmoid().cpu()
# 4. MÁSCARA
pred = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(pred)
mask_pil = mask_pil.resize(orig_size, resample=Image.Resampling.LANCZOS)
# 5. APLICACIÓN FINAL (Sin tocar colores)
image.putalpha(mask_pil)
return image