|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Set, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from megatron.core import InferenceParams, ModelParallelConfig, parallel_state |
|
|
from safetensors.torch import load_file |
|
|
from torch.distributed.fsdp import FullStateDictConfig |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
from torch.distributed.fsdp import ShardingStrategy, StateDictType |
|
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
from torch.nn.modules.module import _IncompatibleKeys |
|
|
|
|
|
from cosmos_predict1.autoregressive.configs.base.model import TrainingModelConfig as ModelConfig |
|
|
from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig |
|
|
from cosmos_predict1.autoregressive.modules.mm_projector import MultimodalProjector |
|
|
from cosmos_predict1.autoregressive.networks.vit import VisionTransformer, get_vit_config |
|
|
|
|
|
|
|
|
from cosmos_predict1.autoregressive.tokenizer.tokenizer import DiscreteMultimodalTokenizer |
|
|
from cosmos_predict1.autoregressive.training.networks.transformer import ( |
|
|
Transformer, |
|
|
TransformerBlock, |
|
|
TransformerBlockTE, |
|
|
) |
|
|
from cosmos_predict1.autoregressive.utils.checkpoint import ( |
|
|
get_partial_state_dict, |
|
|
maybe_convert_checkpoint_to_backend, |
|
|
obtain_tensor_parallel_state_dict, |
|
|
process_state_dict, |
|
|
substrings_to_ignore, |
|
|
) |
|
|
from cosmos_predict1.autoregressive.utils.misc import random_dropout |
|
|
from cosmos_predict1.autoregressive.utils.parallel import broadcast_data_batch_in_tp_cp_group, get_batch_on_this_cp_rank |
|
|
from cosmos_predict1.autoregressive.utils.sampling import ( |
|
|
decode_n_tokens, |
|
|
decode_one_token, |
|
|
prefill, |
|
|
sample_top_k, |
|
|
sample_top_p, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh |
|
|
from cosmos_predict1.utils import distributed, log, misc |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
from cosmos_predict1.utils.misc import download_from_s3_with_cache, sync_s3_dir_to_local |
|
|
from cosmos_predict1.utils.model import Model |
|
|
|
|
|
|
|
|
class AutoRegressiveTrainingModel(Model): |
|
|
""" |
|
|
A class to build and use a Llama model for text generation. |
|
|
|
|
|
Methods: |
|
|
build: Build a Llama instance by initializing and loading a model checkpoint. |
|
|
generate: Generate text sequences based on provided prompts using the language generation model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: Transformer, |
|
|
tokenizer: DiscreteMultimodalTokenizer, |
|
|
config: ModelConfig, |
|
|
model_parallel: ModelParallelConfig = None, |
|
|
vision_encoder: VisionTransformer = None, |
|
|
mm_projector: MultimodalProjector = None, |
|
|
): |
|
|
""" |
|
|
Initialize the Llama instance with a model and tokenizer. |
|
|
|
|
|
Args: |
|
|
model (Transformer): The Transformer model for text generation. |
|
|
tokenizer (Tokenizer): The tokenizer for encoding and decoding text. |
|
|
config (Config): The configuration for the Llama model. |
|
|
""" |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
self.precision = self.model.precision |
|
|
self.vision_encoder = vision_encoder |
|
|
self.mm_projector = mm_projector |
|
|
assert (self.vision_encoder is None) == (self.mm_projector is None), ( |
|
|
"vision_encoder and mm_projector should be " "both None or not None simultaneously" |
|
|
) |
|
|
self.model_parallel = model_parallel |
|
|
self.monitor_output_logits = False |
|
|
self.inference_params = None |
|
|
|
|
|
|
|
|
if self.config.freeze_vision_encoder and vision_encoder is not None: |
|
|
for param in self.vision_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
log.critical("Vision encoder parameters are frozen.") |
|
|
|
|
|
num_params = self.get_num_params() |
|
|
log.info(f"Number of model parameters: {round(num_params / 1e9, 3)}B") |
|
|
|
|
|
def get_num_params( |
|
|
self, |
|
|
) -> int: |
|
|
""" |
|
|
Return the number of parameters in the model. |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
return n_params |
|
|
|
|
|
def training_step( |
|
|
self, data_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
broadcast_data_batch_in_tp_cp_group(data_batch) |
|
|
|
|
|
context = data_batch.get("context", None) |
|
|
context_mask = data_batch.get("context_mask", None) |
|
|
if context is not None: |
|
|
if self.config.embedding_dropout > 0: |
|
|
context = random_dropout( |
|
|
context, |
|
|
self.config.embedding_dropout, |
|
|
) |
|
|
context = misc.to(context, device="cuda") |
|
|
if context_mask is not None: |
|
|
context_mask = misc.to(context_mask, device="cuda") |
|
|
action = data_batch.get("action", None) |
|
|
if action is not None: |
|
|
action = misc.to(action, device="cuda") |
|
|
|
|
|
tokens, token_boundaries = self.tokenizer.tokenize(data_batch) |
|
|
tokens = misc.to(tokens, device="cuda") |
|
|
|
|
|
labels = data_batch.get("labels", None) |
|
|
|
|
|
masks = data_batch.get("token_mask", None) |
|
|
apply_token_mask = masks is not None |
|
|
if masks is None: |
|
|
masks = torch.ones_like(tokens, dtype=torch.bool) |
|
|
masks = misc.to(masks, device="cuda") |
|
|
assert ( |
|
|
data_batch.get("labels", None) is None or apply_token_mask |
|
|
), "The code is not tested for the case when both labels and token_mask are provided." |
|
|
|
|
|
if self.config.ignore_first_num_tokens > 0: |
|
|
assert self.config.ignore_first_num_tokens < masks.shape[1] |
|
|
masks[:, : self.config.ignore_first_num_tokens] = False |
|
|
seq_len = tokens.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
if parallel_state.get_context_parallel_world_size() > 1: |
|
|
|
|
|
cp_group = parallel_state.get_context_parallel_group() |
|
|
self.model.enable_context_parallel(cp_group) |
|
|
tokens = get_batch_on_this_cp_rank(tokens) |
|
|
masks = get_batch_on_this_cp_rank(masks) |
|
|
if labels is not None: |
|
|
labels = get_batch_on_this_cp_rank(labels) |
|
|
if self.vision_encoder is None: |
|
|
logits = self.model.forward( |
|
|
tokens=tokens, |
|
|
input_pos=None, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
total_seq_len=seq_len, |
|
|
) |
|
|
else: |
|
|
assert "images" in data_batch |
|
|
images = data_batch["images"] |
|
|
if images.ndim == 5: |
|
|
|
|
|
images = images.view(-1, *images.shape[2:]) |
|
|
assert images.ndim == 4, f"Invalid shape: {images.shape}" |
|
|
token_embeddings = self.embed_vision_language_features(tokens, images) |
|
|
logits = self.model.forward( |
|
|
token_embeddings=token_embeddings, |
|
|
input_pos=None, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
total_seq_len=seq_len, |
|
|
) |
|
|
|
|
|
if labels is None: |
|
|
|
|
|
|
|
|
logits = logits[:, :-1] |
|
|
masks = masks[:, :-1] |
|
|
labels = tokens[:, 1:].clone() |
|
|
|
|
|
batch_size = tokens.shape[0] |
|
|
|
|
|
for sample_num in range(batch_size): |
|
|
if self.tokenizer.training_type == "text_to_video": |
|
|
|
|
|
|
|
|
if len(token_boundaries["text"]) > 0: |
|
|
labels[sample_num][0 : token_boundaries["text"][sample_num][1] - 1] = self.tokenizer.ignore_index |
|
|
elif self.tokenizer.training_type == "class_to_image": |
|
|
|
|
|
|
|
|
labels[sample_num][0 : token_boundaries["class"][sample_num][1] - 1] = self.tokenizer.ignore_index |
|
|
|
|
|
ignore_index = self.tokenizer.ignore_index |
|
|
if self.config.ignore_first_num_tokens > 0 or apply_token_mask: |
|
|
labels[~masks] = ignore_index |
|
|
|
|
|
output_batch = { |
|
|
"encode_tokens": tokens, |
|
|
"logits": logits.detach(), |
|
|
"labels": labels.detach(), |
|
|
"ignore_index": ignore_index, |
|
|
} |
|
|
|
|
|
if self.monitor_output_logits: |
|
|
self.gather_output_logits_stats(logits, labels, output_batch, ignore_index) |
|
|
|
|
|
logits = logits.flatten(0, 1) |
|
|
labels = labels.flatten(0, 1) |
|
|
|
|
|
|
|
|
ce_loss = F.cross_entropy( |
|
|
input=logits, |
|
|
target=labels, |
|
|
ignore_index=ignore_index, |
|
|
) |
|
|
|
|
|
|
|
|
log_z = torch.logsumexp(logits, dim=-1) |
|
|
z_loss = self.config.z_loss_coeff * (log_z**2).mean() |
|
|
|
|
|
|
|
|
total_loss = ce_loss + z_loss |
|
|
|
|
|
return output_batch, total_loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def validation_step( |
|
|
self, data_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
""" |
|
|
Perform a validation step for the model, which is the same as the training step (but without backpropagation). |
|
|
""" |
|
|
return self.training_step(data_batch, iteration) |
|
|
|
|
|
@torch.no_grad() |
|
|
def gather_output_logits_stats( |
|
|
self, logits: torch.Tensor, labels: torch.Tensor, output_batch: Dict, ignore_index: int = None |
|
|
): |
|
|
""" |
|
|
Gather statistics of the output logits, including mean, norm, and max values. |
|
|
""" |
|
|
bs, seq_len, dim = logits.shape |
|
|
logits = logits.reshape(-1, dim) |
|
|
if ignore_index is not None: |
|
|
select_index = labels.view(-1) != ignore_index |
|
|
acc = labels.view(-1)[select_index] == logits.argmax(dim=1)[select_index] |
|
|
acc = acc.float().mean().view(-1, 1) |
|
|
|
|
|
logits = logits[select_index] |
|
|
output_batch.update( |
|
|
{ |
|
|
"logits_mean": logits.mean(dim=1).detach(), |
|
|
"logits_norm": torch.linalg.vector_norm(logits, dim=1).detach(), |
|
|
"logits_max": logits.max(dim=1).values.detach(), |
|
|
"acc": acc.detach() * 100, |
|
|
} |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def image_encode(self, state: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode the image input state to continuous latent and discrete indices. |
|
|
""" |
|
|
latent, indices = self.tokenizer.image_tokenizer.encode(state) |
|
|
return latent, indices |
|
|
|
|
|
@torch.no_grad() |
|
|
def image_decode(self, indices: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Decode the discrete indices to RGB images. |
|
|
""" |
|
|
return self.tokenizer.image_tokenizer.decode(indices) |
|
|
|
|
|
@torch.no_grad() |
|
|
def video_encode(self, state: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode the video input state to continuous latent and discrete indices. |
|
|
""" |
|
|
latent, indices = self.tokenizer.video_tokenizer.encode(state) |
|
|
return latent, indices |
|
|
|
|
|
@torch.no_grad() |
|
|
def video_decode(self, indices: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Decode the discrete indices to RGB videos. |
|
|
""" |
|
|
if self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap > 0: |
|
|
return self.tokenizer.video_tokenizer.decode_with_overlap( |
|
|
indices, temporal_overlap=self.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap |
|
|
) |
|
|
else: |
|
|
return self.tokenizer.video_tokenizer.decode(indices) |
|
|
|
|
|
@staticmethod |
|
|
def load_llm_checkpoint( |
|
|
ckpt_path: str = "", |
|
|
model: Transformer = None, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
""" |
|
|
Load a LLM checkpoint from the specified path. |
|
|
""" |
|
|
with misc.timer(f"loading checkpoint from {ckpt_path}"): |
|
|
checkpoint = torch.load( |
|
|
ckpt_path, |
|
|
map_location="cpu", |
|
|
mmap=True, |
|
|
) |
|
|
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint |
|
|
llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") |
|
|
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") |
|
|
with misc.timer("loading state_dict into model"): |
|
|
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) |
|
|
|
|
|
@staticmethod |
|
|
def build( |
|
|
seed: int = 1, |
|
|
train_from_scratch: bool = False, |
|
|
model_config: ModelConfig = ModelConfig(), |
|
|
fsdp_checkpointer: Any = None, |
|
|
tokenizer_config: TokenizerConfig = None, |
|
|
model_parallel: ModelParallelConfig = None, |
|
|
shard_checkpoint: bool = True, |
|
|
download_rank_sync: bool = True, |
|
|
**kwargs, |
|
|
) -> "AutoRegressiveTrainingModel": |
|
|
""" |
|
|
Build a Llama instance by initializing and loading a model checkpoint. |
|
|
|
|
|
Args: |
|
|
seed (int, optional): Random seed for reproducibility. Defaults to 1. |
|
|
train_from_scratch (bool, optional): Flag indicating whether to train the model from scratch. Defaults to False. |
|
|
model_config (ModelConfig, optional): The model configuration for the Llama instance. Defaults to ModelConfig(). |
|
|
fsdp_checkpointer (Any, optional): The FSDP checkpointer for the Llama instance. Defaults to None. |
|
|
tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the Llama instance. Defaults to None. |
|
|
shard_checkpoint (bool, optional): Whether to split the checkpoint by Tensor Parallelism before loading. Defaults to False. |
|
|
download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. |
|
|
Returns: |
|
|
Llama: An instance of the Llama class with the loaded model and tokenizer. |
|
|
|
|
|
Raises: |
|
|
AssertionError: If there are no checkpoint files in the specified directory. |
|
|
|
|
|
Note: |
|
|
This method sets the device to CUDA and loads the pre-trained model and tokenizer. |
|
|
""" |
|
|
tensor_parallel_size = 1 if model_parallel is None else model_parallel.tensor_model_parallel_size |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
llama_params = {} |
|
|
|
|
|
|
|
|
if not train_from_scratch: |
|
|
if model_config.ckpt_path is None: |
|
|
|
|
|
ckpt_dir = sync_s3_dir_to_local( |
|
|
s3_dir=model_config.ckpt_dir, |
|
|
s3_credential_path=model_config.s3_credential_path, |
|
|
cache_dir=model_config.cache_dir, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors")) |
|
|
if len(checkpoints) == 0: |
|
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) |
|
|
|
|
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" |
|
|
assert ( |
|
|
len(checkpoints) == 1 |
|
|
), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)" |
|
|
ckpt_path = str(checkpoints[0]) |
|
|
|
|
|
if os.path.exists(Path(ckpt_dir) / "params.json"): |
|
|
with open(Path(ckpt_dir) / "params.json", "r") as f: |
|
|
llama_params = json.loads(f.read()) |
|
|
else: |
|
|
log.info( |
|
|
f"No params.json found in the checkpoint directory ({ckpt_dir}). " |
|
|
f"Using default model config." |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
ckpt_path = download_from_s3_with_cache( |
|
|
s3_path=model_config.ckpt_path, |
|
|
s3_credential_path=model_config.s3_credential_path, |
|
|
cache_dir=model_config.cache_dir, |
|
|
rank_sync=download_rank_sync, |
|
|
) |
|
|
|
|
|
for key, value in llama_params.items(): |
|
|
|
|
|
setattr(model_config, key, value) |
|
|
|
|
|
with misc.timer(f"loading checkpoint from {ckpt_path}"): |
|
|
if ckpt_path.endswith("safetensors"): |
|
|
|
|
|
checkpoint = load_file(ckpt_path, device="cpu") |
|
|
else: |
|
|
|
|
|
checkpoint = torch.load( |
|
|
ckpt_path, |
|
|
map_location="cpu", |
|
|
mmap=True, |
|
|
) |
|
|
llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_checkpoint = maybe_convert_checkpoint_to_backend( |
|
|
llm_checkpoint, |
|
|
target_backend=model_config.backend, |
|
|
model_config=model_config, |
|
|
tensor_parallel_size=tensor_parallel_size if not shard_checkpoint else 1, |
|
|
is_tensor_parallel_shard=tensor_parallel_size > 1 and not shard_checkpoint, |
|
|
) |
|
|
if model_config.vision_encoder is not None: |
|
|
|
|
|
|
|
|
|
|
|
if "vision_encoder" in checkpoint: |
|
|
log.info("Using pretrained vision_encoder") |
|
|
vit_checkpoint = checkpoint["vision_encoder"] |
|
|
else: |
|
|
log.info("Using fine-tuned vision_encoder") |
|
|
vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.") |
|
|
vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.") |
|
|
if "mm_projector" in checkpoint: |
|
|
log.info("Using pretrained mm_projector") |
|
|
projector_checkpoint = checkpoint["mm_projector"] |
|
|
else: |
|
|
log.info("Using fine-tuned mm_projector") |
|
|
projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.") |
|
|
projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.") |
|
|
assert ( |
|
|
len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0 |
|
|
), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector." |
|
|
|
|
|
tokenizer = DiscreteMultimodalTokenizer(tokenizer_config) |
|
|
|
|
|
precision = getattr(torch, model_config.precision) |
|
|
torch.set_default_dtype(precision) |
|
|
log.info(f"Setting torch default dtype to {precision}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Transformer( |
|
|
params=model_config, |
|
|
model_parallel=model_parallel, |
|
|
tokenizer_config=tokenizer_config, |
|
|
init_weights=train_from_scratch, |
|
|
) |
|
|
model_kwargs = {} |
|
|
|
|
|
if model_config.vision_encoder is not None: |
|
|
assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided." |
|
|
vit_config = get_vit_config(model_config.vision_encoder) |
|
|
vision_encoder = VisionTransformer.build( |
|
|
vit_config, |
|
|
hidden_dropout=model_config["hidden_dropout"], |
|
|
attention_dropout=model_config["attention_dropout"], |
|
|
set_parallel_mode=model_config["set_parallel_mode"], |
|
|
model_parallel=model_parallel, |
|
|
attention_tp=tensor_parallel_size > 1, |
|
|
) |
|
|
|
|
|
mm_projector = MultimodalProjector( |
|
|
mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"] |
|
|
) |
|
|
model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector}) |
|
|
|
|
|
|
|
|
if tokenizer.vocab_size > model.vocab_size: |
|
|
log.info(f"Expanding vocab size to {tokenizer.vocab_size}") |
|
|
|
|
|
expand_output_layer = not (tokenizer.training_type == "text_to_video") |
|
|
model.expand_vocab(tokenizer.vocab_size, init_method="gaussian", expand_output_layer=expand_output_layer) |
|
|
|
|
|
if not train_from_scratch: |
|
|
if shard_checkpoint: |
|
|
|
|
|
with misc.timer("sharding checkpoint according to tensor parallelism"): |
|
|
if model_parallel is not None: |
|
|
assert model_parallel.tensor_model_parallel_size == model_config["tensor_model_parallel_size"] |
|
|
llm_checkpoint = obtain_tensor_parallel_state_dict( |
|
|
llm_checkpoint, |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), |
|
|
model_config=model_config, |
|
|
) |
|
|
if model_config.vision_encoder is not None: |
|
|
|
|
|
vit_checkpoint = obtain_tensor_parallel_state_dict( |
|
|
vit_checkpoint, |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
tensor_parallel_rank=parallel_state.get_tensor_model_parallel_rank(), |
|
|
model_config=vit_config, |
|
|
) |
|
|
|
|
|
if model_config.vision_encoder is not None: |
|
|
|
|
|
llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.") |
|
|
|
|
|
llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.") |
|
|
with misc.timer("loading state_dict into model"): |
|
|
missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True) |
|
|
|
|
|
missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")] |
|
|
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" |
|
|
|
|
|
if model_config.vision_encoder is not None: |
|
|
|
|
|
vision_encoder.load_state_dict(vit_checkpoint) |
|
|
mm_projector.load_state_dict(projector_checkpoint) |
|
|
if model_config.vision_encoder_in_channels != 3: |
|
|
vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels) |
|
|
|
|
|
model = model.to(precision) |
|
|
log.info(f"Model config: {model_config}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_class = AutoRegressiveTrainingModel |
|
|
if model_config.fsdp_enabled: |
|
|
raise NotImplementedError("FSDP is not implemented for AutoRegressiveTrainingModel") |
|
|
|
|
|
|
|
|
return model_class(model, tokenizer, model_config, **model_kwargs) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
prompt_tokens: List[List[int]], |
|
|
max_gen_len: int, |
|
|
temperature: float = 0.6, |
|
|
top_p: float = 0.9, |
|
|
top_k: Optional[int] = None, |
|
|
logprobs: bool = False, |
|
|
echo: bool = False, |
|
|
logit_clipping_range: list = [], |
|
|
seed: int = 0, |
|
|
images: Optional[torch.Tensor] = None, |
|
|
context: Optional[torch.Tensor] = None, |
|
|
context_mask: Optional[torch.Tensor] = None, |
|
|
action: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[List[List[int]], Optional[List[List[float]]]]: |
|
|
""" |
|
|
Generate text sequences based on provided prompts using the language generation model. |
|
|
|
|
|
Args: |
|
|
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. |
|
|
max_gen_len (int): Maximum length of the generated text sequence. |
|
|
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. |
|
|
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. |
|
|
top_k (int, optional): Top-k value for top-k sampling. Defaults to None. If not None, top-k sampling will be used instead of top-p sampling. |
|
|
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. |
|
|
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. |
|
|
|
|
|
Note: |
|
|
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. |
|
|
If logprobs is True, token log probabilities are computed for each generated token. |
|
|
|
|
|
""" |
|
|
assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified." |
|
|
if top_p is not None: |
|
|
log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") |
|
|
elif top_k is not None: |
|
|
log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") |
|
|
else: |
|
|
log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") |
|
|
|
|
|
self.model.set_inference_flag(True) |
|
|
misc.set_random_seed(seed) |
|
|
|
|
|
if isinstance(self.model.params, list): |
|
|
|
|
|
log.info( |
|
|
f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" |
|
|
) |
|
|
params = self.config |
|
|
else: |
|
|
params = self.model.params |
|
|
bsz = len(prompt_tokens) |
|
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) |
|
|
|
|
|
if self.config.backend == "transformer_engine": |
|
|
self.inference_params = InferenceParams( |
|
|
max_batch_size=params.max_batch_size, max_sequence_length=params.max_seq_len |
|
|
) |
|
|
|
|
|
|
|
|
min_prompt_len = min(len(t) for t in prompt_tokens) |
|
|
max_prompt_len = max(len(t) for t in prompt_tokens) |
|
|
assert max_prompt_len <= params.max_seq_len |
|
|
total_len = params.max_seq_len |
|
|
assert ( |
|
|
max_gen_len + max_prompt_len <= total_len |
|
|
), f"max_gen_len + max_prompt_len={max_gen_len + max_prompt_len} exceeds max_seq_len={total_len}" |
|
|
|
|
|
pad_id = self.tokenizer.pad_id |
|
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") |
|
|
|
|
|
|
|
|
for k, t in enumerate(prompt_tokens): |
|
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") |
|
|
if logprobs: |
|
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float) |
|
|
|
|
|
prev_pos = 0 |
|
|
eos_reached = torch.tensor([False] * bsz, device="cuda") |
|
|
input_text_mask = tokens != pad_id |
|
|
|
|
|
|
|
|
|
|
|
passed_image_embeddings = False |
|
|
|
|
|
|
|
|
if min_prompt_len == total_len: |
|
|
input_pos = torch.arange(tokens.shape[1], dtype=torch.long, device="cuda") |
|
|
if images is None: |
|
|
logits = self.model.forward( |
|
|
tokens=tokens, |
|
|
input_pos=input_pos, |
|
|
inference_params=self.inference_params, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
else: |
|
|
token_embeddings = self.embed_vision_language_features(tokens, images) |
|
|
logits = self.model.forward( |
|
|
token_embeddings=token_embeddings, |
|
|
input_pos=input_pos, |
|
|
inference_params=self.inference_params, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
passed_image_embeddings = True |
|
|
token_logprobs = -F.cross_entropy( |
|
|
input=logits.transpose(1, 2), |
|
|
target=tokens, |
|
|
reduction="none", |
|
|
ignore_index=pad_id, |
|
|
) |
|
|
|
|
|
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens), dtype=torch.long, device="cuda") |
|
|
|
|
|
|
|
|
log.info(f"Start generating the next {total_len - min_prompt_len} tokens. This will take a while..") |
|
|
for cur_pos in range(min_prompt_len, total_len): |
|
|
input_pos = torch.arange(prev_pos, cur_pos, dtype=torch.long, device="cuda") |
|
|
if images is not None and not passed_image_embeddings: |
|
|
token_embeddings = self.embed_vision_language_features(tokens[:, prev_pos:cur_pos], images) |
|
|
logits = self.model.forward( |
|
|
token_embeddings=token_embeddings, |
|
|
input_pos=input_pos, |
|
|
inference_params=self.inference_params, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
passed_image_embeddings = True |
|
|
else: |
|
|
logits = self.model.forward( |
|
|
tokens=tokens[:, prev_pos:cur_pos], |
|
|
input_pos=input_pos, |
|
|
inference_params=self.inference_params, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
|
|
|
if self.config.backend == "transformer_engine": |
|
|
self.inference_params.sequence_len_offset += logits.shape[1] |
|
|
|
|
|
|
|
|
if len(logit_clipping_range) > 0: |
|
|
min_clip_index = logit_clipping_range[0] |
|
|
max_clip_index = logit_clipping_range[1] |
|
|
logits_clipped = logits[:, :, min_clip_index:max_clip_index] |
|
|
else: |
|
|
logits_clipped = logits |
|
|
min_clip_index = 0 |
|
|
|
|
|
if temperature > 0: |
|
|
if top_p is not None: |
|
|
next_token = sample_top_p(logits_clipped, temperature=temperature, top_p=top_p)[0] |
|
|
else: |
|
|
next_token = sample_top_k(logits_clipped, temperature=temperature, top_k=top_k)[0] |
|
|
else: |
|
|
next_token = torch.argmax(logits_clipped[:, -1, :], dim=-1) |
|
|
|
|
|
next_token += min_clip_index |
|
|
|
|
|
next_token = next_token.reshape(-1) |
|
|
|
|
|
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) |
|
|
tokens[:, cur_pos] = next_token |
|
|
|
|
|
if logprobs: |
|
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( |
|
|
input=logits.transpose(1, 2), |
|
|
target=tokens[:, prev_pos + 1 : cur_pos + 1], |
|
|
reduction="none", |
|
|
ignore_index=pad_id, |
|
|
) |
|
|
|
|
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) |
|
|
prev_pos = cur_pos |
|
|
|
|
|
if all(eos_reached): |
|
|
log.info(f"Reach end of sequence, current pos: {cur_pos}; maximum pos: {total_len}") |
|
|
break |
|
|
|
|
|
if logprobs: |
|
|
token_logprobs = token_logprobs.tolist() |
|
|
out_tokens, out_logprobs = [], [] |
|
|
|
|
|
|
|
|
for i, toks in enumerate(tokens.tolist()): |
|
|
|
|
|
start = 0 if echo else len(prompt_tokens[i]) |
|
|
toks = toks[start : len(prompt_tokens[i]) + max_gen_len] |
|
|
probs = None |
|
|
if logprobs: |
|
|
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] |
|
|
|
|
|
for stop_token in self.tokenizer.stop_tokens: |
|
|
try: |
|
|
eos_idx = toks.index(stop_token) |
|
|
toks = toks[:eos_idx] |
|
|
probs = probs[:eos_idx] if logprobs else None |
|
|
except ValueError: |
|
|
pass |
|
|
out_tokens.append(toks) |
|
|
out_logprobs.append(probs) |
|
|
self.model.set_inference_flag(False) |
|
|
return (out_tokens, out_logprobs if logprobs else None) |
|
|
|
|
|
@torch.no_grad() |
|
|
def fast_generate( |
|
|
self, |
|
|
prompt_tokens: List[List[int]] | torch.Tensor, |
|
|
max_gen_len: int, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
num_gen_seq: int = 1, |
|
|
logprobs: bool = False, |
|
|
echo: bool = False, |
|
|
seed: int = 0, |
|
|
context: Optional[torch.Tensor] = None, |
|
|
context_mask: Optional[torch.Tensor] = None, |
|
|
action: Optional[torch.Tensor] = None, |
|
|
compile_decode: bool = True, |
|
|
compile_prefill: bool = False, |
|
|
verbose: bool = True, |
|
|
stop_tokens: Optional[Set[int]] = None, |
|
|
): |
|
|
""" |
|
|
Fast auto-regressive generation. Currently only supports input batch size = 1. |
|
|
Args: |
|
|
prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len). |
|
|
max_gen_len (int): Maximum length of the generated text sequence. |
|
|
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. |
|
|
top_k (int, optional): Top-k value for top-k sampling. Defaults to None. |
|
|
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None. |
|
|
num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic. |
|
|
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. |
|
|
logit_clipping_range (list, optional): Range of logits to clip. Defaults to []. |
|
|
seed (int, optional): Random seed for reproducibility. Defaults to 0. |
|
|
compile_decode (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True. |
|
|
compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False. |
|
|
verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False. |
|
|
""" |
|
|
assert ( |
|
|
top_p is None or top_k is None |
|
|
), f"Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}" |
|
|
if top_p is not None: |
|
|
log.info(f"Using top-p sampling with p={top_p} and temperature={temperature}") |
|
|
elif top_k is not None: |
|
|
log.info(f"Using top-k sampling with k={top_k} and temperature={temperature}") |
|
|
else: |
|
|
log.info("Not applying top-k or top-p sampling. Will use top-k sampling with k=None") |
|
|
|
|
|
torch._inductor.config.coordinate_descent_tuning = True |
|
|
torch._inductor.config.triton.unique_kernel_names = True |
|
|
|
|
|
torch._inductor.config.fx_graph_cache = True |
|
|
|
|
|
|
|
|
self.model.set_inference_flag(True) |
|
|
misc.set_random_seed(seed) |
|
|
|
|
|
assert not logprobs, "logprobs are not supported for fast_generate yet" |
|
|
|
|
|
if compile_decode and not getattr(self, "inference_decode_compiled", False): |
|
|
self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) |
|
|
self.inference_decode_compiled = True |
|
|
log.critical("Compiled decode_one_token function. Note: the first run will be slower due to compilation") |
|
|
if compile_prefill and not getattr(self, "inference_prefill_compiled", False): |
|
|
self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True) |
|
|
self.inference_prefill_compiled = True |
|
|
log.critical("Compiled prefill function. Note: the first run will be slower due to compilation") |
|
|
|
|
|
if not hasattr(self, "decode_one_token"): |
|
|
self.decode_one_token = decode_one_token |
|
|
if not hasattr(self, "prefill"): |
|
|
self.prefill = prefill |
|
|
|
|
|
|
|
|
if isinstance(self.model.params, list): |
|
|
|
|
|
log.info( |
|
|
f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}" |
|
|
) |
|
|
params = self.config |
|
|
else: |
|
|
params = self.model.params |
|
|
if isinstance(prompt_tokens, list): |
|
|
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda") |
|
|
if prompt_tokens.ndim == 1: |
|
|
prompt_tokens = prompt_tokens.view(1, -1) |
|
|
else: |
|
|
assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}" |
|
|
batch_size, prompt_len = prompt_tokens.shape |
|
|
total_len = min(params.max_seq_len, max_gen_len + prompt_len) |
|
|
if max_gen_len + prompt_len > params.max_seq_len: |
|
|
log.warning( |
|
|
f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}" |
|
|
) |
|
|
max_gen_len = params.max_seq_len - prompt_len |
|
|
|
|
|
if context_mask is not None: |
|
|
context_mask = context_mask.to(dtype=torch.bool) |
|
|
if context_mask.ndim == 2: |
|
|
assert ( |
|
|
context_mask.shape[0] == batch_size |
|
|
), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}" |
|
|
|
|
|
context_mask = context_mask.view(batch_size, 1, 1, -1) |
|
|
|
|
|
if num_gen_seq > 1: |
|
|
assert ( |
|
|
batch_size == 1 |
|
|
), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts" |
|
|
log.critical(f"Generating {num_gen_seq} sequences with the same prompt") |
|
|
assert ( |
|
|
num_gen_seq <= params.max_batch_size |
|
|
), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}" |
|
|
|
|
|
prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1) |
|
|
assert prompt_tokens.shape == ( |
|
|
num_gen_seq, |
|
|
prompt_len, |
|
|
), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}" |
|
|
batch_size = len(prompt_tokens) |
|
|
|
|
|
|
|
|
empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device) |
|
|
empty[:, :prompt_len] = prompt_tokens |
|
|
seq = empty |
|
|
input_pos = torch.arange(0, prompt_len, device="cuda") |
|
|
|
|
|
if verbose: |
|
|
prefill_start = time.time() |
|
|
|
|
|
|
|
|
next_token = self.prefill( |
|
|
self.model, |
|
|
prompt_tokens, |
|
|
input_pos=input_pos, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
if verbose: |
|
|
prefill_time = time.time() - prefill_start |
|
|
|
|
|
seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype) |
|
|
input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda") |
|
|
stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens |
|
|
stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda") |
|
|
|
|
|
if verbose: |
|
|
decode_start = time.time() |
|
|
|
|
|
generated_tokens = decode_n_tokens( |
|
|
self.model, |
|
|
next_token.view(batch_size, -1), |
|
|
input_pos, |
|
|
max_gen_len - 1, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
stop_tokens=stop_tokens, |
|
|
decode_one_token_function=self.decode_one_token, |
|
|
context=context, |
|
|
context_mask=context_mask, |
|
|
action=action, |
|
|
) |
|
|
gen_len = len(generated_tokens) |
|
|
if verbose: |
|
|
decode_time = time.time() - decode_start |
|
|
prefill_throughput = prompt_len / prefill_time |
|
|
decode_throughput = gen_len / decode_time |
|
|
log.info(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s") |
|
|
log.info(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s") |
|
|
|
|
|
generated_tokens = torch.cat(generated_tokens, dim=1) |
|
|
|
|
|
log.critical(f"generated_tokens: {generated_tokens.shape}") |
|
|
seq = seq[:, : prompt_len + 1 + gen_len] |
|
|
seq[:, prompt_len + 1 :] = generated_tokens |
|
|
if not echo: |
|
|
seq = seq[:, prompt_len:] |
|
|
return seq, None |
|
|
|
|
|
def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor: |
|
|
""" |
|
|
Embed vision and language features into a combined representation. |
|
|
|
|
|
Args: |
|
|
input_ids (torch.Tensor): Input token IDs. |
|
|
images (torch.tensor): Input images. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Combined vision-language features. |
|
|
|
|
|
Raises: |
|
|
AssertionError: If vision encoder or mm projector is not initialized, |
|
|
or if dimensions mismatch. |
|
|
""" |
|
|
|
|
|
assert self.vision_encoder is not None |
|
|
assert self.mm_projector is not None |
|
|
|
|
|
|
|
|
image_token_id = self.vision_encoder.image_token_id |
|
|
assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}" |
|
|
|
|
|
|
|
|
text_locations = input_ids != image_token_id |
|
|
image_locations = input_ids == image_token_id |
|
|
|
|
|
|
|
|
text_features = self.model.tok_embeddings(input_ids[text_locations]) |
|
|
|
|
|
|
|
|
images = images.to(device=text_features.device, dtype=text_features.dtype) |
|
|
vit_outputs = self.vision_encoder(images) |
|
|
image_features = self.mm_projector(vit_outputs) |
|
|
|
|
|
|
|
|
B, seq_len = input_ids.shape |
|
|
N_total = B * seq_len |
|
|
N_txt, D_txt = text_features.shape |
|
|
N_img, N_patch, D_img = image_features.shape |
|
|
|
|
|
|
|
|
image_features = image_features.reshape(N_img * N_patch, D_img) |
|
|
|
|
|
|
|
|
assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" |
|
|
assert ( |
|
|
N_total == N_txt + N_img * N_patch |
|
|
), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}" |
|
|
|
|
|
|
|
|
combined_features = torch.empty( |
|
|
(B, seq_len, D_txt), |
|
|
dtype=text_features.dtype, |
|
|
device=text_features.device, |
|
|
) |
|
|
combined_features[text_locations, :] = text_features |
|
|
combined_features[image_locations, :] = image_features |
|
|
|
|
|
return combined_features |
|
|
|
|
|
def on_after_backward(self, iteration: int = 0): |
|
|
""" |
|
|
Hook after loss.backward() is called. |
|
|
|
|
|
This method is called immediately after the backward pass, allowing for custom operations |
|
|
or modifications to be performed on the gradients before the optimizer step. |
|
|
|
|
|
So far, this method is used to all-reduce layernorm grads for tensor/sequence parallelism. |
|
|
|
|
|
Args: |
|
|
iteration (int): Current iteration number. |
|
|
""" |
|
|
for module in self.children(): |
|
|
if hasattr(module, "on_after_backward"): |
|
|
module.on_after_backward(iteration) |
|
|
|
|
|
def on_before_zero_grad( |
|
|
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int |
|
|
) -> None: |
|
|
"""Hook before zero_grad() is called. |
|
|
|
|
|
Args: |
|
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
|
iteration (int): Current iteration number. |
|
|
""" |
|
|
for module in self.children(): |
|
|
if hasattr(module, "on_before_zero_grad"): |
|
|
module.on_before_zero_grad(optimizer, scheduler, iteration) |
|
|
|
|
|
@property |
|
|
def fsdp_wrap_block_cls(self): |
|
|
""" |
|
|
Return the transformer block class to wrap with FSDP. |
|
|
""" |
|
|
if self.config.backend == "pytorch": |
|
|
return TransformerBlock |
|
|
elif self.config.backend == "transformer_engine": |
|
|
return TransformerBlockTE |
|
|
else: |
|
|
raise ValueError(f"Unknown backend: {self.config.backend}") |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
""" |
|
|
Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). |
|
|
""" |
|
|
state_dict = super().state_dict(*args, **kwargs) |
|
|
return process_state_dict(state_dict) |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): |
|
|
""" |
|
|
Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by |
|
|
TransformerEngine for FP8). |
|
|
""" |
|
|
state_dict = process_state_dict(state_dict) |
|
|
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) |
|
|
actual_missing_keys = [] |
|
|
for key in missing_keys: |
|
|
if not any(substring in key for substring in substrings_to_ignore): |
|
|
actual_missing_keys.append(key) |
|
|
if strict: |
|
|
if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: |
|
|
raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") |
|
|
return _IncompatibleKeys(actual_missing_keys, unexpected_keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|