RepuestosMOM commited on
Commit
4f29f63
·
verified ·
1 Parent(s): fb6147c

Commit changes to main

Browse files
Files changed (1) hide show
  1. handler.py +23 -13
handler.py CHANGED
@@ -11,7 +11,6 @@ from torchvision import transforms
11
  from transformers import AutoModelForImageSegmentation
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
14
-
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  ### image_proc.py
@@ -24,22 +23,18 @@ def refine_foreground(image, mask, r=90):
24
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
25
  return image_masked
26
 
27
-
28
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
29
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
30
  alpha = alpha[:, :, None]
31
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
32
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
33
 
34
-
35
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36
  if isinstance(image, Image.Image):
37
  image = np.array(image) / 255.0
38
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
39
-
40
  blurred_FA = cv2.blur(F * alpha, (r, r))
41
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
42
-
43
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45
  F = blurred_F + alpha * \
@@ -47,7 +42,6 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
47
  F = np.clip(F, 0, 1)
48
  return F, blurred_B
49
 
50
-
51
  class ImagePreprocessor():
52
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
53
  self.transform_image = transforms.Compose([
@@ -55,7 +49,6 @@ class ImagePreprocessor():
55
  transforms.ToTensor(),
56
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57
  ])
58
-
59
  def proc(self, image: Image.Image) -> torch.Tensor:
60
  image = self.transform_image(image)
61
  return image
@@ -111,7 +104,13 @@ class EndpointHandler():
111
  """
112
  print('data["inputs"] = ', data["inputs"])
113
  image_src = data["inputs"]
114
- if isinstance(image_src, str):
 
 
 
 
 
 
115
  if os.path.isfile(image_src):
116
  image_ori = Image.open(image_src)
117
  else:
@@ -119,21 +118,32 @@ class EndpointHandler():
119
  image_data = BytesIO(response.content)
120
  image_ori = Image.open(image_data)
121
  else:
122
- image_ori = Image.fromarray(image_src)
123
-
 
 
 
 
 
 
 
 
 
 
124
  image = image_ori.convert('RGB')
 
125
  # Preprocess the image
126
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
127
  image_proc = image_preprocessor.proc(image)
128
  image_proc = image_proc.unsqueeze(0)
129
-
130
  # Prediction
131
  with torch.no_grad():
132
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
133
  pred = preds[0].squeeze()
134
-
135
  # Show Results
136
  pred_pil = transforms.ToPILImage()(pred)
137
  image_masked = refine_foreground(image, pred_pil)
138
  image_masked.putalpha(pred_pil.resize(image.size))
139
- return image_masked
 
11
  from transformers import AutoModelForImageSegmentation
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  ### image_proc.py
 
23
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
24
  return image_masked
25
 
 
26
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
27
  # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
28
  alpha = alpha[:, :, None]
29
  F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
30
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
31
 
 
32
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
33
  if isinstance(image, Image.Image):
34
  image = np.array(image) / 255.0
35
  blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
 
36
  blurred_FA = cv2.blur(F * alpha, (r, r))
37
  blurred_F = blurred_FA / (blurred_alpha + 1e-5)
 
38
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
39
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
40
  F = blurred_F + alpha * \
 
42
  F = np.clip(F, 0, 1)
43
  return F, blurred_B
44
 
 
45
  class ImagePreprocessor():
46
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
47
  self.transform_image = transforms.Compose([
 
49
  transforms.ToTensor(),
50
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
51
  ])
 
52
  def proc(self, image: Image.Image) -> torch.Tensor:
53
  image = self.transform_image(image)
54
  return image
 
104
  """
105
  print('data["inputs"] = ', data["inputs"])
106
  image_src = data["inputs"]
107
+
108
+ # ------------------------------------------------------------------
109
+ # MODIFICACION REPUESTOS MOM: Soporte para imágenes directas (Bytes/PIL)
110
+ # ------------------------------------------------------------------
111
+ if isinstance(image_src, Image.Image):
112
+ image_ori = image_src
113
+ elif isinstance(image_src, str):
114
  if os.path.isfile(image_src):
115
  image_ori = Image.open(image_src)
116
  else:
 
118
  image_data = BytesIO(response.content)
119
  image_ori = Image.open(image_data)
120
  else:
121
+ try:
122
+ # Intento leer como array (comportamiento original)
123
+ image_ori = Image.fromarray(image_src)
124
+ except Exception:
125
+ # Fallback: Intento leer como bytes crudos (para Odoo)
126
+ try:
127
+ image_ori = Image.open(BytesIO(image_src))
128
+ except Exception:
129
+ # Si falla, intentamos array de nuevo como último recurso
130
+ image_ori = Image.fromarray(image_src)
131
+ # ------------------------------------------------------------------
132
+
133
  image = image_ori.convert('RGB')
134
+
135
  # Preprocess the image
136
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
137
  image_proc = image_preprocessor.proc(image)
138
  image_proc = image_proc.unsqueeze(0)
139
+
140
  # Prediction
141
  with torch.no_grad():
142
  preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
143
  pred = preds[0].squeeze()
144
+
145
  # Show Results
146
  pred_pil = transforms.ToPILImage()(pred)
147
  image_masked = refine_foreground(image, pred_pil)
148
  image_masked.putalpha(pred_pil.resize(image.size))
149
+ return image_masked