RepuestosMOM commited on
Commit
2d27abc
·
verified ·
1 Parent(s): fe98e45

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -62
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any, Tuple
2
  import os
3
  import requests
4
  from io import BytesIO
@@ -7,90 +7,73 @@ import torch
7
  from torchvision import transforms
8
  from transformers import AutoModelForImageSegmentation
9
 
10
- # --- 1. Configuración ---
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- usage_to_weights_file = {
15
- 'General': 'BiRefNet',
16
- 'General-Lite': 'BiRefNet_lite',
17
- 'General-Lite-2K': 'BiRefNet_lite-2K',
18
- 'General-reso_512': 'BiRefNet-reso_512',
19
- 'General-HR': 'BiRefNet_HR'
20
- }
21
-
22
- usage = 'General'
23
- resolution = (1024, 1024)
24
- half_precision = True
25
-
26
- class ImagePreprocessor():
27
- def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
28
- self.transform_image = transforms.Compose([
29
- transforms.Resize(resolution),
30
- transforms.ToTensor(),
31
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
32
- ])
33
- def proc(self, image: Image.Image) -> torch.Tensor:
34
- image = self.transform_image(image)
35
- return image
36
-
37
  class EndpointHandler():
38
  def __init__(self, path=''):
39
- # Carga del modelo
40
- self.birefnet = AutoModelForImageSegmentation.from_pretrained(
41
- '/'.join(('zhengpeng7', usage_to_weights_file[usage])),
42
  trust_remote_code=True
43
  )
44
- self.birefnet.to(device)
45
- self.birefnet.eval()
46
- if half_precision:
47
- self.birefnet.half()
48
 
49
  def __call__(self, data: Dict[str, Any]):
50
- # --- PASO 1: Carga Segura de la Imagen ---
51
  image_src = data["inputs"]
52
- image_ori = None
53
 
54
- # Detectamos qué nos enviaron (Objeto, URL o Bytes)
55
- if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image):
56
- image_ori = image_src
57
  elif isinstance(image_src, str):
58
- if os.path.isfile(image_src):
59
- image_ori = Image.open(image_src)
60
- else:
61
- response = requests.get(image_src)
62
- image_ori = Image.open(BytesIO(response.content))
63
  else:
 
64
  try:
65
- image_ori = Image.open(BytesIO(image_src))
66
- except Exception:
67
- try:
68
- image_ori = Image.fromarray(image_src)
69
- except Exception:
70
- image_ori = image_src
71
 
72
- # Convertimos a RGB (Esto limpia cualquier rareza del archivo original y asegura color)
73
- image = image_ori.convert('RGB')
 
74
 
75
- # --- PASO 2: La IA detecta la silueta ---
76
- image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
77
- image_proc = image_preprocessor.proc(image)
78
- image_proc = image_proc.unsqueeze(0)
 
 
 
 
 
 
 
 
79
 
80
  with torch.no_grad():
81
- preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
 
82
 
 
83
  pred = preds[0].squeeze()
84
-
85
- # --- PASO 3: Recorte Limpio (Sin matemáticas raras) ---
86
- # Convertimos la predicción en una máscara (imagen en blanco y negro)
87
  mask_pil = transforms.ToPILImage()(pred)
88
 
89
- # Redimensionamos la máscara al tamaño EXACTO de la foto original
90
- mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS)
91
 
92
- # MAGIA: Simplemente le decimos a la foto original "Usa esta transparencia"
93
- # No tocamos los canales de color (RGB), solo añadimos el canal Alpha.
 
94
  image.putalpha(mask_pil)
95
 
96
  return image
 
1
+ from typing import Dict, List, Any
2
  import os
3
  import requests
4
  from io import BytesIO
 
7
  from torchvision import transforms
8
  from transformers import AutoModelForImageSegmentation
9
 
10
+ # --- Configuración ---
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class EndpointHandler():
15
  def __init__(self, path=''):
16
+ # Cargamos BiRefNet (General)
17
+ self.model = AutoModelForImageSegmentation.from_pretrained(
18
+ 'zhengpeng7/BiRefNet',
19
  trust_remote_code=True
20
  )
21
+ self.model.to(device)
22
+ self.model.eval()
23
+ self.model.half() # Usamos media precisión para velocidad
 
24
 
25
  def __call__(self, data: Dict[str, Any]):
26
+ # 1. RECIBIR IMAGEN (Entrada Blindada)
27
  image_src = data["inputs"]
28
+ image = None
29
 
30
+ # Detectar si es Bytes, URL o PIL Image
31
+ if isinstance(image_src, Image.Image):
32
+ image = image_src
33
  elif isinstance(image_src, str):
34
+ if os.path.exists(image_src):
35
+ image = Image.open(image_src)
36
+ elif image_src.startswith('http'):
37
+ image = Image.open(BytesIO(requests.get(image_src).content))
 
38
  else:
39
+ # Asumimos bytes
40
  try:
41
+ image = Image.open(BytesIO(image_src))
42
+ except:
43
+ # Fallback final
44
+ image = Image.fromarray(image_src)
 
 
45
 
46
+ # 2. LIMPIEZA DE COLOR (CRUCIAL)
47
+ # Convertimos a RGB puro para eliminar cualquier rareza del archivo original
48
+ image = image.convert("RGB")
49
 
50
+ # Guardamos el tamaño original para luego
51
+ orig_size = image.size
52
+
53
+ # 3. PROCESAMIENTO IA
54
+ # Transformación estándar para BiRefNet (1024x1024)
55
+ transform = transforms.Compose([
56
+ transforms.Resize((1024, 1024)),
57
+ transforms.ToTensor(),
58
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
59
+ ])
60
+
61
+ input_tensor = transform(image).unsqueeze(0).to(device).half()
62
 
63
  with torch.no_grad():
64
+ # Predicción
65
+ preds = self.model(input_tensor)[-1].sigmoid().cpu()
66
 
67
+ # 4. MÁSCARA (Sin Numpy, Solo PIL)
68
  pred = preds[0].squeeze()
 
 
 
69
  mask_pil = transforms.ToPILImage()(pred)
70
 
71
+ # Redimensionamos la máscara al tamaño EXACTO de la imagen original
72
+ mask_pil = mask_pil.resize(orig_size, resample=Image.Resampling.LANCZOS)
73
 
74
+ # 5. APLICACIÓN FINAL
75
+ # Tomamos la imagen RGB original y le "inyectamos" la máscara en el canal Alfa.
76
+ # NO tocamos los colores. Solo decimos qué es transparente.
77
  image.putalpha(mask_pil)
78
 
79
  return image