ikaganacar commited on
Commit
838299a
·
1 Parent(s): 2120bf6
Files changed (1) hide show
  1. Model_Architecture/generation.py +14 -2
Model_Architecture/generation.py CHANGED
@@ -69,7 +69,11 @@ def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temper
69
  logits = model(idx_cond)
70
 
71
  # Focus only on the last time step
72
- logits = logits[:, -1, :] / temperature
 
 
 
 
73
 
74
  # Optional: apply top-k filtering
75
  if top_k is not None:
@@ -77,7 +81,15 @@ def generate_text_with_sampling(model, idx, max_new_tokens, context_size, temper
77
  logits[logits < v[:, [-1]]] = -float('Inf')
78
 
79
  # Apply softmax to get probabilities
80
- probs = torch.softmax(logits, dim=-1)
 
 
 
 
 
 
 
 
81
 
82
  # Sample from the distribution
83
  idx_next = torch.multinomial(probs, num_samples=1)
 
69
  logits = model(idx_cond)
70
 
71
  # Focus only on the last time step
72
+ logits = logits[:, -1, :]
73
+
74
+ # Clamp temperature to avoid division by very small numbers
75
+ temperature = max(temperature, 1e-8)
76
+ logits = logits / temperature
77
 
78
  # Optional: apply top-k filtering
79
  if top_k is not None:
 
81
  logits[logits < v[:, [-1]]] = -float('Inf')
82
 
83
  # Apply softmax to get probabilities
84
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
85
+
86
+ # Handle edge cases: check for invalid probabilities
87
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
88
+ # Fallback to uniform distribution over valid tokens
89
+ probs = torch.ones_like(probs) / probs.size(-1)
90
+
91
+ # Ensure probabilities sum to 1
92
+ probs = probs / probs.sum(dim=-1, keepdim=True)
93
 
94
  # Sample from the distribution
95
  idx_next = torch.multinomial(probs, num_samples=1)