Spaces:
Runtime error
Runtime error
| """Visualize some sense vectors""" | |
| import torch | |
| import argparse | |
| import transformers | |
| def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None): | |
| """ | |
| Prints out the top-scoring words (and lowest-scoring words) for each sense. | |
| """ | |
| if contents is None: | |
| print(word) | |
| token_id = tokenizer(word)['input_ids'][0] | |
| contents = vecs[token_id] # torch.Size([16, 768]) | |
| for i in range(contents.shape[0]): | |
| print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i)) | |
| logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257] | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| print('~~~Positive~~~') | |
| for j in range(count): | |
| print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item())) | |
| print('~~~Negative~~~') | |
| for j in range(count): | |
| print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item())) | |
| return contents | |
| print() | |
| print() | |
| print() | |
| argp = argparse.ArgumentParser() | |
| argp.add_argument('vecs_path') | |
| argp.add_argument('lm_head_path') | |
| args = argp.parse_args() | |
| # Load tokenizer and parameters | |
| tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') | |
| vecs = torch.load(args.vecs_path) | |
| lm_head = torch.load(args.lm_head_path) | |
| visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5) | |