|
|
from .configuration_diff_llama import DiffusionLlamaConfig |
|
|
from lit_gpt.diffmodel import TransEncoder |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import init |
|
|
import math |
|
|
from typing import Optional, Union, Tuple |
|
|
|
|
|
|
|
|
class DiffusionLlamaLM(PreTrainedModel): |
|
|
config_class = DiffusionLlamaConfig |
|
|
base_model_prefix = "model" |
|
|
|
|
|
def __init__(self, config: DiffusionLlamaConfig): |
|
|
super().__init__(config) |
|
|
self.model = TransEncoder(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
""" |
|
|
Initialization logic for training. |
|
|
Adapted from original TransEncoder._init_weights. |
|
|
""" |
|
|
n_layer = self.config.n_layer |
|
|
|
|
|
if isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) |
|
|
elif isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, p in module.named_parameters(): |
|
|
if "proj.weight" in name: |
|
|
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
logits = self.model(input_ids) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
return ((loss,) + (logits,)) if loss is not None else (logits,) |
|
|
|
|
|
return CausalLMOutputWithPast(loss=loss, logits=logits) |