Spaces:
Sleeping
Sleeping
sanchit-gandhi
commited on
Commit
·
6f5cea7
1
Parent(s):
5039fa6
generation logic
Browse files
app.py
CHANGED
|
@@ -149,10 +149,22 @@ class ParlerTTSStreamer(BaseStreamer):
|
|
| 149 |
# send the input_ids to the correct device
|
| 150 |
input_ids = input_ids.to(self.audio_encoder.device)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
input_ids
|
| 154 |
-
|
|
|
|
| 155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
audio_values = output_values.audio_values[0, 0]
|
| 157 |
return audio_values.cpu().float().numpy()
|
| 158 |
|
|
|
|
| 149 |
# send the input_ids to the correct device
|
| 150 |
input_ids = input_ids.to(self.audio_encoder.device)
|
| 151 |
|
| 152 |
+
decode_sequentially = (
|
| 153 |
+
self.generation_config.bos_token_id in input_ids
|
| 154 |
+
or self.generation_config.pad_token_id in input_ids
|
| 155 |
+
or self.generation_config.eos_token_id in input_ids
|
| 156 |
)
|
| 157 |
+
if not decode_sequentially:
|
| 158 |
+
output_values = self.audio_encoder.decode(
|
| 159 |
+
input_ids,
|
| 160 |
+
audio_scales=[None],
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
sample = input_ids[:, 0]
|
| 164 |
+
sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
|
| 165 |
+
sample = sample[:, :, sample_mask]
|
| 166 |
+
output_values = self.audio_encoder.decode(sample[None, ...], [None])
|
| 167 |
+
|
| 168 |
audio_values = output_values.audio_values[0, 0]
|
| 169 |
return audio_values.cpu().float().numpy()
|
| 170 |
|