LI Junxing commited on
Commit
7063f69
·
1 Parent(s): ec737d7

修复 RankSEG 输出模式」} তদন্ত? Wait format

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. app_local.py +4 -2
app.py CHANGED
@@ -168,9 +168,11 @@ def refine_foreground(image, mask, r=90, device='cuda'):
168
 
169
 
170
  def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
171
- rankseg = RankSEG(metric=metric, output_mode='multiclass', solver='RMA')
 
 
172
  probs = pred.unsqueeze(0).unsqueeze(0)
173
- rankseg_pred = rankseg.predict(probs).squeeze(0).to(torch.float32)
174
  return transforms.ToPILImage()(rankseg_pred)
175
 
176
 
 
168
 
169
 
170
  def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
171
+ # BiRefNet produces a single foreground probability map, so RankSEG should
172
+ # return a binary mask for that one channel instead of a multiclass map.
173
+ rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
174
  probs = pred.unsqueeze(0).unsqueeze(0)
175
+ rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
176
  return transforms.ToPILImage()(rankseg_pred)
177
 
178
 
app_local.py CHANGED
@@ -165,9 +165,11 @@ def refine_foreground(image, mask, r=90, device='cuda'):
165
 
166
 
167
  def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
168
- rankseg = RankSEG(metric=metric, output_mode='multiclass', solver='RMA')
 
 
169
  probs = pred.unsqueeze(0).unsqueeze(0)
170
- rankseg_pred = rankseg.predict(probs).squeeze(0).to(torch.float32)
171
  return transforms.ToPILImage()(rankseg_pred)
172
 
173
 
 
165
 
166
 
167
  def get_rankseg_mask(pred: torch.Tensor, metric: str) -> Image.Image:
168
+ # BiRefNet produces a single foreground probability map, so RankSEG should
169
+ # return a binary mask for that one channel instead of a multiclass map.
170
+ rankseg = RankSEG(metric=metric, output_mode='multilabel', solver='RMA')
171
  probs = pred.unsqueeze(0).unsqueeze(0)
172
+ rankseg_pred = rankseg.predict(probs).squeeze(0).squeeze(0).to(torch.float32)
173
  return transforms.ToPILImage()(rankseg_pred)
174
 
175