IAMJB commited on
Commit
927fbc8
·
verified ·
1 Parent(s): 2f98b6b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -3
README.md CHANGED
@@ -48,10 +48,17 @@ inputs = tokenizer(
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
  logits = outputs.logits
51
- pred_class = torch.argmax(logits, dim=-1)
 
52
 
53
- print("logits shape:", logits.shape) # [batch_size, num_labels]
54
- print("pred classes:", pred_class.tolist())
 
 
 
 
 
 
55
  ```
56
 
57
 
 
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
  logits = outputs.logits
51
+ probs = torch.sigmoid(logits)
52
+ pred_mask = probs > 0.5
53
 
54
+ print("logits:", logits)
55
+ print("logits shape:", logits.shape)
56
+ print("probs over 0.5:", probs > 0.5) # [batch_size, num_labels]
57
+ print("pred label mask:", pred_mask.tolist())
58
+ print(
59
+ "pred label indices:",
60
+ [[i for i, on in enumerate(row) if on] for row in pred_mask.tolist()],
61
+ )
62
  ```
63
 
64