Spaces:
Running
Running
tidy
Browse files
app.py
CHANGED
|
@@ -41,7 +41,7 @@ def generate_step(out: object,
|
|
| 41 |
|
| 42 |
args:
|
| 43 |
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
|
| 44 |
-
- gen_idx (int): location for which to generate
|
| 45 |
- top_k (int): if >0, only sample from the top k most probable words
|
| 46 |
- temperature (float): sampling temperature
|
| 47 |
- typical_p (float): if >0 use typical sampling
|
|
@@ -53,13 +53,13 @@ def generate_step(out: object,
|
|
| 53 |
logits = out.logits[:, gen_idx]
|
| 54 |
warpers = LogitsProcessorList()
|
| 55 |
if temperature:
|
| 56 |
-
warpers
|
| 57 |
if top_k > 0:
|
| 58 |
-
warpers
|
| 59 |
if typical_p > 0:
|
| 60 |
if typical_p >= 1:
|
| 61 |
typical_p = 0.999
|
| 62 |
-
warpers
|
| 63 |
logits = warpers(None, logits)
|
| 64 |
|
| 65 |
if sample:
|
|
|
|
| 41 |
|
| 42 |
args:
|
| 43 |
- out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
|
| 44 |
+
- gen_idx (int): location for which to generate
|
| 45 |
- top_k (int): if >0, only sample from the top k most probable words
|
| 46 |
- temperature (float): sampling temperature
|
| 47 |
- typical_p (float): if >0 use typical sampling
|
|
|
|
| 53 |
logits = out.logits[:, gen_idx]
|
| 54 |
warpers = LogitsProcessorList()
|
| 55 |
if temperature:
|
| 56 |
+
warpers.append(TemperatureLogitsWarper(temperature))
|
| 57 |
if top_k > 0:
|
| 58 |
+
warpers.append(TopKLogitsWarper(top_k))
|
| 59 |
if typical_p > 0:
|
| 60 |
if typical_p >= 1:
|
| 61 |
typical_p = 0.999
|
| 62 |
+
warpers.append(TypicalLogitsWarper(typical_p))
|
| 63 |
logits = warpers(None, logits)
|
| 64 |
|
| 65 |
if sample:
|