Update
Browse files- handler.py +20 -37
handler.py
CHANGED
|
@@ -24,14 +24,15 @@ def refine_foreground(image, mask, r=90):
|
|
| 24 |
return image_masked
|
| 25 |
|
| 26 |
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
| 27 |
-
# Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
|
| 28 |
alpha = alpha[:, :, None]
|
| 29 |
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
| 30 |
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
| 31 |
|
| 32 |
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
| 33 |
-
|
|
|
|
| 34 |
image = np.array(image) / 255.0
|
|
|
|
| 35 |
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
| 36 |
blurred_FA = cv2.blur(F * alpha, (r, r))
|
| 37 |
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
|
@@ -69,19 +70,8 @@ usage_to_weights_file = {
|
|
| 69 |
'General-legacy': 'BiRefNet-legacy'
|
| 70 |
}
|
| 71 |
|
| 72 |
-
# Choose the version of BiRefNet here.
|
| 73 |
usage = 'General'
|
| 74 |
-
|
| 75 |
-
# Set resolution
|
| 76 |
-
if usage in ['General-Lite-2K']:
|
| 77 |
-
resolution = (2560, 1440)
|
| 78 |
-
elif usage in ['General-reso_512']:
|
| 79 |
-
resolution = (512, 512)
|
| 80 |
-
elif usage in ['General-HR', 'Matting-HR']:
|
| 81 |
-
resolution = (2048, 2048)
|
| 82 |
-
else:
|
| 83 |
-
resolution = (1024, 1024)
|
| 84 |
-
|
| 85 |
half_precision = True
|
| 86 |
|
| 87 |
class EndpointHandler():
|
|
@@ -95,21 +85,15 @@ class EndpointHandler():
|
|
| 95 |
self.birefnet.half()
|
| 96 |
|
| 97 |
def __call__(self, data: Dict[str, Any]):
|
| 98 |
-
"""
|
| 99 |
-
data args:
|
| 100 |
-
inputs (:obj: `str`)
|
| 101 |
-
date (:obj: `str`)
|
| 102 |
-
Return:
|
| 103 |
-
A :obj:`list` | `dict`: will be serialized and returned
|
| 104 |
-
"""
|
| 105 |
-
print('data["inputs"] = ', data["inputs"])
|
| 106 |
image_src = data["inputs"]
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
#
|
| 110 |
-
|
| 111 |
-
if isinstance(image_src, Image.Image):
|
| 112 |
image_ori = image_src
|
|
|
|
|
|
|
| 113 |
elif isinstance(image_src, str):
|
| 114 |
if os.path.isfile(image_src):
|
| 115 |
image_ori = Image.open(image_src)
|
|
@@ -117,32 +101,31 @@ class EndpointHandler():
|
|
| 117 |
response = requests.get(image_src)
|
| 118 |
image_data = BytesIO(response.content)
|
| 119 |
image_ori = Image.open(image_data)
|
|
|
|
|
|
|
| 120 |
else:
|
| 121 |
try:
|
| 122 |
-
#
|
| 123 |
-
image_ori = Image.
|
| 124 |
except Exception:
|
| 125 |
-
# Fallback: Intento leer como bytes crudos (para Odoo)
|
| 126 |
try:
|
| 127 |
-
|
| 128 |
-
except Exception:
|
| 129 |
-
# Si falla, intentamos array de nuevo como 煤ltimo recurso
|
| 130 |
image_ori = Image.fromarray(image_src)
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
| 133 |
image = image_ori.convert('RGB')
|
| 134 |
|
| 135 |
-
# Preprocess the image
|
| 136 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 137 |
image_proc = image_preprocessor.proc(image)
|
| 138 |
image_proc = image_proc.unsqueeze(0)
|
| 139 |
|
| 140 |
-
# Prediction
|
| 141 |
with torch.no_grad():
|
| 142 |
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
|
| 143 |
pred = preds[0].squeeze()
|
| 144 |
|
| 145 |
-
# Show Results
|
| 146 |
pred_pil = transforms.ToPILImage()(pred)
|
| 147 |
image_masked = refine_foreground(image, pred_pil)
|
| 148 |
image_masked.putalpha(pred_pil.resize(image.size))
|
|
|
|
| 24 |
return image_masked
|
| 25 |
|
| 26 |
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
|
|
|
| 27 |
alpha = alpha[:, :, None]
|
| 28 |
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
| 29 |
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
| 30 |
|
| 31 |
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
| 32 |
+
# Detecci贸n segura para helpers internos
|
| 33 |
+
if hasattr(image, 'size') or isinstance(image, Image.Image):
|
| 34 |
image = np.array(image) / 255.0
|
| 35 |
+
|
| 36 |
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
| 37 |
blurred_FA = cv2.blur(F * alpha, (r, r))
|
| 38 |
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
|
|
|
| 70 |
'General-legacy': 'BiRefNet-legacy'
|
| 71 |
}
|
| 72 |
|
|
|
|
| 73 |
usage = 'General'
|
| 74 |
+
resolution = (1024, 1024)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
half_precision = True
|
| 76 |
|
| 77 |
class EndpointHandler():
|
|
|
|
| 85 |
self.birefnet.half()
|
| 86 |
|
| 87 |
def __call__(self, data: Dict[str, Any]):
|
| 88 |
+
print('data["inputs"] type:', type(data["inputs"])) # Log para debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
image_src = data["inputs"]
|
| 90 |
|
| 91 |
+
# --- LOGICA BLINDADA ---
|
| 92 |
+
# 1. Si ya es una imagen (tiene atributo 'size' o 'convert'), 煤sala directo.
|
| 93 |
+
if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
|
|
|
|
| 94 |
image_ori = image_src
|
| 95 |
+
|
| 96 |
+
# 2. Si es una ruta de archivo o URL (String)
|
| 97 |
elif isinstance(image_src, str):
|
| 98 |
if os.path.isfile(image_src):
|
| 99 |
image_ori = Image.open(image_src)
|
|
|
|
| 101 |
response = requests.get(image_src)
|
| 102 |
image_data = BytesIO(response.content)
|
| 103 |
image_ori = Image.open(image_data)
|
| 104 |
+
|
| 105 |
+
# 3. 脷ltimo recurso: Bytes crudos o Arrays
|
| 106 |
else:
|
| 107 |
try:
|
| 108 |
+
# Intenta abrirlo como bytes (lo m谩s com煤n si falla el paso 1)
|
| 109 |
+
image_ori = Image.open(BytesIO(image_src))
|
| 110 |
except Exception:
|
|
|
|
| 111 |
try:
|
| 112 |
+
# Intenta como array de numpy
|
|
|
|
|
|
|
| 113 |
image_ori = Image.fromarray(image_src)
|
| 114 |
+
except Exception:
|
| 115 |
+
# Si falla todo, asume que YA es una imagen que fall贸 la detecci贸n
|
| 116 |
+
image_ori = image_src
|
| 117 |
+
# -----------------------
|
| 118 |
+
|
| 119 |
image = image_ori.convert('RGB')
|
| 120 |
|
|
|
|
| 121 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 122 |
image_proc = image_preprocessor.proc(image)
|
| 123 |
image_proc = image_proc.unsqueeze(0)
|
| 124 |
|
|
|
|
| 125 |
with torch.no_grad():
|
| 126 |
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
|
| 127 |
pred = preds[0].squeeze()
|
| 128 |
|
|
|
|
| 129 |
pred_pil = transforms.ToPILImage()(pred)
|
| 130 |
image_masked = refine_foreground(image, pred_pil)
|
| 131 |
image_masked.putalpha(pred_pil.resize(image.size))
|