Spaces:
Runtime error
Runtime error
| """ | |
| This script provides an example to wrap TencentPretrain for embedding extraction. | |
| """ | |
| import sys | |
| import os | |
| import argparse | |
| import torch | |
| tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.append(tencentpretrain_dir) | |
| from tencentpretrain.utils.vocab import Vocab | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--load_model_path", default=None, type=str, | |
| help="Path of the input model.") | |
| parser.add_argument("--vocab_path", default=None, type=str, | |
| help="Path of the vocabulary file.") | |
| parser.add_argument("--spm_model_path", default=None, type=str, | |
| help="Path of the sentence piece model.") | |
| parser.add_argument("--word_embedding_path", default=None, type=str, | |
| help="Path of the output word embedding.") | |
| args = parser.parse_args() | |
| if args.spm_model_path: | |
| try: | |
| import sentencepiece as spm | |
| except ImportError: | |
| raise ImportError("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" | |
| "pip install sentencepiece") | |
| sp_model = spm.SentencePieceProcessor() | |
| sp_model.Load(args.spm_model_path) | |
| vocab = Vocab() | |
| vocab.i2w = {i: sp_model.IdToPiece(i) for i in range(sp_model.GetPieceSize())} | |
| else: | |
| vocab = Vocab() | |
| vocab.load(args.vocab_path) | |
| pretrained_model = torch.load(args.load_model_path) | |
| embedding = pretrained_model["embedding.word.embedding.weight"] | |
| with open(args.word_embedding_path, mode="w", encoding="utf-8") as f: | |
| head = str(list(embedding.size())[0]) + " " + str(list(embedding.size())[1]) + "\n" | |
| f.write(head) | |
| for i in range(len(vocab.i2w)): | |
| word = vocab.i2w[i] | |
| word_embedding = embedding[vocab.get(word), :] | |
| word_embedding = word_embedding.cpu().numpy().tolist() | |
| line = str(word) | |
| for j in range(len(word_embedding)): | |
| line = line + " " + str(word_embedding[j]) | |
| line += "\n" | |
| f.write(line) | |