|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
def train_step(model,batch): |
|
|
batch = {k: v.to(model.device) for k, v in batch.items()} |
|
|
output = model(batch) |
|
|
return output |
|
|
|
|
|
def alter_emb_and_head(model, vocab_size, audio_token_size): |
|
|
old_embeddings = model.model.embeddings |
|
|
if vocab_size < model.config.vocab_size: |
|
|
print(f'No need to enlarge the vocabulary size: {model.config.vocab_size}') |
|
|
|
|
|
|
|
|
print(f'Enlarging vocabulary size from {model.config.vocab_size} to {vocab_size}') |
|
|
embedding_dim = old_embeddings.weight.size(1) |
|
|
current_vocab_size = old_embeddings.weight.size(0) |
|
|
new_embeddings = nn.Embedding(vocab_size, embedding_dim) |
|
|
with torch.no_grad(): |
|
|
new_embeddings.weight[:current_vocab_size, :] = old_embeddings.weight.data |
|
|
std = old_embeddings.weight.std().item() |
|
|
new_embeddings.weight[current_vocab_size:, :].normal_(mean=0.0, std=std) |
|
|
model.model.embeddings = new_embeddings |
|
|
model.config.vocab_size = vocab_size |
|
|
|
|
|
head_dim = model.config.hidden_size |
|
|
new_head = nn.Linear(head_dim, audio_token_size+1) |
|
|
with torch.no_grad(): |
|
|
|
|
|
new_head.weight.normal_(mean=0.0, std=0.02) |
|
|
model.lm_head = new_head |
|
|
print(f'Enlarging head size from {head_dim} to {audio_token_size}') |
|
|
return model |