Update README.md
Browse files
README.md
CHANGED
|
@@ -48,10 +48,17 @@ inputs = tokenizer(
|
|
| 48 |
with torch.no_grad():
|
| 49 |
outputs = model(**inputs)
|
| 50 |
logits = outputs.logits
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
-
print("logits
|
| 54 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
```
|
| 56 |
|
| 57 |
|
|
|
|
| 48 |
with torch.no_grad():
|
| 49 |
outputs = model(**inputs)
|
| 50 |
logits = outputs.logits
|
| 51 |
+
probs = torch.sigmoid(logits)
|
| 52 |
+
pred_mask = probs > 0.5
|
| 53 |
|
| 54 |
+
print("logits:", logits)
|
| 55 |
+
print("logits shape:", logits.shape)
|
| 56 |
+
print("probs over 0.5:", probs > 0.5) # [batch_size, num_labels]
|
| 57 |
+
print("pred label mask:", pred_mask.tolist())
|
| 58 |
+
print(
|
| 59 |
+
"pred label indices:",
|
| 60 |
+
[[i for i, on in enumerate(row) if on] for row in pred_mask.tolist()],
|
| 61 |
+
)
|
| 62 |
```
|
| 63 |
|
| 64 |
|