from typing import Dict, Any import os import requests from io import BytesIO from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation # Configuración torch.set_float32_matmul_precision(["high", "highest"][0]) device = "cuda" if torch.cuda.is_available() else "cpu" class EndpointHandler(): def __init__(self, path=''): # Cargamos el modelo BiRefNet original (Efectivo y rápido) self.model = AutoModelForImageSegmentation.from_pretrained( 'zhengpeng7/BiRefNet', trust_remote_code=True ) self.model.to(device) self.model.eval() self.model.half() def __call__(self, data: Dict[str, Any]): # 1. RECIBIR IMAGEN (Entrada Blindada) image_src = data["inputs"] image = None if isinstance(image_src, Image.Image): image = image_src elif isinstance(image_src, str): if image_src.startswith('http'): image = Image.open(BytesIO(requests.get(image_src).content)) else: image = Image.open(image_src) else: image = Image.open(BytesIO(image_src)) # 2. LIMPIEZA: Aseguramos RGB (Color Real) image = image.convert("RGB") orig_size = image.size # 3. PROCESAMIENTO IA transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_tensor = transform(image).unsqueeze(0).to(device).half() with torch.no_grad(): preds = self.model(input_tensor)[-1].sigmoid().cpu() # 4. MÁSCARA pred = preds[0].squeeze() mask_pil = transforms.ToPILImage()(pred) mask_pil = mask_pil.resize(orig_size, resample=Image.Resampling.LANCZOS) # 5. APLICACIÓN FINAL (Sin tocar colores) image.putalpha(mask_pil) return image