gaoyang07 commited on
Commit
078ab4e
·
1 Parent(s): b1cede0

Fix apply_repetition_penalty

Browse files
Files changed (1) hide show
  1. streaming_mossttsrealtime.py +2 -12
streaming_mossttsrealtime.py CHANGED
@@ -365,24 +365,14 @@ class MossTTSRealtimeInference:
365
  repetition_window: Optional[int] = None,
366
  ):
367
  scores_ = scores[:, 0, :]
368
- batch_size = scores_.shape[0]
369
  ht = history_tokens
370
 
371
  if repetition_window is not None and repetition_window > 0:
372
  ht = ht[:, -repetition_window:]
373
 
374
- ht_sorted, _ = torch.sort(ht, dim=1)
375
- uniq = torch.unique_consecutive(ht_sorted, dim=1)
376
-
377
- b_idx = torch.arange(batch_size, device=uniq.device).unsqueeze(1).expand_as(uniq)
378
- b_flat = b_idx.reshape(-1)
379
- t_flat = uniq.reshape(-1)
380
-
381
- cur = scores_[b_flat, t_flat]
382
  new = torch.where(cur < 0, cur * penalty, cur / penalty)
383
-
384
- scores_[b_flat, t_flat] = new
385
-
386
  return scores_
387
 
388
  def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
 
365
  repetition_window: Optional[int] = None,
366
  ):
367
  scores_ = scores[:, 0, :]
 
368
  ht = history_tokens
369
 
370
  if repetition_window is not None and repetition_window > 0:
371
  ht = ht[:, -repetition_window:]
372
 
373
+ cur = scores_.gather(1, ht)
 
 
 
 
 
 
 
374
  new = torch.where(cur < 0, cur * penalty, cur / penalty)
375
+ scores_.scatter_(1, ht, new)
 
 
376
  return scores_
377
 
378
  def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):