t_n / novel_model.py
woywan's picture
Upload 12 files
a072099 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: [batch_size, seq_len, embedding_dim]
"""
return x + self.pe[:, :x.size(1), :]
class NovelTransformer(nn.Module):
def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=4,
dim_feedforward=512, dropout=0.1, max_len=2048):
# 减小模型参数:降低维度、减少层数和注意力头数
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, batch_first=True),
num_layers=num_layers
)
self.d_model = d_model
self.linear = nn.Linear(d_model, vocab_size)
self.init_weights()
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.linear.bias.data.zero_()
self.linear.weight.data.uniform_(-initrange, initrange)
def forward(self, src, src_mask=None):
# 使用梯度检查点来节省内存
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
# 分块处理长序列
if src.size(1) > 1024 and not torch.is_grad_enabled():
# 推理时分块处理
chunks = []
chunk_size = 1024
for i in range(0, src.size(1), chunk_size):
end = min(i + chunk_size, src.size(1))
chunk = self.transformer_encoder(src[:, i:end, :])
chunks.append(chunk)
output = torch.cat(chunks, dim=1)
else:
output = self.transformer_encoder(src, src_mask)
output = self.linear(output)
return output
class NovelLM(nn.Module):
"""用于指令微调的模型"""
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def forward(self, input_ids, attention_mask=None):
return self.base_model(input_ids, attention_mask)