from .configuration_diff_llama import DiffusionLlamaConfig from lit_gpt.diffmodel import TransEncoder from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast import torch import torch.nn as nn from torch.nn import init import math from typing import Optional, Union, Tuple class DiffusionLlamaLM(PreTrainedModel): config_class = DiffusionLlamaConfig base_model_prefix = "model" def __init__(self, config: DiffusionLlamaConfig): super().__init__(config) self.model = TransEncoder(config) # Initialize weights (Training feature) self.post_init() def _init_weights(self, module: nn.Module) -> None: """ Initialization logic for training. Adapted from original TransEncoder._init_weights. """ n_layer = self.config.n_layer if isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) if module.bias is not None: torch.nn.init.zeros_(module.bias) # Special initialization for SwiGLU / Projections based on names # In HF _init_weights, 'module' is the current leaf. We check specific instances. # if isinstance(module, LLaMAMLP): # module is LLaMAMLP for name, p in module.named_parameters(): if "proj.weight" in name: nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) # if isinstance(module, SwiGLU): # for name, p in module.named_parameters(): # if "w3.weight" in name: # nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) # if isinstance(module, SelfAttention): # for name, p in module.named_parameters(): # if "proj.weight" in name: # nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict logits = self.model(input_ids) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: return ((loss,) + (logits,)) if loss is not None else (logits,) return CausalLMOutputWithPast(loss=loss, logits=logits)