File size: 3,553 Bytes
fe98e45 fb6147c fe98e45 fb6147c fe98e45 fb6147c eaa9aed fb6147c fe98e45 fb6147c c65bcc1 fb6147c fe98e45 fb6147c fe98e45 fb6147c c65bcc1 fe98e45 eaa9aed 4f29f63 fb6147c c65bcc1 fb6147c 4f29f63 eaa9aed 4f29f63 eaa9aed fe98e45 fb6147c 4f29f63 fe98e45 fb6147c 4f29f63 fb6147c c65bcc1 fb6147c 4f29f63 fe98e45 c65bcc1 fe98e45 c65bcc1 fe98e45 c65bcc1 fe98e45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from typing import Dict, Any, Tuple
import os
import requests
from io import BytesIO
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# --- 1. Configuración ---
torch.set_float32_matmul_precision(["high", "highest"][0])
device = "cuda" if torch.cuda.is_available() else "cpu"
usage_to_weights_file = {
'General': 'BiRefNet',
'General-Lite': 'BiRefNet_lite',
'General-Lite-2K': 'BiRefNet_lite-2K',
'General-reso_512': 'BiRefNet-reso_512',
'General-HR': 'BiRefNet_HR'
}
usage = 'General'
resolution = (1024, 1024)
half_precision = True
class ImagePreprocessor():
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
self.transform_image = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image: Image.Image) -> torch.Tensor:
image = self.transform_image(image)
return image
class EndpointHandler():
def __init__(self, path=''):
# Carga del modelo
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
'/'.join(('zhengpeng7', usage_to_weights_file[usage])),
trust_remote_code=True
)
self.birefnet.to(device)
self.birefnet.eval()
if half_precision:
self.birefnet.half()
def __call__(self, data: Dict[str, Any]):
# --- PASO 1: Carga Segura de la Imagen ---
image_src = data["inputs"]
image_ori = None
# Detectamos qué nos enviaron (Objeto, URL o Bytes)
if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
image_ori = image_src
elif isinstance(image_src, str):
if os.path.isfile(image_src):
image_ori = Image.open(image_src)
else:
response = requests.get(image_src)
image_ori = Image.open(BytesIO(response.content))
else:
try:
image_ori = Image.open(BytesIO(image_src))
except Exception:
try:
image_ori = Image.fromarray(image_src)
except Exception:
image_ori = image_src
# Convertimos a RGB (Esto limpia cualquier rareza del archivo original y asegura color)
image = image_ori.convert('RGB')
# --- PASO 2: La IA detecta la silueta ---
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
image_proc = image_preprocessor.proc(image)
image_proc = image_proc.unsqueeze(0)
with torch.no_grad():
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# --- PASO 3: Recorte Limpio (Sin matemáticas raras) ---
# Convertimos la predicción en una máscara (imagen en blanco y negro)
mask_pil = transforms.ToPILImage()(pred)
# Redimensionamos la máscara al tamaño EXACTO de la foto original
mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS)
# ✨ MAGIA: Simplemente le decimos a la foto original "Usa esta transparencia"
# No tocamos los canales de color (RGB), solo añadimos el canal Alpha.
image.putalpha(mask_pil)
return image |