Update generation_utils.py
Browse files- generation_utils.py +2 -1
generation_utils.py
CHANGED
|
@@ -369,7 +369,8 @@ class DreamGenerationMixin:
|
|
| 369 |
|
| 370 |
# 仅 rcr=True:引入轻量跟踪,不影响 baseline
|
| 371 |
is_fixed = torch.zeros_like(x, dtype=torch.bool) if rcr else None
|
| 372 |
-
fixed_conf = torch.
|
|
|
|
| 373 |
|
| 374 |
x = generation_tokens_hook_func(None, x, None)
|
| 375 |
for i in range(steps):
|
|
|
|
| 369 |
|
| 370 |
# 仅 rcr=True:引入轻量跟踪,不影响 baseline
|
| 371 |
is_fixed = torch.zeros_like(x, dtype=torch.bool) if rcr else None
|
| 372 |
+
fixed_conf = torch.full(x.shape, float("-inf"), device=x.device, dtype=torch.float32) if rcr else None
|
| 373 |
+
# 存放已确定位置的置信度
|
| 374 |
|
| 375 |
x = generation_tokens_hook_func(None, x, None)
|
| 376 |
for i in range(steps):
|