distil-whisper / utils /model_utils.py
dmnph's picture
Saving train state of step 1000
4095301 verified
import torch
from transformers import WhisperForConditionalGeneration
def mix_language_embeddings(model: WhisperForConditionalGeneration, tokenizer, languages=['zh', 'en'], target_language='zh', weights=None):
target_id = tokenizer.convert_tokens_to_ids(f"<|{target_language}|>")
new_embedding = torch.zeros(model.model.decoder.embed_tokens.weight[target_id].shape, dtype=model.model.decoder.embed_tokens.weight[target_id].dtype)
if weights is None:
weights = [1/len(languages)] * len(languages)
with torch.no_grad():
for language, weight in zip(languages, weights):
language_id = tokenizer.convert_tokens_to_ids(f"<|{language}|>")
new_embedding += model.model.decoder.embed_tokens.weight[language_id] * weight
model.model.decoder.embed_tokens.weight[target_id] = new_embedding
return model