| import torch | |
| from transformers import WhisperForConditionalGeneration | |
| def mix_language_embeddings(model: WhisperForConditionalGeneration, tokenizer, languages=['zh', 'en'], target_language='zh', weights=None): | |
| target_id = tokenizer.convert_tokens_to_ids(f"<|{target_language}|>") | |
| new_embedding = torch.zeros(model.model.decoder.embed_tokens.weight[target_id].shape, dtype=model.model.decoder.embed_tokens.weight[target_id].dtype) | |
| if weights is None: | |
| weights = [1/len(languages)] * len(languages) | |
| with torch.no_grad(): | |
| for language, weight in zip(languages, weights): | |
| language_id = tokenizer.convert_tokens_to_ids(f"<|{language}|>") | |
| new_embedding += model.model.decoder.embed_tokens.weight[language_id] * weight | |
| model.model.decoder.embed_tokens.weight[target_id] = new_embedding | |
| return model |