Update modeling_auristream.py
Browse files- modeling_auristream.py +2 -2
modeling_auristream.py
CHANGED
|
@@ -199,7 +199,7 @@ class AuriStream(PreTrainedModel):
|
|
| 199 |
return logits, None
|
| 200 |
|
| 201 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|
| 202 |
-
top_k: int =
|
| 203 |
"""
|
| 204 |
Samples an integer from the distribution of logits
|
| 205 |
Parameters:
|
|
@@ -251,7 +251,7 @@ class AuriStream(PreTrainedModel):
|
|
| 251 |
|
| 252 |
@torch.no_grad()
|
| 253 |
def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
|
| 254 |
-
top_k=
|
| 255 |
"""
|
| 256 |
Parameters:
|
| 257 |
seq: torch.Tensor of shape (b, t, n_freq_bins)
|
|
|
|
| 199 |
return logits, None
|
| 200 |
|
| 201 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|
| 202 |
+
top_k: int = None, top_p: float = None) -> torch.LongTensor:
|
| 203 |
"""
|
| 204 |
Samples an integer from the distribution of logits
|
| 205 |
Parameters:
|
|
|
|
| 251 |
|
| 252 |
@torch.no_grad()
|
| 253 |
def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
|
| 254 |
+
top_k=None, top_p=None, seed=None):
|
| 255 |
"""
|
| 256 |
Parameters:
|
| 257 |
seq: torch.Tensor of shape (b, t, n_freq_bins)
|