| import copy |
| import inspect |
| import sys |
| from abc import ABC, abstractmethod |
| from collections import OrderedDict |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import hydra.utils |
| import torch |
| from hydra.errors import InstantiationException |
| from transformers import ( |
| AutoTokenizer, |
| DynamicCache, |
| GenerationConfig, |
| LogitsProcessorList, |
| PretrainedConfig, |
| PreTrainedModel, |
| StoppingCriteriaList, |
| ) |
| from transformers.cache_utils import Cache |
| from transformers.generation.utils import GenerateOutput |
| from transformers.modeling_outputs import ModelOutput |
|
|
| |
| |
| from .backbone_automodel import AutoModelFromPreTrained |
| from .backbone_encoder_decoder import ( |
| LLMasEncoderDecoder, |
| LLMasEncoderDecoderShareKV, |
| ) |
| from .noise_schedule_noise_schedules import ( |
| CosineNoise, |
| ExponentialNoise, |
| LinearNoise, |
| LogarithmicNoise, |
| ) |
|
|
|
|
| @dataclass |
| class DenoiserInput(OrderedDict): |
| """Input to the denoiser model.""" |
|
|
| xt: torch.LongTensor |
| x0: Optional[torch.LongTensor] = None |
| attention_mask: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Union[torch.FloatTensor, Cache]] = None |
| context_mask: Optional[torch.FloatTensor] = None |
| tokens_mask: Optional[torch.FloatTensor] = None |
| t: Optional[torch.FloatTensor] = None |
| alpha_t: Optional[torch.FloatTensor] = None |
| alpha_t_prime: Optional[torch.FloatTensor] = None |
| backbone_kwargs: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class LossAndNllOutput(OrderedDict): |
| """Loss output for denoiser models.""" |
|
|
| loss: torch.FloatTensor |
| nlls: torch.FloatTensor |
| other_loss_terms: dict = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class DenoiserOutput(ModelOutput): |
| """Output of the denoiser model.""" |
|
|
| denoiser_output: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| tokens_mask: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| loss: Optional[torch.FloatTensor] = None |
| nlls: Optional[torch.FloatTensor] = None |
| other_loss_terms: Optional[dict[str, Any]] = None |
|
|
|
|
| class DenoiserConfig(PretrainedConfig): |
| """Configuration class for Denoiser models. |
| |
| This class is used to initialize the model and contains all the necessary |
| parameters for the model's architecture. |
| """ |
|
|
| model_type = "denoiser" |
|
|
| def __init__( |
| self, |
| length: Optional[int] = None, |
| backbone_config: Optional[Dict[str, Any]] = None, |
| noise_config: Optional[Dict[str, Any]] = None, |
| tokenization_config: Optional[Dict[str, Any]] = None, |
| time_conditioned_backbone: Optional[bool] = None, |
| attn_backend: str = "sdpa", |
| train_on_context: bool = False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| for v in [ |
| "vocab_size", |
| "mask_token_id", |
| "pad_token_id", |
| "bos_token_id", |
| "eos_token_id", |
| "pad_vocab_size_multiple", |
| ]: |
| if tokenization_config is not None and ( |
| getattr(self, v, None) is None or v in tokenization_config |
| ): |
| setattr(self, v, tokenization_config.get(v, None)) |
| else: |
| setattr(self, v, None) |
| self.backbone_config = backbone_config |
| self.noise_config = noise_config |
| self.tokenization_config = tokenization_config |
| self.length = length |
| self.time_conditioned_backbone = time_conditioned_backbone |
| self.attn_backend = attn_backend |
| self.train_on_context = train_on_context |
|
|
|
|
| class Denoiser(ABC, PreTrainedModel): |
| """Abstract base class for denoising models. |
| |
| This class defines the interface for AR, Diffusion, and Flow-based parametrizations. |
| """ |
|
|
| config_class = DenoiserConfig |
|
|
| def __init__( |
| self, |
| config: DenoiserConfig, |
| **kwargs, |
| ): |
| """ |
| Initialize the Denoiser with a configuration and optional dataset type. |
| |
| Parameters: |
| config (Any): Configuration object for the model. |
| """ |
| super().__init__(config) |
| self.config = config |
| self.vocab_size = config.vocab_size |
| self.mask_token_id = config.mask_token_id |
| self.pad_token_id = config.pad_token_id |
| self.bos_token_id = config.bos_token_id |
| self.eos_token_id = config.eos_token_id |
| try: |
| self.backbone = hydra.utils.instantiate(config.backbone_config) |
| except InstantiationException: |
| |
| |
| |
| |
| |
| |
| |
| sys_modules = copy.deepcopy(list(sys.modules.keys())) |
| repo_root_module = ".".join(__name__.split(".")[:-1]) |
| for name in sys_modules: |
| if name.startswith(repo_root_module): |
| short = name.split(".")[-1] |
| if short not in sys.modules: |
| sys.modules[short] = sys.modules[name] |
| del sys_modules |
| self.backbone = hydra.utils.instantiate(config.backbone_config) |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| config.tokenizer_name, |
| trust_remote_code=True, |
| ) |
| self.noise_schedule = ( |
| hydra.utils.instantiate(config.noise_config) |
| if config.noise_config is not None |
| else None |
| ) |
| self.time_conditioned_backbone = ( |
| config.time_conditioned_backbone |
| if config.time_conditioned_backbone is not None |
| else "noise" in inspect.getfullargspec(self.backbone.forward).args |
| ) |
| |
| |
| self.skip_params_for_push = [] |
|
|
| @abstractmethod |
| def _prepare_inputs( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| context_mask: Optional[torch.FloatTensor] = None, |
| t: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| ) -> DenoiserInput: |
| """ |
| Prepare inputs for the model. |
| |
| Parameters: |
| input_ids (LongTensor): Input tensor to the model. |
| attention_mask (Optional[FloatTensor]): Attention mask for the model. |
| t (Optional[FloatTensor]): Time step for the model. |
| past_key_values (Optional[Cache]): Past key values for the model. |
| Returns: |
| Denoiser inputs. |
| """ |
| raise NotImplementedError("Denoiser subclasses must implement _prepare_inputs") |
|
|
| def _prepare_inputs_inference( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| context: Optional[torch.LongTensor] = None, |
| context_mask: Optional[torch.FloatTensor] = None, |
| cache: Optional[Dict[str, Any]] = None, |
| **backbone_kwargs: Any, |
| ) -> Tuple[DenoiserInput, Dict[str, Any]]: |
| raise NotImplementedError( |
| "Denoiser subclasses must implement _prepare_inputs_inference" |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @abstractmethod |
| def _compute_loss( |
| self, |
| model_output: torch.FloatTensor, |
| denoiser_inputs: DenoiserInput, |
| **kwargs: Any, |
| ) -> LossAndNllOutput: |
| """ |
| Compute the loss for the denoising model. |
| |
| Parameters: |
| model_output (FloatTensor): Output tensor from self.forward. |
| denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
| |
| Returns: |
| LossAndNllOutput: loss (FloatTensor) and nlls (FloatTensor). |
| """ |
| raise NotImplementedError("Denoiser subclasses must implement _compute_loss") |
|
|
| def _forward( |
| self, |
| backbone_output: torch.FloatTensor, |
| denoiser_inputs: DenoiserInput, |
| **kwargs: Any, |
| ) -> torch.FloatTensor: |
| """ |
| Forward pass for the denoiser model returns probabilities over denoised |
| sequence. |
| |
| Some classes may need to override this method. |
| |
| Parameters: |
| backbone_output (FloatTensor): Output tensor from the backbone model. |
| denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
| |
| Returns: |
| Model outputs (FloatTensor). |
| """ |
| return torch.log_softmax(backbone_output, dim=-1) |
|
|
| def _backbone_forward( |
| self, |
| denoiser_inputs: DenoiserInput, |
| **backbone_kwargs: Any, |
| ) -> ModelOutput: |
| """Forward pass for the backbone model (should return logits). |
| |
| Some classes may need to override this method. |
| |
| Parameters: |
| denoiser_inputs (DenoiserInput): Inputs passed to the denoiser model. |
| return_updated_cache (bool): If True, return past_key_values instead of |
| logits. |
| |
| Returns: |
| Backbone output (ModelOutput instance). |
| """ |
| if self.time_conditioned_backbone: |
| return self.backbone( |
| denoiser_inputs.xt, |
| attention_mask=denoiser_inputs.attention_mask, |
| past_key_values=denoiser_inputs.past_key_values, |
| noise=denoiser_inputs.alpha_t, |
| **denoiser_inputs.backbone_kwargs, |
| **backbone_kwargs, |
| ) |
| return self.backbone( |
| denoiser_inputs.xt, |
| attention_mask=denoiser_inputs.attention_mask, |
| past_key_values=denoiser_inputs.past_key_values, |
| **denoiser_inputs.backbone_kwargs, |
| **backbone_kwargs, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| context_mask: Optional[torch.FloatTensor] = None, |
| t: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| compute_loss: Optional[bool] = True, |
| **kwargs, |
| ) -> DenoiserOutput: |
| """ |
| Perform a forward pass through the denoising model and |
| (optionally) compute the loss. |
| |
| Parameters: |
| input_ids (LongTensor): Input tensor to the model. |
| attention_mask (Optional[FloatTensor]): Attention mask for the model. |
| context_mask (Optional[FloatTensor]): Indicator for context tokens. |
| t (Optional[FloatTensor]): Denoising time step for the model. |
| past_key_values (Optional[Cache]): KV cache. |
| compute_loss (Optional[bool]): Flag to compute loss. |
| |
| Returns: |
| DenoiserOutput |
| """ |
| denoiser_inputs = self._prepare_inputs( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| context_mask=context_mask, |
| past_key_values=past_key_values, |
| t=t, |
| ) |
|
|
| backbone_output = self._backbone_forward(denoiser_inputs, **kwargs) |
| new_past_key_values = getattr(backbone_output, "past_key_values", None) |
| backbone_output = getattr(backbone_output, "logits", backbone_output[0]) |
| denoiser_output = self._forward( |
| backbone_output, |
| denoiser_inputs, |
| **kwargs, |
| ) |
|
|
| if compute_loss: |
| loss_and_nll = self._compute_loss( |
| model_output=denoiser_output, denoiser_inputs=denoiser_inputs, **kwargs |
| ) |
| loss = loss_and_nll.loss |
| nlls = loss_and_nll.nlls |
| other_loss_terms = loss_and_nll.other_loss_terms |
| else: |
| loss, nlls = None, None |
| other_loss_terms = {} |
|
|
| return DenoiserOutput( |
| denoiser_output=denoiser_output, |
| logits=backbone_output, |
| past_key_values=new_past_key_values, |
| tokens_mask=denoiser_inputs.tokens_mask, |
| loss=loss, |
| nlls=nlls, |
| other_loss_terms=other_loss_terms, |
| ) |
|
|
| @staticmethod |
| def _sample_categorical(categorical_probs, do_sample=True): |
| """Helper function to sample from a categorical distribution.""" |
| categorical_probs = categorical_probs.to(torch.float64) |
| if not do_sample: |
| return categorical_probs.argmax(dim=-1) |
| gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to( |
| categorical_probs.dtype |
| ) |
| return (categorical_probs / gumbel_norm).argmax(dim=-1) |
|
|
| @staticmethod |
| def _preprocess_attention_mask(attention_mask, dtype): |
| min_dtype = torch.finfo(dtype).min |
| attention_mask = torch.where( |
| (attention_mask == 0.0).bool(), |
| min_dtype, |
| 0.0, |
| ).to(dtype) |
| return attention_mask |
|
|
| @staticmethod |
| def _get_past_key_values_seq_length(past_key_values: DynamicCache): |
| seq_length = 0 |
| for i in range(len(past_key_values)): |
| if past_key_values[i][0].shape[0] > 0: |
| seq_length = max( |
| past_key_values[i][0].shape[-2], |
| seq_length, |
| ) |
| return seq_length |
|
|
| def update_cache( |
| self, |
| inputs: torch.LongTensor, |
| cache: Optional[Dict[str, Any]] = None, |
| **backbone_kwargs: Any, |
| ) -> Dict[str, Any]: |
| """ |
| Cache the key-value pairs for the context. |
| Args: |
| inputs (torch.LongTensor): The context tensor. |
| cache (Dict[str, Any | None): Cache objects, e.g., past_key_values. |
| Returns: |
| Dict: Updated cache objects, e.g., past_key_values. |
| """ |
| context_input, cache = self._prepare_inputs_inference( |
| input_ids=inputs, cache=cache, return_updated_cache=True, **backbone_kwargs |
| ) |
| backbone_output = self._backbone_forward( |
| context_input, |
| return_updated_cache=True, |
| **cache, |
| ) |
| backbone_output = {k: v for k, v in backbone_output.items()} |
| backbone_output.pop("logits", None) |
| cache = cache | backbone_output |
| return cache |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.LongTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| max_new_tokens: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| device: Optional[str] = None, |
| **kwargs: Any, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| """Generates sample from denoising model. |
| Follows signature of transformers.GenerationMixin. |
| """ |
| raise NotImplementedError("Denoiser subclasses must implement generate") |
|
|