| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| class PositionalEncoding(nn.Module):
|
| def __init__(self, d_model, max_len=5000):
|
| super().__init__()
|
| position = torch.arange(max_len).unsqueeze(1)
|
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
| pe = torch.zeros(max_len, d_model)
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
| self.register_buffer('pe', pe)
|
|
|
| def forward(self, x):
|
| """
|
| Args:
|
| x: [batch_size, seq_len, embedding_dim]
|
| """
|
| return x + self.pe[:, :x.size(1), :]
|
|
|
| class NovelTransformer(nn.Module):
|
| def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=4,
|
| dim_feedforward=512, dropout=0.1, max_len=2048):
|
|
|
| super().__init__()
|
| self.model_type = 'Transformer'
|
| self.pos_encoder = PositionalEncoding(d_model, max_len)
|
| self.embedding = nn.Embedding(vocab_size, d_model)
|
| self.transformer_encoder = nn.TransformerEncoder(
|
| nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
|
| dim_feedforward=dim_feedforward,
|
| dropout=dropout, batch_first=True),
|
| num_layers=num_layers
|
| )
|
| self.d_model = d_model
|
| self.linear = nn.Linear(d_model, vocab_size)
|
|
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| initrange = 0.1
|
| self.embedding.weight.data.uniform_(-initrange, initrange)
|
| self.linear.bias.data.zero_()
|
| self.linear.weight.data.uniform_(-initrange, initrange)
|
|
|
| def forward(self, src, src_mask=None):
|
|
|
| src = self.embedding(src) * math.sqrt(self.d_model)
|
| src = self.pos_encoder(src)
|
|
|
|
|
| if src.size(1) > 1024 and not torch.is_grad_enabled():
|
|
|
| chunks = []
|
| chunk_size = 1024
|
| for i in range(0, src.size(1), chunk_size):
|
| end = min(i + chunk_size, src.size(1))
|
| chunk = self.transformer_encoder(src[:, i:end, :])
|
| chunks.append(chunk)
|
| output = torch.cat(chunks, dim=1)
|
| else:
|
| output = self.transformer_encoder(src, src_mask)
|
|
|
| output = self.linear(output)
|
| return output
|
|
|
| class NovelLM(nn.Module):
|
| """用于指令微调的模型"""
|
| def __init__(self, base_model):
|
| super().__init__()
|
| self.base_model = base_model
|
|
|
| def forward(self, input_ids, attention_mask=None):
|
| return self.base_model(input_ids, attention_mask) |