RepuestosMOM commited on
Commit
eaa9aed
verified
1 Parent(s): 4f29f63
Files changed (1) hide show
  1. handler.py +20 -37
handler.py CHANGED
@@ -24,14 +24,15 @@ def refine_foreground(image, mask, r=90):
24
  return image_masked
25
 
26
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
27
- # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
28
  alpha = alpha[:, :, None]
29
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
30
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
31
 
32
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
33
- if isinstance(image, Image.Image):
 
34
  image = np.array(image) / 255.0
 
35
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
36
  blurred_FA = cv2.blur(F * alpha, (r, r))
37
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
@@ -69,19 +70,8 @@ usage_to_weights_file = {
69
  'General-legacy': 'BiRefNet-legacy'
70
  }
71
 
72
- # Choose the version of BiRefNet here.
73
  usage = 'General'
74
-
75
- # Set resolution
76
- if usage in ['General-Lite-2K']:
77
- resolution = (2560, 1440)
78
- elif usage in ['General-reso_512']:
79
- resolution = (512, 512)
80
- elif usage in ['General-HR', 'Matting-HR']:
81
- resolution = (2048, 2048)
82
- else:
83
- resolution = (1024, 1024)
84
-
85
  half_precision = True
86
 
87
  class EndpointHandler():
@@ -95,21 +85,15 @@ class EndpointHandler():
95
  self.birefnet.half()
96
 
97
  def __call__(self, data: Dict[str, Any]):
98
- """
99
- data args:
100
- inputs (:obj: `str`)
101
- date (:obj: `str`)
102
- Return:
103
- A :obj:`list` | `dict`: will be serialized and returned
104
- """
105
- print('data["inputs"] = ', data["inputs"])
106
  image_src = data["inputs"]
107
 
108
- # ------------------------------------------------------------------
109
- # MODIFICACION REPUESTOS MOM: Soporte para im谩genes directas (Bytes/PIL)
110
- # ------------------------------------------------------------------
111
- if isinstance(image_src, Image.Image):
112
  image_ori = image_src
 
 
113
  elif isinstance(image_src, str):
114
  if os.path.isfile(image_src):
115
  image_ori = Image.open(image_src)
@@ -117,32 +101,31 @@ class EndpointHandler():
117
  response = requests.get(image_src)
118
  image_data = BytesIO(response.content)
119
  image_ori = Image.open(image_data)
 
 
120
  else:
121
  try:
122
- # Intento leer como array (comportamiento original)
123
- image_ori = Image.fromarray(image_src)
124
  except Exception:
125
- # Fallback: Intento leer como bytes crudos (para Odoo)
126
  try:
127
- image_ori = Image.open(BytesIO(image_src))
128
- except Exception:
129
- # Si falla, intentamos array de nuevo como 煤ltimo recurso
130
  image_ori = Image.fromarray(image_src)
131
- # ------------------------------------------------------------------
132
-
 
 
 
133
  image = image_ori.convert('RGB')
134
 
135
- # Preprocess the image
136
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
137
  image_proc = image_preprocessor.proc(image)
138
  image_proc = image_proc.unsqueeze(0)
139
 
140
- # Prediction
141
  with torch.no_grad():
142
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
143
  pred = preds[0].squeeze()
144
 
145
- # Show Results
146
  pred_pil = transforms.ToPILImage()(pred)
147
  image_masked = refine_foreground(image, pred_pil)
148
  image_masked.putalpha(pred_pil.resize(image.size))
 
24
  return image_masked
25
 
26
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
 
27
  alpha = alpha[:, :, None]
28
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
29
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
30
 
31
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
32
+ # Detecci贸n segura para helpers internos
33
+ if hasattr(image, 'size') or isinstance(image, Image.Image):
34
  image = np.array(image) / 255.0
35
+
36
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
37
  blurred_FA = cv2.blur(F * alpha, (r, r))
38
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
 
70
  'General-legacy': 'BiRefNet-legacy'
71
  }
72
 
 
73
  usage = 'General'
74
+ resolution = (1024, 1024)
 
 
 
 
 
 
 
 
 
 
75
  half_precision = True
76
 
77
  class EndpointHandler():
 
85
  self.birefnet.half()
86
 
87
  def __call__(self, data: Dict[str, Any]):
88
+ print('data["inputs"] type:', type(data["inputs"])) # Log para debug
 
 
 
 
 
 
 
89
  image_src = data["inputs"]
90
 
91
+ # --- LOGICA BLINDADA ---
92
+ # 1. Si ya es una imagen (tiene atributo 'size' o 'convert'), 煤sala directo.
93
+ if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
 
94
  image_ori = image_src
95
+
96
+ # 2. Si es una ruta de archivo o URL (String)
97
  elif isinstance(image_src, str):
98
  if os.path.isfile(image_src):
99
  image_ori = Image.open(image_src)
 
101
  response = requests.get(image_src)
102
  image_data = BytesIO(response.content)
103
  image_ori = Image.open(image_data)
104
+
105
+ # 3. 脷ltimo recurso: Bytes crudos o Arrays
106
  else:
107
  try:
108
+ # Intenta abrirlo como bytes (lo m谩s com煤n si falla el paso 1)
109
+ image_ori = Image.open(BytesIO(image_src))
110
  except Exception:
 
111
  try:
112
+ # Intenta como array de numpy
 
 
113
  image_ori = Image.fromarray(image_src)
114
+ except Exception:
115
+ # Si falla todo, asume que YA es una imagen que fall贸 la detecci贸n
116
+ image_ori = image_src
117
+ # -----------------------
118
+
119
  image = image_ori.convert('RGB')
120
 
 
121
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
122
  image_proc = image_preprocessor.proc(image)
123
  image_proc = image_proc.unsqueeze(0)
124
 
 
125
  with torch.no_grad():
126
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
127
  pred = preds[0].squeeze()
128
 
 
129
  pred_pil = transforms.ToPILImage()(pred)
130
  image_masked = refine_foreground(image, pred_pil)
131
  image_masked.putalpha(pred_pil.resize(image.size))