|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
import torchviz |
|
|
import yaml |
|
|
import argparse |
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
|
|
import logging |
|
|
import hydra |
|
|
import torch.utils.checkpoint as checkpoint |
|
|
from typing import Optional, Dict, Any |
|
|
from rotary_embedding_torch import RotaryEmbedding |
|
|
from rich import traceback |
|
|
from attention import ComplexMultiHeadAttentionV2 |
|
|
from accelerate import Accelerator, DeepSpeedPlugin |
|
|
from accelerate.utils import DistributedType |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__) |
|
|
logger.propagate = False |
|
|
logger.addHandler(logging.FileHandler("logs/model.log")) |
|
|
|
|
|
|
|
|
class CustomConfig(PretrainedConfig): |
|
|
model_type: str = "ComplexFormer" |
|
|
|
|
|
def __init__(self, |
|
|
vocab_size: int = 30522, |
|
|
hidden_dim: int = 512, |
|
|
intermediate_size: int = 1024, |
|
|
max_seq_len: int = 512, |
|
|
n_layers: int = 8, |
|
|
num_attention_heads: int = 8, |
|
|
dropout: float = 0.0, |
|
|
**kwargs: Any): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size: int = vocab_size |
|
|
self.hidden_dim: int = hidden_dim |
|
|
self.intermediate_size: int = intermediate_size |
|
|
self.max_seq_len: int = max_seq_len |
|
|
self.n_layers: int = n_layers |
|
|
self.num_attention_heads: int = num_attention_heads |
|
|
self.dropout: float = dropout |
|
|
self.debug : bool = kwargs.get("debug", False) |
|
|
|
|
|
|
|
|
class ComplexFormerModel(PreTrainedModel): |
|
|
config_class = CustomConfig |
|
|
|
|
|
|
|
|
def __init__(self, config: CustomConfig): |
|
|
super().__init__(config) |
|
|
self.config: CustomConfig = config |
|
|
self.embedding: nn.Embedding = nn.Embedding(config.vocab_size, config.hidden_dim) |
|
|
self.pos_embedding = PositionalEncoding(config.hidden_dim, config.max_seq_len, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) |
|
|
self.head_dim = config.hidden_dim // config.num_attention_heads |
|
|
self.transformer_blocks: nn.ModuleList = nn.ModuleList( |
|
|
[TransformerBlock(config) for _ in range(config.n_layers)] |
|
|
) |
|
|
self.linear: nn.Linear = nn.Linear(config.hidden_dim, config.vocab_size) |
|
|
self.softmax: nn.Softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
self.rope = RotaryEmbedding(dim=self.head_dim) |
|
|
self.gradient_checkpointing = False |
|
|
self.debug = config.debug |
|
|
self._togger_loger() |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def can_generate(self)-> bool: |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids: Tensor, attention_mask: Tensor, labels: Optional[Tensor] = None, use_checkpoint: bool = False,token_type_ids = None) -> Tensor: |
|
|
|
|
|
if torch.isnan(input_ids).any() or torch.isinf(input_ids).any(): |
|
|
raise ValueError("Tensor contains NaN or Inf values.") |
|
|
|
|
|
if (input_ids >= self.config.vocab_size).any() or (input_ids < 0).any(): |
|
|
raise ValueError("input_ids contain values outside the valid range.") |
|
|
seq_len: int = input_ids.size(1) |
|
|
logger.info(f"Input IDs shape: {input_ids.shape}") |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
token_embeddings = self.embedding(input_ids).to(device) |
|
|
x: Tensor = token_embeddings |
|
|
|
|
|
if self.config.complex_attention is False: |
|
|
position_embeddings = self.pos_embedding(input_ids).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
position_embeddings = position_embeddings.to(device) |
|
|
|
|
|
assert token_embeddings.device == position_embeddings.device, "input_ids and position_embedding must be on the same device." |
|
|
logger.info(f"Token Embeddings shape: {token_embeddings.device}") |
|
|
|
|
|
logger.info(f"Position Embeddings shape: {position_embeddings.device}") |
|
|
x: Tensor = token_embeddings + position_embeddings |
|
|
logger.info(f"Input embeddings shape: {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for block in self.transformer_blocks: |
|
|
if self.gradient_checkpointing: |
|
|
def create_custom_forward(module): |
|
|
|
|
|
def custom_forward(*inputs): |
|
|
|
|
|
return module(inputs[0], padding_mask=captured_key_padding_mask) |
|
|
return custom_forward |
|
|
|
|
|
captured_key_padding_mask = attention_mask |
|
|
layer_outputs = checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
x, |
|
|
use_reentrant=False, |
|
|
preserve_rng_state=True |
|
|
) |
|
|
x = layer_outputs |
|
|
logger.info(f"Transformer block output shape with ckpt: {x.shape}") |
|
|
else: |
|
|
x = block(x, padding_mask=attention_mask) |
|
|
logger.info(f"Transformer block output shape no ckpt: {x.shape}") |
|
|
logger.info(f"Transformer block output shape: {x.shape}") |
|
|
x = self.linear(x) |
|
|
logger.info(f"Linear layer output shape: {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
def gradient_checkpointing_enable(self,**kwargs): |
|
|
self.gradient_checkpointing = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
nn.init.kaiming_normal_(module.weight, nonlinearity='leaky_relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.kaiming_normal_(module.weight, nonlinearity='leaky_relu') |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
elif isinstance(module, (RMSNorm, nn.LayerNorm)): |
|
|
if hasattr(module, 'weight') and module.weight is not None: |
|
|
nn.init.ones_(module.weight) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
def _togger_loger(self): |
|
|
if self.debug == False: |
|
|
logger.setLevel(logging.WARNING) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, |
|
|
input_ids: Tensor, |
|
|
attention_mask: Optional[Tensor] = None, |
|
|
max_length: int = 100, |
|
|
temperature: float = 1.0, |
|
|
top_k: int = 0, |
|
|
top_p: float = 1.0, |
|
|
repetition_penalty: float = 1.0, |
|
|
eos_token_id: Optional[int] = None, |
|
|
pad_token_id: Optional[int] = None, |
|
|
use_cache: bool = True, |
|
|
**kwargs): |
|
|
|
|
|
self.eval() |
|
|
if pad_token_id is None and self.config.pad_token_id is not None: |
|
|
pad_token_id = self.config.pad_token_id |
|
|
if eos_token_id is None and self.config.eos_token_id is not None: |
|
|
eos_token_id = self.config.eos_token_id |
|
|
|
|
|
batch_size, cur_len = input_ids.shape |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
model_kwargs = {"attention_mask": attention_mask, **kwargs} |
|
|
|
|
|
|
|
|
|
|
|
generated_ids = input_ids.clone().detach() |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
|
|
|
for _ in range(max_length - cur_len): |
|
|
|
|
|
model_inputs = {"input_ids":generated_ids,"attention_mask":attention_mask} |
|
|
|
|
|
|
|
|
outputs = self( |
|
|
**model_inputs, |
|
|
|
|
|
) |
|
|
logits = outputs |
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
for i in range(batch_size): |
|
|
for token_id in set(generated_ids[i].tolist()): |
|
|
next_token_logits[i, token_id] /= repetition_penalty |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
top_k_values, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) |
|
|
|
|
|
top_k_mask = torch.ones_like(next_token_logits, dtype=torch.bool).scatter_(-1, top_k_indices, False) |
|
|
next_token_logits.masked_fill_(top_k_mask, float('-inf')) |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits.masked_fill_(indices_to_remove, float('-inf')) |
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token_id = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = torch.cat( |
|
|
[attention_mask, attention_mask.new_ones((batch_size, 1))], dim=-1 |
|
|
) |
|
|
model_kwargs["attention_mask"] = attention_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if eos_token_id is not None and (next_token_id == eos_token_id).all(): |
|
|
logger.info("EOS token generated for all sequences in batch.") |
|
|
break |
|
|
|
|
|
|
|
|
if generated_ids.shape[1] >= max_length: |
|
|
break |
|
|
|
|
|
return generated_ids |
|
|
|
|
|
|
|
|
class FFN(nn.Module): |
|
|
def __init__(self, config: CustomConfig): |
|
|
super().__init__() |
|
|
self.linear1: nn.Linear = nn.Linear(config.hidden_dim, config.intermediate_size) |
|
|
self.linear3: nn.Linear = nn.Linear(config.intermediate_size, config.hidden_dim) |
|
|
self.linear2: nn.Linear = nn.Linear(config.hidden_dim, config.intermediate_size) |
|
|
self.activation: nn.Module = nn.SiLU() |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
gate = self.linear1(x) |
|
|
up = self.linear2(x) |
|
|
|
|
|
|
|
|
activated_up = self.activation(up) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if gate.dtype == torch.float32: |
|
|
gate_bf16 = gate.to(torch.bfloat16) |
|
|
activated_up_bf16 = activated_up.to(torch.bfloat16) |
|
|
intermediate = gate_bf16 * activated_up_bf16 |
|
|
x = self.linear3(intermediate.to(torch.float32)) |
|
|
else: |
|
|
intermediate = gate * activated_up |
|
|
x = self.linear3(intermediate) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
""" |
|
|
compute sinusoid encoding. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model, max_len, device): |
|
|
""" |
|
|
constructor of sinusoid encoding class |
|
|
|
|
|
:param d_model: dimension of model |
|
|
:param max_len: max sequence length |
|
|
:param device: hardware device setting |
|
|
""" |
|
|
super(PositionalEncoding, self).__init__() |
|
|
|
|
|
|
|
|
self.encoding = torch.zeros(max_len, d_model, device=device) |
|
|
self.encoding.requires_grad = False |
|
|
|
|
|
pos = torch.arange(0, max_len, device=device) |
|
|
pos = pos.float().unsqueeze(dim=1) |
|
|
|
|
|
|
|
|
_2i = torch.arange(0, d_model, step=2, device=device).float() |
|
|
|
|
|
|
|
|
|
|
|
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) |
|
|
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
batch_size, seq_len = x.size() |
|
|
|
|
|
|
|
|
return self.encoding[:seq_len, :] |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size: int): |
|
|
super().__init__() |
|
|
self.weight: nn.Parameter = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.eps: float = 1e-5 |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return x * (torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * self.weight + self.eps) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__(self, config: CustomConfig): |
|
|
super().__init__() |
|
|
self.attention = ComplexMultiHeadAttentionV2(config.hidden_dim, config.num_attention_heads) |
|
|
self.ffn: FFN = FFN(config) |
|
|
self.rmsnorm: RMSNorm = RMSNorm(config.hidden_dim) |
|
|
self.config: CustomConfig = config |
|
|
|
|
|
def forward(self, x: Tensor, padding_mask: Tensor) -> Tensor: |
|
|
residual: Tensor = x |
|
|
x = self.rmsnorm(x) |
|
|
logger.info(f"TransformerBlock Input shape: {x.shape}") |
|
|
seq_len: int = x.shape[1] |
|
|
causal_mask: Tensor = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device) |
|
|
|
|
|
|
|
|
if self.config.complex_attention: |
|
|
logger.info('using complex attention') |
|
|
x = self.attention(x, x, x, mask=padding_mask) |
|
|
else: |
|
|
|
|
|
def atten_forward(x: Tensor) -> Tensor: |
|
|
return self.attention(x, x, x, attn_mask=causal_mask)[0] |
|
|
logger.info(f'before forward{x.dtype}') |
|
|
x = atten_forward(x.bfloat16()) |
|
|
logger.info(f'after forward{x.dtype}') |
|
|
x = self.rmsnorm(x + residual) |
|
|
logger.info(f"TransformerBlock after rmsnorm shape: {x.shape}") |
|
|
|
|
|
if torch.isnan(x).any(): |
|
|
logger.info(f"TransformerBlock after rmsnorm has nan") |
|
|
raise ValueError("TransformerBlock after rmsnorm has nan") |
|
|
residual = x |
|
|
x = self.ffn(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_config(config_path: str) -> CustomConfig: |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict: Dict[str, Any] = yaml.safe_load(f) |
|
|
return CustomConfig(**config_dict) |
|
|
|
|
|
|
|
|
@hydra.main(config_path='.', config_name="config.yaml") |
|
|
def main(config: Dict[str, Any]) -> None: |
|
|
argparser = argparse.ArgumentParser() |
|
|
argparser.add_argument('--config', type=str, default='./pretrain/config.yaml', help='Path to the config file') |
|
|
args = argparser.parse_args() |
|
|
|
|
|
AutoConfig.register(config.architecture, CustomConfig) |
|
|
AutoModel.register(CustomConfig, ComplexFormerModel) |
|
|
|
|
|
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model: ComplexFormerModel = ComplexFormerModel(config=load_config(args.config)).to(device) |
|
|
|
|
|
|
|
|
|
|
|
test_model(model,config) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
model.save_pretrained(config.model.save_dir) |
|
|
logger.info(f"Model saved to {config.model.save_dir}") |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in model.state_dict().values() if torch.is_tensor(p)) |
|
|
print(f"Trainable model parameters: {num_params/1e9:,}B") |
|
|
print_model_parameters_llm(model) |
|
|
|
|
|
|
|
|
|
|
|
def print_model_parameters_llm(model_to_inspect): |
|
|
if not isinstance(model_to_inspect, (AutoModel, torch.nn.Module)): |
|
|
print("传入的不是一个有效的PyTorch模型。") |
|
|
return |
|
|
|
|
|
print(f"\n模型架构 (部分,如果太大):") |
|
|
|
|
|
|
|
|
|
|
|
print(type(model_to_inspect)) |
|
|
|
|
|
|
|
|
print("\n参数详情:") |
|
|
print("---------------------------------------------------------------------------------------------------------------") |
|
|
|
|
|
print(f"{'Parameter Name':<70} | {'Shape':<25} | {'Numel':<12} | {'Requires Grad':<15} | {'Dtype':<10}") |
|
|
print("---------------------------------------------------------------------------------------------------------------") |
|
|
|
|
|
total_params = 0 |
|
|
trainable_params = 0 |
|
|
|
|
|
for name, param in model_to_inspect.named_parameters(): |
|
|
numel = param.numel() |
|
|
total_params += numel |
|
|
if param.requires_grad: |
|
|
trainable_params += numel |
|
|
|
|
|
|
|
|
display_name = name |
|
|
|
|
|
|
|
|
|
|
|
print(f"{display_name:<70} | {str(param.shape):<25} | {numel:<12,} | {str(param.requires_grad):<15} | {str(param.dtype).replace('torch.', ''):<10}") |
|
|
|
|
|
print("---------------------------------------------------------------------------------------------------------------") |
|
|
print(f"总参数量 (Total parameters): {total_params:,}") |
|
|
print(f"可训练参数量 (Trainable parameters): {trainable_params:,}") |
|
|
if total_params != trainable_params: |
|
|
print(f"不可训练/冻结参数量 (Non-trainable parameters): {total_params - trainable_params:,}") |
|
|
print("---------------------------------------------------------------------------------------------------------------") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def test_model(model: ComplexFormerModel,config) -> None: |
|
|
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
logger.info(f"Testing model...{device}") |
|
|
|
|
|
input_ids: Tensor = torch.randint(0, config.vocab_size, (config.batch_size, config.max_seq_len)).to(device) |
|
|
attention_mask: Tensor = torch.ones((config.batch_size, config.max_seq_len )).bool().to(device) |
|
|
labels: Tensor = torch.randint(0, config.vocab_size, (config.batch_size, config.max_seq_len)).to(device) |
|
|
output: Tensor = model(input_ids, attention_mask=attention_mask, labels=labels, use_checkpoint=True).to(device) |
|
|
logger.info(f"Model output shape: {output.shape}") |
|
|
assert output.shape == (config.batch_size, config.max_seq_len, config.vocab_size), "Output shape mismatch" |
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
loss = criterion(output.view(-1, config.vocab_size), labels.view(-1)) |
|
|
logger.info(f"Loss: {loss.dim()}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|
|
|
AutoConfig.register("ComplexFormer", CustomConfig) |
|
|
AutoModel.register(CustomConfig, ComplexFormerModel) |
|
|
|
|
|
|
|
|
|