lukecq commited on
Commit
9ab5792
·
1 Parent(s): 7599afa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -7
README.md CHANGED
@@ -48,13 +48,11 @@ def add_prefix(text, list_label, label_num = 20, shuffle = False):
48
 
49
  text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
50
 
51
- ids = tokenizer.encode(text_new)
52
- tokens = tokenizer.convert_ids_to_tokens(ids)
53
- encoding = tokenizer([text],truncation=True, padding='max_length',max_length=512)
54
- item = {key: torch.tensor(val) for key, val in encoding.items()}
55
- logits = model(**item).logits
56
- probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
57
- predictions = torch.argmax(logits, dim=-1)
58
  ```
59
 
60
 
 
48
 
49
  text_new, list_label_new = add_prefix(text,list_label,shuffle=False)
50
 
51
+ encoding = tokenizer([text_new],truncation=True, padding='max_length',max_length=512, return_tensors='pt')
52
+ with torch.no_grad():
53
+ logits = model(**item).logits
54
+ probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
55
+ predictions = torch.argmax(logits, dim=-1)
 
 
56
  ```
57
 
58