RepuestosMOM commited on
Commit
fe98e45
·
verified ·
1 Parent(s): c65bcc1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -53
handler.py CHANGED
@@ -1,48 +1,45 @@
1
- # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
- from typing import Dict, List, Any, Tuple
3
  import os
4
  import requests
5
  from io import BytesIO
6
- import cv2 # Importante para el redimensionado preciso
7
- import numpy as np # Importante para la manipulación de píxeles
8
  from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
12
 
13
- # --- Configuración Básica ---
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- class ImagePreprocessor():
18
- def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
19
- self.transform_image = transforms.Compose([
20
- transforms.Resize(resolution),
21
- transforms.ToTensor(),
22
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23
- ])
24
- def proc(self, image: Image.Image) -> torch.Tensor:
25
- image = self.transform_image(image)
26
- return image
27
-
28
  usage_to_weights_file = {
29
  'General': 'BiRefNet',
30
- 'General-HR': 'BiRefNet_HR',
31
  'General-Lite': 'BiRefNet_lite',
32
  'General-Lite-2K': 'BiRefNet_lite-2K',
33
  'General-reso_512': 'BiRefNet-reso_512',
34
- 'General-legacy': 'BiRefNet-legacy'
35
  }
36
 
37
  usage = 'General'
38
  resolution = (1024, 1024)
39
  half_precision = True
40
 
 
 
 
 
 
 
 
 
 
 
 
41
  class EndpointHandler():
42
  def __init__(self, path=''):
43
  # Carga del modelo
44
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
45
- '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
 
46
  )
47
  self.birefnet.to(device)
48
  self.birefnet.eval()
@@ -50,12 +47,11 @@ class EndpointHandler():
50
  self.birefnet.half()
51
 
52
  def __call__(self, data: Dict[str, Any]):
53
- # ---------------------------------------------------------
54
- # 1. LÓGICA BLINDADA DE ENTRADA (Mantenemos lo que ya funcionaba)
55
- # ---------------------------------------------------------
56
  image_src = data["inputs"]
57
  image_ori = None
58
-
 
59
  if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
60
  image_ori = image_src
61
  elif isinstance(image_src, str):
@@ -72,13 +68,11 @@ class EndpointHandler():
72
  image_ori = Image.fromarray(image_src)
73
  except Exception:
74
  image_ori = image_src
75
-
76
- # Convertimos a RGB para asegurar consistencia
77
  image = image_ori.convert('RGB')
78
 
79
- # ---------------------------------------------------------
80
- # 2. INFERENCIA (Detectar qué es fondo y qué es producto)
81
- # ---------------------------------------------------------
82
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
83
  image_proc = image_preprocessor.proc(image)
84
  image_proc = image_proc.unsqueeze(0)
@@ -88,31 +82,15 @@ class EndpointHandler():
88
 
89
  pred = preds[0].squeeze()
90
 
91
- # ---------------------------------------------------------
92
- # 3. RECONSTRUCCIÓN MATEMÁTICA (Solución al problema del negro)
93
- # ---------------------------------------------------------
94
-
95
- # A. Convertimos la predicción a array numpy y normalizamos
96
- mask_np = pred.numpy()
97
- mask_np = (mask_np - mask_np.min()) / (mask_np.max() - mask_np.min() + 1e-8)
98
-
99
- # B. Convertimos la imagen original a matriz de números [Alto, Ancho, 3]
100
- image_np = np.array(image)
101
 
102
- # C. Redimensionamos la máscara al tamaño EXACTO de la imagen original
103
- # (Esto evita deformaciones o bordes extraños)
104
- mask_resized = cv2.resize(mask_np, (image_np.shape[1], image_np.shape[0]))
105
 
106
- # D. Creamos una imagen vacía de 4 canales (RGBA - Rojo, Verde, Azul, Alfa)
107
- rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
108
-
109
- # E. Copiamos los colores ORIGINALES (Sin modificarlos ni mezclarlos)
110
- rgba_image[:, :, :3] = image_np
111
-
112
- # F. Aplicamos la máscara al canal Alfa (Transparencia)
113
- rgba_image[:, :, 3] = (mask_resized * 255).astype(np.uint8)
114
-
115
- # G. Convertimos de vuelta a imagen PIL para devolverla
116
- final_image = Image.fromarray(rgba_image)
117
 
118
- return final_image
 
1
+ from typing import Dict, Any, Tuple
 
2
  import os
3
  import requests
4
  from io import BytesIO
 
 
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
  from transformers import AutoModelForImageSegmentation
9
 
10
+ # --- 1. Configuración ---
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
 
 
 
 
 
 
 
 
 
 
14
  usage_to_weights_file = {
15
  'General': 'BiRefNet',
 
16
  'General-Lite': 'BiRefNet_lite',
17
  'General-Lite-2K': 'BiRefNet_lite-2K',
18
  'General-reso_512': 'BiRefNet-reso_512',
19
+ 'General-HR': 'BiRefNet_HR'
20
  }
21
 
22
  usage = 'General'
23
  resolution = (1024, 1024)
24
  half_precision = True
25
 
26
+ class ImagePreprocessor():
27
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
28
+ self.transform_image = transforms.Compose([
29
+ transforms.Resize(resolution),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
32
+ ])
33
+ def proc(self, image: Image.Image) -> torch.Tensor:
34
+ image = self.transform_image(image)
35
+ return image
36
+
37
  class EndpointHandler():
38
  def __init__(self, path=''):
39
  # Carga del modelo
40
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
41
+ '/'.join(('zhengpeng7', usage_to_weights_file[usage])),
42
+ trust_remote_code=True
43
  )
44
  self.birefnet.to(device)
45
  self.birefnet.eval()
 
47
  self.birefnet.half()
48
 
49
  def __call__(self, data: Dict[str, Any]):
50
+ # --- PASO 1: Carga Segura de la Imagen ---
 
 
51
  image_src = data["inputs"]
52
  image_ori = None
53
+
54
+ # Detectamos qué nos enviaron (Objeto, URL o Bytes)
55
  if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
56
  image_ori = image_src
57
  elif isinstance(image_src, str):
 
68
  image_ori = Image.fromarray(image_src)
69
  except Exception:
70
  image_ori = image_src
71
+
72
+ # Convertimos a RGB (Esto limpia cualquier rareza del archivo original y asegura color)
73
  image = image_ori.convert('RGB')
74
 
75
+ # --- PASO 2: La IA detecta la silueta ---
 
 
76
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
77
  image_proc = image_preprocessor.proc(image)
78
  image_proc = image_proc.unsqueeze(0)
 
82
 
83
  pred = preds[0].squeeze()
84
 
85
+ # --- PASO 3: Recorte Limpio (Sin matemáticas raras) ---
86
+ # Convertimos la predicción en una máscara (imagen en blanco y negro)
87
+ mask_pil = transforms.ToPILImage()(pred)
 
 
 
 
 
 
 
88
 
89
+ # Redimensionamos la máscara al tamaño EXACTO de la foto original
90
+ mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS)
 
91
 
92
+ # MAGIA: Simplemente le decimos a la foto original "Usa esta transparencia"
93
+ # No tocamos los canales de color (RGB), solo añadimos el canal Alpha.
94
+ image.putalpha(mask_pil)
 
 
 
 
 
 
 
 
95
 
96
+ return image