File size: 1,402 Bytes
fd82c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn

def train_step(model,batch):
    batch = {k: v.to(model.device) for k, v in batch.items()}
    output = model(batch)
    return output

def alter_emb_and_head(model, vocab_size, audio_token_size):
    old_embeddings = model.model.embeddings
    if vocab_size < model.config.vocab_size:
        print(f'No need to enlarge the vocabulary size: {model.config.vocab_size}')
    
    # 创建并初始化新的 embedding 层
    print(f'Enlarging vocabulary size from {model.config.vocab_size} to {vocab_size}')
    embedding_dim = old_embeddings.weight.size(1)
    current_vocab_size = old_embeddings.weight.size(0)
    new_embeddings = nn.Embedding(vocab_size, embedding_dim)
    with torch.no_grad():
        new_embeddings.weight[:current_vocab_size, :] = old_embeddings.weight.data
        std = old_embeddings.weight.std().item()
        new_embeddings.weight[current_vocab_size:, :].normal_(mean=0.0, std=std)
    model.model.embeddings = new_embeddings
    model.config.vocab_size = vocab_size
    # old_head = model.lm_head
    head_dim = model.config.hidden_size
    new_head = nn.Linear(head_dim, audio_token_size+1)
    with torch.no_grad():
        #init the new head with random values
        new_head.weight.normal_(mean=0.0, std=0.02)
    model.lm_head = new_head  
    print(f'Enlarging head size from {head_dim} to {audio_token_size}')
    return model