Update modeling_ministu.py
Browse files- modeling_ministu.py +97 -0
modeling_ministu.py
CHANGED
|
@@ -138,3 +138,100 @@ class MiniSTU(PreTrainedModel):
|
|
| 138 |
torch.nn.init.zeros_(module.c_attn.bias)
|
| 139 |
if module.c_proj.bias is not None:
|
| 140 |
torch.nn.init.zeros_(module.c_proj.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
torch.nn.init.zeros_(module.c_attn.bias)
|
| 139 |
if module.c_proj.bias is not None:
|
| 140 |
torch.nn.init.zeros_(module.c_proj.bias)
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def top_k_top_p_filtering(
|
| 144 |
+
logits: torch.Tensor,
|
| 145 |
+
top_k: int = 50,
|
| 146 |
+
top_p: float = 0.95,
|
| 147 |
+
filter_value: float = float("-inf"),
|
| 148 |
+
):
|
| 149 |
+
"""
|
| 150 |
+
Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
|
| 151 |
+
"""
|
| 152 |
+
# top_k
|
| 153 |
+
if top_k > 0:
|
| 154 |
+
top_k = min(top_k, logits.size(-1))
|
| 155 |
+
# Remove all logits that are not in the top k
|
| 156 |
+
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
|
| 157 |
+
logits[indices_to_remove] = filter_value
|
| 158 |
+
|
| 159 |
+
# top_p (nucleus)
|
| 160 |
+
if 0 < top_p < 1.0:
|
| 161 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 162 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 163 |
+
|
| 164 |
+
# Remove tokens with cumulative probability above the threshold
|
| 165 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 166 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 167 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| 168 |
+
sorted_indices_to_remove[:, 0] = False
|
| 169 |
+
|
| 170 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 171 |
+
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
| 172 |
+
)
|
| 173 |
+
logits[indices_to_remove] = filter_value
|
| 174 |
+
|
| 175 |
+
return logits
|
| 176 |
+
|
| 177 |
+
def generate(
|
| 178 |
+
self,
|
| 179 |
+
input_ids: torch.LongTensor,
|
| 180 |
+
max_new_tokens: int = 50,
|
| 181 |
+
temperature: float = 1.0,
|
| 182 |
+
top_k: int = 50,
|
| 183 |
+
top_p: float = 0.95,
|
| 184 |
+
eos_token_id: int = None,
|
| 185 |
+
pad_token_id: int = 0,
|
| 186 |
+
**kwargs
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
input_ids (torch.LongTensor): shape (batch_size, sequence_length).
|
| 193 |
+
max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
|
| 194 |
+
temperature (float): sampling temperature (>=0).
|
| 195 |
+
top_k (int): Top-K sampling cutoff.
|
| 196 |
+
top_p (float): Nucleus sampling cutoff.
|
| 197 |
+
eos_token_id (int): If set, stop generation when this token is produced.
|
| 198 |
+
pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
|
| 199 |
+
kwargs: Unused arguments (like num_beams) for compatibility.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
|
| 203 |
+
"""
|
| 204 |
+
device = input_ids.device
|
| 205 |
+
|
| 206 |
+
# We'll accumulate new tokens into generated_ids
|
| 207 |
+
generated_ids = input_ids.clone()
|
| 208 |
+
|
| 209 |
+
for _ in range(max_new_tokens):
|
| 210 |
+
# Forward pass to get logits for the last token
|
| 211 |
+
outputs = self.forward(generated_ids)
|
| 212 |
+
logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
|
| 213 |
+
|
| 214 |
+
# Scale logits by temperature
|
| 215 |
+
if temperature != 1.0:
|
| 216 |
+
logits = logits / temperature
|
| 217 |
+
|
| 218 |
+
# Filter logits using top-k and/or top-p
|
| 219 |
+
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 220 |
+
|
| 221 |
+
# Convert to probabilities
|
| 222 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 223 |
+
|
| 224 |
+
# Sample from the distribution
|
| 225 |
+
next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
|
| 226 |
+
|
| 227 |
+
# Append next token
|
| 228 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
| 229 |
+
|
| 230 |
+
# If eos_token_id is set and any sample produced it, we optionally could break early
|
| 231 |
+
if eos_token_id is not None:
|
| 232 |
+
# Check if all sequences in the batch ended
|
| 233 |
+
# or if you want to do a more fine-grained approach
|
| 234 |
+
if (next_token == eos_token_id).all():
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
return generated_ids
|