Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,18 +31,18 @@ def process(image: Image.Image) -> Image.Image:
|
|
| 31 |
input_tensor = transform_image(image).unsqueeze(0).to(device)
|
| 32 |
|
| 33 |
with torch.no_grad():
|
| 34 |
-
# RMBG v1.4 returns a tuple, we need the first element
|
| 35 |
preds = model(input_tensor)
|
| 36 |
|
| 37 |
-
# Handle
|
| 38 |
-
if isinstance(preds,
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
else:
|
| 43 |
pred = preds
|
| 44 |
|
| 45 |
-
#
|
| 46 |
mask = pred.sigmoid().cpu()
|
| 47 |
|
| 48 |
# Process the mask
|
|
@@ -91,7 +91,6 @@ def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str]
|
|
| 91 |
|
| 92 |
except Exception as e:
|
| 93 |
print("General error:", e)
|
| 94 |
-
# Add debug info
|
| 95 |
import traceback
|
| 96 |
traceback.print_exc()
|
| 97 |
|
|
|
|
| 31 |
input_tensor = transform_image(image).unsqueeze(0).to(device)
|
| 32 |
|
| 33 |
with torch.no_grad():
|
|
|
|
| 34 |
preds = model(input_tensor)
|
| 35 |
|
| 36 |
+
# Handle list output - extract the tensor from the list
|
| 37 |
+
if isinstance(preds, list):
|
| 38 |
+
# Usually the mask is the last or first element
|
| 39 |
+
pred = preds[-1] if len(preds) > 0 else preds[0]
|
| 40 |
+
elif isinstance(preds, tuple):
|
| 41 |
+
pred = preds[0]
|
| 42 |
else:
|
| 43 |
pred = preds
|
| 44 |
|
| 45 |
+
# Now apply sigmoid to the tensor
|
| 46 |
mask = pred.sigmoid().cpu()
|
| 47 |
|
| 48 |
# Process the mask
|
|
|
|
| 91 |
|
| 92 |
except Exception as e:
|
| 93 |
print("General error:", e)
|
|
|
|
| 94 |
import traceback
|
| 95 |
traceback.print_exc()
|
| 96 |
|