RWKV-TTS / third_party /train_scripts /train_functions.py
kevinwang676's picture
Add files using upload-large-folder tool
fd82c69 verified
import torch
import torch.nn as nn
def train_step(model,batch):
batch = {k: v.to(model.device) for k, v in batch.items()}
output = model(batch)
return output
def alter_emb_and_head(model, vocab_size, audio_token_size):
old_embeddings = model.model.embeddings
if vocab_size < model.config.vocab_size:
print(f'No need to enlarge the vocabulary size: {model.config.vocab_size}')
# 创建并初始化新的 embedding 层
print(f'Enlarging vocabulary size from {model.config.vocab_size} to {vocab_size}')
embedding_dim = old_embeddings.weight.size(1)
current_vocab_size = old_embeddings.weight.size(0)
new_embeddings = nn.Embedding(vocab_size, embedding_dim)
with torch.no_grad():
new_embeddings.weight[:current_vocab_size, :] = old_embeddings.weight.data
std = old_embeddings.weight.std().item()
new_embeddings.weight[current_vocab_size:, :].normal_(mean=0.0, std=std)
model.model.embeddings = new_embeddings
model.config.vocab_size = vocab_size
# old_head = model.lm_head
head_dim = model.config.hidden_size
new_head = nn.Linear(head_dim, audio_token_size+1)
with torch.no_grad():
#init the new head with random values
new_head.weight.normal_(mean=0.0, std=0.02)
model.lm_head = new_head
print(f'Enlarging head size from {head_dim} to {audio_token_size}')
return model