Youzhi Yu
commited on
Commit
·
ea46d13
1
Parent(s):
5324bdd
Fix generate method to handle CausalLMOutput, plus other updates
Browse files
model.py
CHANGED
|
@@ -9,6 +9,7 @@ from transformers import (
|
|
| 9 |
AutoModel,
|
| 10 |
AutoModelForCausalLM
|
| 11 |
)
|
|
|
|
| 12 |
|
| 13 |
from typing import Optional
|
| 14 |
|
|
@@ -102,6 +103,9 @@ class MLP(nn.Module):
|
|
| 102 |
class ArgonneModel(PreTrainedModel):
|
| 103 |
config_class = ArgonneConfig
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
def __init__(self, config, device_map=None):
|
| 106 |
super().__init__(config)
|
| 107 |
# Create embeddings on CPU initially
|
|
@@ -214,18 +218,40 @@ class ArgonneModel(PreTrainedModel):
|
|
| 214 |
# For now, we'll just return self since our model structure should be compatible
|
| 215 |
return self
|
| 216 |
|
| 217 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
"""
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
"""
|
| 223 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if idx.dim() == 1:
|
| 225 |
-
# Add batch dimension if missing
|
| 226 |
idx = idx.unsqueeze(0)
|
| 227 |
-
|
| 228 |
-
#
|
|
|
|
|
|
|
| 229 |
if self.pipeline_stages is None:
|
| 230 |
# Single-device forward pass
|
| 231 |
device = self.token_embedding.weight.device
|
|
@@ -250,7 +276,11 @@ class ArgonneModel(PreTrainedModel):
|
|
| 250 |
targets = targets.view(-1)
|
| 251 |
loss = F.cross_entropy(logits, targets)
|
| 252 |
|
| 253 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
else:
|
| 255 |
# Pipeline parallel forward
|
| 256 |
first_device = next(self.token_embedding.parameters()).device
|
|
@@ -270,7 +300,7 @@ class ArgonneModel(PreTrainedModel):
|
|
| 270 |
hidden_states = hidden_states.to(device_stage)
|
| 271 |
hidden_states = stage(hidden_states)
|
| 272 |
|
| 273 |
-
#
|
| 274 |
hidden_states = hidden_states.to(last_device)
|
| 275 |
hidden_states = self.ln_f(hidden_states)
|
| 276 |
logits = self.head(hidden_states)
|
|
@@ -282,7 +312,11 @@ class ArgonneModel(PreTrainedModel):
|
|
| 282 |
targets = targets.view(-1)
|
| 283 |
loss = F.cross_entropy(logits, targets)
|
| 284 |
|
| 285 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
@torch.no_grad()
|
|
@@ -342,8 +376,9 @@ class ArgonneModel(PreTrainedModel):
|
|
| 342 |
generated = generated[:, -self.config.block_size:]
|
| 343 |
|
| 344 |
# Forward pass
|
| 345 |
-
|
| 346 |
-
logits = logits
|
|
|
|
| 347 |
|
| 348 |
# Temperature
|
| 349 |
if temperature != 1.0:
|
|
@@ -382,91 +417,6 @@ class ArgonneModel(PreTrainedModel):
|
|
| 382 |
|
| 383 |
return generated
|
| 384 |
|
| 385 |
-
|
| 386 |
-
# @torch.no_grad()
|
| 387 |
-
# def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
|
| 388 |
-
# """
|
| 389 |
-
# Generate text using the model.
|
| 390 |
-
|
| 391 |
-
# Args:
|
| 392 |
-
# input_ids: Input token IDs to continue from
|
| 393 |
-
# max_new_tokens: Number of tokens to generate
|
| 394 |
-
# temperature: Temperature for sampling (higher = more random)
|
| 395 |
-
# top_k: If set, only sample from the top k most likely tokens
|
| 396 |
-
# top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
|
| 397 |
-
# sample: If True, sample from the distribution; if False, use greedy decoding
|
| 398 |
-
|
| 399 |
-
# Returns:
|
| 400 |
-
# Tensor containing the input_ids extended with max_new_tokens generated tokens
|
| 401 |
-
# """
|
| 402 |
-
# self.eval()
|
| 403 |
-
|
| 404 |
-
# # Determine which device to use - explicitly use first device for consistency
|
| 405 |
-
# if self.pipeline_stages is not None and len(self.devices) > 0:
|
| 406 |
-
# device = self.devices[0] # Always use first device for generation
|
| 407 |
-
# else:
|
| 408 |
-
# device = next(self.parameters()).device
|
| 409 |
-
|
| 410 |
-
# # Ensure input is on the correct device
|
| 411 |
-
# generated = input_ids.to(device)
|
| 412 |
-
|
| 413 |
-
# for _ in range(max_new_tokens):
|
| 414 |
-
# # Truncate if necessary to fit within the model's context window
|
| 415 |
-
# if generated.shape[1] > self.config.block_size:
|
| 416 |
-
# generated = generated[:, -self.config.block_size:]
|
| 417 |
-
|
| 418 |
-
# # Forward pass
|
| 419 |
-
# logits, _ = self.forward(generated)
|
| 420 |
-
|
| 421 |
-
# # Make sure logits are on the same device
|
| 422 |
-
# logits = logits.to(device)
|
| 423 |
-
|
| 424 |
-
# # Get logits for the last token only
|
| 425 |
-
# logits = logits[:, -1, :]
|
| 426 |
-
|
| 427 |
-
# # Apply temperature
|
| 428 |
-
# if temperature != 1.0:
|
| 429 |
-
# logits = logits / temperature
|
| 430 |
-
|
| 431 |
-
# # Greedy decoding (argmax) if sample=False
|
| 432 |
-
# if not sample:
|
| 433 |
-
# next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 434 |
-
# else:
|
| 435 |
-
# # Sampling logic
|
| 436 |
-
# # Apply top-k filtering
|
| 437 |
-
# if top_k is not None:
|
| 438 |
-
# indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 439 |
-
# logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
| 440 |
-
|
| 441 |
-
# # Apply top-p (nucleus) filtering
|
| 442 |
-
# if top_p is not None:
|
| 443 |
-
# sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 444 |
-
# cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 445 |
-
|
| 446 |
-
# # Remove tokens with cumulative probability above the threshold
|
| 447 |
-
# sorted_indices_to_remove = cumulative_probs > top_p
|
| 448 |
-
|
| 449 |
-
# # Shift the indices to the right to keep the first token above the threshold
|
| 450 |
-
# sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 451 |
-
# sorted_indices_to_remove[..., 0] = 0
|
| 452 |
-
|
| 453 |
-
# indices_to_remove = sorted_indices_to_remove.scatter(
|
| 454 |
-
# dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
| 455 |
-
# )
|
| 456 |
-
# logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
| 457 |
-
|
| 458 |
-
# # Convert to probability distribution and sample
|
| 459 |
-
# probs = F.softmax(logits, dim=-1)
|
| 460 |
-
# next_token = torch.multinomial(probs, num_samples=1)
|
| 461 |
-
|
| 462 |
-
# # Ensure next_token is on the same device before concatenation
|
| 463 |
-
# next_token = next_token.to(device)
|
| 464 |
-
|
| 465 |
-
# # Append the generated token to the sequence
|
| 466 |
-
# generated = torch.cat((generated, next_token), dim=1)
|
| 467 |
-
|
| 468 |
-
# return generated
|
| 469 |
-
|
| 470 |
# Register the model with Hugging Face's Auto classes
|
| 471 |
AutoConfig.register("argonne", ArgonneConfig)
|
| 472 |
AutoModel.register(ArgonneConfig, ArgonneModel)
|
|
|
|
| 9 |
AutoModel,
|
| 10 |
AutoModelForCausalLM
|
| 11 |
)
|
| 12 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 13 |
|
| 14 |
from typing import Optional
|
| 15 |
|
|
|
|
| 103 |
class ArgonneModel(PreTrainedModel):
|
| 104 |
config_class = ArgonneConfig
|
| 105 |
|
| 106 |
+
# for map_device = "auto"
|
| 107 |
+
_no_split_modules = ["Block"]
|
| 108 |
+
|
| 109 |
def __init__(self, config, device_map=None):
|
| 110 |
super().__init__(config)
|
| 111 |
# Create embeddings on CPU initially
|
|
|
|
| 218 |
# For now, we'll just return self since our model structure should be compatible
|
| 219 |
return self
|
| 220 |
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
input_ids=None,
|
| 224 |
+
attention_mask=None,
|
| 225 |
+
labels=None,
|
| 226 |
+
**kwargs
|
| 227 |
+
):
|
| 228 |
"""
|
| 229 |
+
HF-friendly forward method.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
input_ids (torch.LongTensor): Tokens to be fed to the model. [batch_size, seq_len].
|
| 233 |
+
attention_mask (torch.LongTensor, optional): Mask of shape [batch_size, seq_len],
|
| 234 |
+
with 1 for actual tokens and 0 for padding, if you want to incorporate it.
|
| 235 |
+
Currently ignored in this minimal example.
|
| 236 |
+
labels (torch.LongTensor, optional): Targets for language modeling, same shape as `input_ids`.
|
| 237 |
+
**kwargs: Catch-all for any additional arguments (e.g. past_key_values) so we don't crash.
|
| 238 |
"""
|
| 239 |
+
# 1) We'll rename the parameters from the old code
|
| 240 |
+
if input_ids is None:
|
| 241 |
+
raise ValueError("`input_ids` must be provided.")
|
| 242 |
+
|
| 243 |
+
# We used to call it 'idx'
|
| 244 |
+
idx = input_ids
|
| 245 |
+
# We used to call it 'targets'
|
| 246 |
+
targets = labels
|
| 247 |
+
|
| 248 |
+
# [Optional] If we want to handle single-dim input_ids
|
| 249 |
if idx.dim() == 1:
|
|
|
|
| 250 |
idx = idx.unsqueeze(0)
|
| 251 |
+
|
| 252 |
+
# 2) Now the rest of your old forward logic remains, just replacing references
|
| 253 |
+
# to "idx" and "targets" with these new variables.
|
| 254 |
+
|
| 255 |
if self.pipeline_stages is None:
|
| 256 |
# Single-device forward pass
|
| 257 |
device = self.token_embedding.weight.device
|
|
|
|
| 276 |
targets = targets.view(-1)
|
| 277 |
loss = F.cross_entropy(logits, targets)
|
| 278 |
|
| 279 |
+
return CausalLMOutput(
|
| 280 |
+
loss=loss,
|
| 281 |
+
logits=logits,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
else:
|
| 285 |
# Pipeline parallel forward
|
| 286 |
first_device = next(self.token_embedding.parameters()).device
|
|
|
|
| 300 |
hidden_states = hidden_states.to(device_stage)
|
| 301 |
hidden_states = stage(hidden_states)
|
| 302 |
|
| 303 |
+
# Move to last device before final ops
|
| 304 |
hidden_states = hidden_states.to(last_device)
|
| 305 |
hidden_states = self.ln_f(hidden_states)
|
| 306 |
logits = self.head(hidden_states)
|
|
|
|
| 312 |
targets = targets.view(-1)
|
| 313 |
loss = F.cross_entropy(logits, targets)
|
| 314 |
|
| 315 |
+
return CausalLMOutput(
|
| 316 |
+
loss=loss,
|
| 317 |
+
logits=logits,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
|
| 321 |
|
| 322 |
@torch.no_grad()
|
|
|
|
| 376 |
generated = generated[:, -self.config.block_size:]
|
| 377 |
|
| 378 |
# Forward pass
|
| 379 |
+
outputs = self.forward(generated)
|
| 380 |
+
logits = outputs.logits # outputs is a CausalLMOutput
|
| 381 |
+
logits = logits[:, -1, :] # get the last token's logits
|
| 382 |
|
| 383 |
# Temperature
|
| 384 |
if temperature != 1.0:
|
|
|
|
| 417 |
|
| 418 |
return generated
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
# Register the model with Hugging Face's Auto classes
|
| 421 |
AutoConfig.register("argonne", ArgonneConfig)
|
| 422 |
AutoModel.register(ArgonneConfig, ArgonneModel)
|