gaoyang07
commited on
Commit
·
078ab4e
1
Parent(s):
b1cede0
Fix apply_repetition_penalty
Browse files- streaming_mossttsrealtime.py +2 -12
streaming_mossttsrealtime.py
CHANGED
|
@@ -365,24 +365,14 @@ class MossTTSRealtimeInference:
|
|
| 365 |
repetition_window: Optional[int] = None,
|
| 366 |
):
|
| 367 |
scores_ = scores[:, 0, :]
|
| 368 |
-
batch_size = scores_.shape[0]
|
| 369 |
ht = history_tokens
|
| 370 |
|
| 371 |
if repetition_window is not None and repetition_window > 0:
|
| 372 |
ht = ht[:, -repetition_window:]
|
| 373 |
|
| 374 |
-
|
| 375 |
-
uniq = torch.unique_consecutive(ht_sorted, dim=1)
|
| 376 |
-
|
| 377 |
-
b_idx = torch.arange(batch_size, device=uniq.device).unsqueeze(1).expand_as(uniq)
|
| 378 |
-
b_flat = b_idx.reshape(-1)
|
| 379 |
-
t_flat = uniq.reshape(-1)
|
| 380 |
-
|
| 381 |
-
cur = scores_[b_flat, t_flat]
|
| 382 |
new = torch.where(cur < 0, cur * penalty, cur / penalty)
|
| 383 |
-
|
| 384 |
-
scores_[b_flat, t_flat] = new
|
| 385 |
-
|
| 386 |
return scores_
|
| 387 |
|
| 388 |
def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
|
|
|
|
| 365 |
repetition_window: Optional[int] = None,
|
| 366 |
):
|
| 367 |
scores_ = scores[:, 0, :]
|
|
|
|
| 368 |
ht = history_tokens
|
| 369 |
|
| 370 |
if repetition_window is not None and repetition_window > 0:
|
| 371 |
ht = ht[:, -repetition_window:]
|
| 372 |
|
| 373 |
+
cur = scores_.gather(1, ht)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
new = torch.where(cur < 0, cur * penalty, cur / penalty)
|
| 375 |
+
scores_.scatter_(1, ht, new)
|
|
|
|
|
|
|
| 376 |
return scores_
|
| 377 |
|
| 378 |
def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
|