Spaces:
Sleeping
Sleeping
reorder LogitsWarpers and tidy
Browse files
app.py
CHANGED
|
@@ -51,16 +51,16 @@ def generate_step(out: object,
|
|
| 51 |
- list: batch_size tokens
|
| 52 |
"""
|
| 53 |
logits = out.logits[:, gen_idx]
|
| 54 |
-
|
| 55 |
-
if top_k > 0:
|
| 56 |
-
logit_warpers += [TopKLogitsWarper(top_k)]
|
| 57 |
if temperature:
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
if typical_p > 0:
|
| 60 |
if typical_p >= 1:
|
| 61 |
typical_p = 0.999
|
| 62 |
-
|
| 63 |
-
logits =
|
| 64 |
|
| 65 |
if sample:
|
| 66 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
|
| 51 |
- list: batch_size tokens
|
| 52 |
"""
|
| 53 |
logits = out.logits[:, gen_idx]
|
| 54 |
+
warpers = LogitsProcessorList()
|
|
|
|
|
|
|
| 55 |
if temperature:
|
| 56 |
+
warpers += [TemperatureLogitsWarper(temperature)]
|
| 57 |
+
if top_k > 0:
|
| 58 |
+
warpers += [TopKLogitsWarper(top_k)]
|
| 59 |
if typical_p > 0:
|
| 60 |
if typical_p >= 1:
|
| 61 |
typical_p = 0.999
|
| 62 |
+
warpers += [TypicalLogitsWarper(typical_p)]
|
| 63 |
+
logits = warpers(None, logits)
|
| 64 |
|
| 65 |
if sample:
|
| 66 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|