lukeafullard commited on
Commit
5b003bf
·
verified ·
1 Parent(s): 169cdb3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +23 -12
src/streamlit_app.py CHANGED
@@ -36,7 +36,7 @@ def load_upscaler(scale=2):
36
 
37
  def safe_rembg_inference(model, image, device):
38
  """
39
- Robust inference for RMBG-1.4 that handles different output formats.
40
  """
41
  w, h = image.size
42
 
@@ -52,22 +52,33 @@ def safe_rembg_inference(model, image, device):
52
  with torch.no_grad():
53
  outputs = model(input_images)
54
 
55
- # FIX: Handle List vs Tuple vs Tensor output
56
- # BiRefNet usually returns a list/tuple of tensors.
57
- # The output we want is usually the LAST element or the FIRST depending on version.
58
- # We check if 'outputs' is a sequence (list/tuple) and grab the tensor.
59
- if isinstance(outputs, (list, tuple)):
60
- # We assume the last element is the high-res prediction for RMBG-1.4
61
- result_tensor = outputs[-1]
 
 
 
 
 
 
 
62
 
63
- # Double check: if the result is still a list (nested), grab the first item
64
- if isinstance(result_tensor, (list, tuple)):
65
- result_tensor = result_tensor[0]
 
 
66
  else:
67
  result_tensor = outputs
 
68
 
69
  # Post-processing
70
- pred = result_tensor.sigmoid().cpu()[0].squeeze()
 
71
 
72
  # Convert mask to PIL
73
  pred_pil = transforms.ToPILImage()(pred)
 
36
 
37
  def safe_rembg_inference(model, image, device):
38
  """
39
+ Robust inference for RMBG-1.4 that finds the correct mask tensor.
40
  """
41
  w, h = image.size
42
 
 
52
  with torch.no_grad():
53
  outputs = model(input_images)
54
 
55
+ # --- FIX START ---
56
+ result_tensor = None
57
+
58
+ # Priority 1: Check for explicit 'logits' attribute (Standard Hugging Face)
59
+ if hasattr(outputs, "logits"):
60
+ result_tensor = outputs.logits
61
+
62
+ # Priority 2: Iterate through list/tuple to find the 1-channel mask
63
+ elif isinstance(outputs, (list, tuple)):
64
+ for tensor in outputs:
65
+ # We are looking for shape [Batch, 1, Height, Width]
66
+ if isinstance(tensor, torch.Tensor) and tensor.dim() == 4 and tensor.shape[1] == 1:
67
+ result_tensor = tensor
68
+ break
69
 
70
+ # Fallback: If no 1-channel tensor found, take the first element
71
+ if result_tensor is None:
72
+ result_tensor = outputs[0]
73
+
74
+ # Priority 3: It's already a tensor
75
  else:
76
  result_tensor = outputs
77
+ # --- FIX END ---
78
 
79
  # Post-processing
80
+ # Squeeze removes batch dim (1, 1, 1024, 1024) -> (1024, 1024)
81
+ pred = result_tensor.squeeze().sigmoid().cpu()
82
 
83
  # Convert mask to PIL
84
  pred_pil = transforms.ToPILImage()(pred)