RepuestosMOM commited on
Commit
c65bcc1
·
verified ·
1 Parent(s): eaa9aed

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -57
handler.py CHANGED
@@ -3,46 +3,17 @@ from typing import Dict, List, Any, Tuple
3
  import os
4
  import requests
5
  from io import BytesIO
6
- import cv2
7
- import numpy as np
8
  from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
12
 
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- ### image_proc.py
17
- def refine_foreground(image, mask, r=90):
18
- if mask.size != image.size:
19
- mask = mask.resize(image.size)
20
- image = np.array(image) / 255.0
21
- mask = np.array(mask) / 255.0
22
- estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
23
- image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
24
- return image_masked
25
-
26
- def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
27
- alpha = alpha[:, :, None]
28
- F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
29
- return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
30
-
31
- def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
32
- # Detección segura para helpers internos
33
- if hasattr(image, 'size') or isinstance(image, Image.Image):
34
- image = np.array(image) / 255.0
35
-
36
- blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
37
- blurred_FA = cv2.blur(F * alpha, (r, r))
38
- blurred_F = blurred_FA / (blurred_alpha + 1e-5)
39
- blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
40
- blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
41
- F = blurred_F + alpha * \
42
- (image - alpha * blurred_F - (1 - alpha) * blurred_B)
43
- F = np.clip(F, 0, 1)
44
- return F, blurred_B
45
-
46
  class ImagePreprocessor():
47
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
48
  self.transform_image = transforms.Compose([
@@ -60,13 +31,6 @@ usage_to_weights_file = {
60
  'General-Lite': 'BiRefNet_lite',
61
  'General-Lite-2K': 'BiRefNet_lite-2K',
62
  'General-reso_512': 'BiRefNet-reso_512',
63
- 'Matting': 'BiRefNet-matting',
64
- 'Matting-HR': 'BiRefNet_HR-Matting',
65
- 'Portrait': 'BiRefNet-portrait',
66
- 'DIS': 'BiRefNet-DIS5K',
67
- 'HRSOD': 'BiRefNet-HRSOD',
68
- 'COD': 'BiRefNet-COD',
69
- 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
70
  'General-legacy': 'BiRefNet-legacy'
71
  }
72
 
@@ -76,6 +40,7 @@ half_precision = True
76
 
77
  class EndpointHandler():
78
  def __init__(self, path=''):
 
79
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
80
  '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
81
  )
@@ -85,48 +50,69 @@ class EndpointHandler():
85
  self.birefnet.half()
86
 
87
  def __call__(self, data: Dict[str, Any]):
88
- print('data["inputs"] type:', type(data["inputs"])) # Log para debug
 
 
89
  image_src = data["inputs"]
 
90
 
91
- # --- LOGICA BLINDADA ---
92
- # 1. Si ya es una imagen (tiene atributo 'size' o 'convert'), úsala directo.
93
  if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
94
  image_ori = image_src
95
-
96
- # 2. Si es una ruta de archivo o URL (String)
97
  elif isinstance(image_src, str):
98
  if os.path.isfile(image_src):
99
  image_ori = Image.open(image_src)
100
  else:
101
  response = requests.get(image_src)
102
- image_data = BytesIO(response.content)
103
- image_ori = Image.open(image_data)
104
-
105
- # 3. Último recurso: Bytes crudos o Arrays
106
  else:
107
  try:
108
- # Intenta abrirlo como bytes (lo más común si falla el paso 1)
109
  image_ori = Image.open(BytesIO(image_src))
110
  except Exception:
111
  try:
112
- # Intenta como array de numpy
113
  image_ori = Image.fromarray(image_src)
114
  except Exception:
115
- # Si falla todo, asume que YA es una imagen que falló la detección
116
  image_ori = image_src
117
- # -----------------------
118
-
119
  image = image_ori.convert('RGB')
120
 
 
 
 
121
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
122
  image_proc = image_preprocessor.proc(image)
123
  image_proc = image_proc.unsqueeze(0)
124
 
125
  with torch.no_grad():
126
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
 
127
  pred = preds[0].squeeze()
128
 
129
- pred_pil = transforms.ToPILImage()(pred)
130
- image_masked = refine_foreground(image, pred_pil)
131
- image_masked.putalpha(pred_pil.resize(image.size))
132
- return image_masked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import requests
5
  from io import BytesIO
6
+ import cv2 # Importante para el redimensionado preciso
7
+ import numpy as np # Importante para la manipulación de píxeles
8
  from PIL import Image
9
  import torch
10
  from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
12
 
13
+ # --- Configuración Básica ---
14
  torch.set_float32_matmul_precision(["high", "highest"][0])
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class ImagePreprocessor():
18
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
19
  self.transform_image = transforms.Compose([
 
31
  'General-Lite': 'BiRefNet_lite',
32
  'General-Lite-2K': 'BiRefNet_lite-2K',
33
  'General-reso_512': 'BiRefNet-reso_512',
 
 
 
 
 
 
 
34
  'General-legacy': 'BiRefNet-legacy'
35
  }
36
 
 
40
 
41
  class EndpointHandler():
42
  def __init__(self, path=''):
43
+ # Carga del modelo
44
  self.birefnet = AutoModelForImageSegmentation.from_pretrained(
45
  '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
46
  )
 
50
  self.birefnet.half()
51
 
52
  def __call__(self, data: Dict[str, Any]):
53
+ # ---------------------------------------------------------
54
+ # 1. LÓGICA BLINDADA DE ENTRADA (Mantenemos lo que ya funcionaba)
55
+ # ---------------------------------------------------------
56
  image_src = data["inputs"]
57
+ image_ori = None
58
 
 
 
59
  if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
60
  image_ori = image_src
 
 
61
  elif isinstance(image_src, str):
62
  if os.path.isfile(image_src):
63
  image_ori = Image.open(image_src)
64
  else:
65
  response = requests.get(image_src)
66
+ image_ori = Image.open(BytesIO(response.content))
 
 
 
67
  else:
68
  try:
 
69
  image_ori = Image.open(BytesIO(image_src))
70
  except Exception:
71
  try:
 
72
  image_ori = Image.fromarray(image_src)
73
  except Exception:
 
74
  image_ori = image_src
75
+
76
+ # Convertimos a RGB para asegurar consistencia
77
  image = image_ori.convert('RGB')
78
 
79
+ # ---------------------------------------------------------
80
+ # 2. INFERENCIA (Detectar qué es fondo y qué es producto)
81
+ # ---------------------------------------------------------
82
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
83
  image_proc = image_preprocessor.proc(image)
84
  image_proc = image_proc.unsqueeze(0)
85
 
86
  with torch.no_grad():
87
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
88
+
89
  pred = preds[0].squeeze()
90
 
91
+ # ---------------------------------------------------------
92
+ # 3. RECONSTRUCCIÓN MATEMÁTICA (Solución al problema del negro)
93
+ # ---------------------------------------------------------
94
+
95
+ # A. Convertimos la predicción a array numpy y normalizamos
96
+ mask_np = pred.numpy()
97
+ mask_np = (mask_np - mask_np.min()) / (mask_np.max() - mask_np.min() + 1e-8)
98
+
99
+ # B. Convertimos la imagen original a matriz de números [Alto, Ancho, 3]
100
+ image_np = np.array(image)
101
+
102
+ # C. Redimensionamos la máscara al tamaño EXACTO de la imagen original
103
+ # (Esto evita deformaciones o bordes extraños)
104
+ mask_resized = cv2.resize(mask_np, (image_np.shape[1], image_np.shape[0]))
105
+
106
+ # D. Creamos una imagen vacía de 4 canales (RGBA - Rojo, Verde, Azul, Alfa)
107
+ rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
108
+
109
+ # E. Copiamos los colores ORIGINALES (Sin modificarlos ni mezclarlos)
110
+ rgba_image[:, :, :3] = image_np
111
+
112
+ # F. Aplicamos la máscara al canal Alfa (Transparencia)
113
+ rgba_image[:, :, 3] = (mask_resized * 255).astype(np.uint8)
114
+
115
+ # G. Convertimos de vuelta a imagen PIL para devolverla
116
+ final_image = Image.fromarray(rgba_image)
117
+
118
+ return final_image