zzy1123's picture
Upload folder using huggingface_hub
7ebf906 verified
from .configuration_diff_llama import DiffusionLlamaConfig
from lit_gpt.diffmodel import TransEncoder
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
import torch.nn as nn
from torch.nn import init
import math
from typing import Optional, Union, Tuple
class DiffusionLlamaLM(PreTrainedModel):
config_class = DiffusionLlamaConfig
base_model_prefix = "model"
def __init__(self, config: DiffusionLlamaConfig):
super().__init__(config)
self.model = TransEncoder(config)
# Initialize weights (Training feature)
self.post_init()
def _init_weights(self, module: nn.Module) -> None:
"""
Initialization logic for training.
Adapted from original TransEncoder._init_weights.
"""
n_layer = self.config.n_layer
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# Special initialization for SwiGLU / Projections based on names
# In HF _init_weights, 'module' is the current leaf. We check specific instances.
# if isinstance(module, LLaMAMLP):
# module is LLaMAMLP
for name, p in module.named_parameters():
if "proj.weight" in name:
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
# if isinstance(module, SwiGLU):
# for name, p in module.named_parameters():
# if "w3.weight" in name:
# nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
# if isinstance(module, SelfAttention):
# for name, p in module.named_parameters():
# if "proj.weight" in name:
# nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
logits = self.model(input_ids)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
return ((loss,) + (logits,)) if loss is not None else (logits,)
return CausalLMOutputWithPast(loss=loss, logits=logits)