Update modeling_auristream.py
Browse files- modeling_auristream.py +1 -1
modeling_auristream.py
CHANGED
|
@@ -320,7 +320,7 @@ class AuriStream(PreTrainedModel):
|
|
| 320 |
|
| 321 |
# First prediction of the model is the decoding of the last time bin
|
| 322 |
logits = self.coch_head(x[:, [-1]])
|
| 323 |
-
predictions = [self.sample_logits(logits, temperature=temp)]
|
| 324 |
all_logits.append(logits)
|
| 325 |
|
| 326 |
### Predict future tokens
|
|
|
|
| 320 |
|
| 321 |
# First prediction of the model is the decoding of the last time bin
|
| 322 |
logits = self.coch_head(x[:, [-1]])
|
| 323 |
+
predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
|
| 324 |
all_logits.append(logits)
|
| 325 |
|
| 326 |
### Predict future tokens
|