Spaces:
Sleeping
Sleeping
LI Junxing commited on
Commit ·
7063f69
1
Parent(s): ec737d7
修复 RankSEG 输出模式」} তদন্ত? Wait format
Browse files- app.py +4 -2
- app_local.py +4 -2
app.py
CHANGED
|
@@ -168,9 +168,11 @@ def refine_foreground(image, mask, r=90, device='cuda'):
|
|
| 168 |
|
| 169 |
|
| 170 |
def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 173 |
-
rankseg_pred = rankseg.predict(probs).squeeze(0).to(torch.float32)
|
| 174 |
return transforms.ToPILImage()(rankseg_pred)
|
| 175 |
|
| 176 |
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
|
| 171 |
+
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 172 |
+
# return a binary mask for that one channel instead of a multiclass map.
|
| 173 |
+
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 174 |
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 175 |
+
rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
|
| 176 |
return transforms.ToPILImage()(rankseg_pred)
|
| 177 |
|
| 178 |
|
app_local.py
CHANGED
|
@@ -165,9 +165,11 @@ def refine_foreground(image, mask, r=90, device='cuda'):
|
|
| 165 |
|
| 166 |
|
| 167 |
def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 170 |
-
rankseg_pred = rankseg.predict(probs).squeeze(0).to(torch.float32)
|
| 171 |
return transforms.ToPILImage()(rankseg_pred)
|
| 172 |
|
| 173 |
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
|
| 168 |
+
# BiRefNet produces a single foreground probability map, so RankSEG should
|
| 169 |
+
# return a binary mask for that one channel instead of a multiclass map.
|
| 170 |
+
rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
|
| 171 |
probs = pred.unsqueeze(0).unsqueeze(0)
|
| 172 |
+
rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
|
| 173 |
return transforms.ToPILImage()(rankseg_pred)
|
| 174 |
|
| 175 |
|