klemenk commited on
Commit
9c1f3a9
·
verified ·
1 Parent(s): d329d9e

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +2 -2
modeling_auristream.py CHANGED
@@ -199,7 +199,7 @@ class AuriStream(PreTrainedModel):
199
  return logits, None
200
 
201
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
202
- top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor:
203
  """
204
  Samples an integer from the distribution of logits
205
  Parameters:
@@ -251,7 +251,7 @@ class AuriStream(PreTrainedModel):
251
 
252
  @torch.no_grad()
253
  def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
254
- top_k=500, top_p=0.5, seed=None):
255
  """
256
  Parameters:
257
  seq: torch.Tensor of shape (b, t, n_freq_bins)
 
199
  return logits, None
200
 
201
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
202
+ top_k: int = None, top_p: float = None) -> torch.LongTensor:
203
  """
204
  Samples an integer from the distribution of logits
205
  Parameters:
 
251
 
252
  @torch.no_grad()
253
  def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
254
+ top_k=None, top_p=None, seed=None):
255
  """
256
  Parameters:
257
  seq: torch.Tensor of shape (b, t, n_freq_bins)