Update controllable_blender/generation_methods.py
Browse files
controllable_blender/generation_methods.py
CHANGED
|
@@ -62,7 +62,7 @@ class RerankedTopKSampling(TreeSearch):
|
|
| 62 |
|
| 63 |
if all_penalties.dim() == 1:
|
| 64 |
all_penalties = all_penalties.unsqueeze(0)
|
| 65 |
-
all_penalties = all_penalties.expand(batch_size,
|
| 66 |
|
| 67 |
penalties = torch.gather(all_penalties, -1, indices)
|
| 68 |
penalised_probs = torch.mul(probs, penalties)
|
|
|
|
| 62 |
|
| 63 |
if all_penalties.dim() == 1:
|
| 64 |
all_penalties = all_penalties.unsqueeze(0)
|
| 65 |
+
all_penalties = all_penalties.expand(batch_size, 8008)
|
| 66 |
|
| 67 |
penalties = torch.gather(all_penalties, -1, indices)
|
| 68 |
penalised_probs = torch.mul(probs, penalties)
|