Spaces:
Runtime error
Runtime error
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
|
@@ -60,8 +60,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
|
|
| 60 |
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 61 |
if top_k > 0 or top_p < 1.0:
|
| 62 |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 63 |
-
print(logits.sum())
|
| 64 |
-
print(logits)
|
| 65 |
probs = F.softmax(logits, dim=-1)
|
| 66 |
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
| 67 |
# mask = (probs == values).float()
|
|
@@ -93,6 +91,8 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
|
|
| 93 |
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
| 94 |
if cfg_scale > 1.0:
|
| 95 |
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
|
|
|
|
|
|
| 96 |
logits_combined = logits
|
| 97 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
| 98 |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
|
|
|
| 60 |
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 61 |
if top_k > 0 or top_p < 1.0:
|
| 62 |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
|
|
|
|
|
| 63 |
probs = F.softmax(logits, dim=-1)
|
| 64 |
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
| 65 |
# mask = (probs == values).float()
|
|
|
|
| 91 |
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
| 92 |
if cfg_scale > 1.0:
|
| 93 |
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
| 94 |
+
print(logits.sum())
|
| 95 |
+
print(logits)
|
| 96 |
logits_combined = logits
|
| 97 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
| 98 |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|