Commit ·
838299a
1
Parent(s): 2120bf6
Fixes
Browse files
Model_Architecture/generation.py
CHANGED
|
@@ -69,7 +69,11 @@ def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temper
|
|
| 69 |
logits = model(idx_cond)
|
| 70 |
|
| 71 |
# Focus only on the last time step
|
| 72 |
-
logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Optional: apply top-k filtering
|
| 75 |
if top_k is not None:
|
|
@@ -77,7 +81,15 @@ def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temper
|
|
| 77 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 78 |
|
| 79 |
# Apply softmax to get probabilities
|
| 80 |
-
probs = torch.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# Sample from the distribution
|
| 83 |
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
| 69 |
logits = model(idx_cond)
|
| 70 |
|
| 71 |
# Focus only on the last time step
|
| 72 |
+
logits = logits[:, -1, :]
|
| 73 |
+
|
| 74 |
+
# Clamp temperature to avoid division by very small numbers
|
| 75 |
+
temperature = max(temperature, 1e-8)
|
| 76 |
+
logits = logits / temperature
|
| 77 |
|
| 78 |
# Optional: apply top-k filtering
|
| 79 |
if top_k is not None:
|
|
|
|
| 81 |
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 82 |
|
| 83 |
# Apply softmax to get probabilities
|
| 84 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 85 |
+
|
| 86 |
+
# Handle edge cases: check for invalid probabilities
|
| 87 |
+
if torch.isnan(probs).any() or torch.isinf(probs).any():
|
| 88 |
+
# Fallback to uniform distribution over valid tokens
|
| 89 |
+
probs = torch.ones_like(probs) / probs.size(-1)
|
| 90 |
+
|
| 91 |
+
# Ensure probabilities sum to 1
|
| 92 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 93 |
|
| 94 |
# Sample from the distribution
|
| 95 |
idx_next = torch.multinomial(probs, num_samples=1)
|