Update README.md
Browse files
README.md
CHANGED
|
@@ -54,7 +54,6 @@ language:
|
|
| 54 |
truncation=True
|
| 55 |
)
|
| 56 |
|
| 57 |
-
# ... 其他代码不变
|
| 58 |
|
| 59 |
# 取出输入对应的编码
|
| 60 |
input_ids = encoding['input_ids'].to(device)
|
|
@@ -62,8 +61,8 @@ language:
|
|
| 62 |
|
| 63 |
# 不计算梯度
|
| 64 |
with torch.no_grad():
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
# 使用softmax将logits转换为概率
|
| 69 |
probs = softmax(outputs.logits, dim=1)
|
|
|
|
| 54 |
truncation=True
|
| 55 |
)
|
| 56 |
|
|
|
|
| 57 |
|
| 58 |
# 取出输入对应的编码
|
| 59 |
input_ids = encoding['input_ids'].to(device)
|
|
|
|
| 61 |
|
| 62 |
# 不计算梯度
|
| 63 |
with torch.no_grad():
|
| 64 |
+
# 产生情感预测的logits
|
| 65 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 66 |
|
| 67 |
# 使用softmax将logits转换为概率
|
| 68 |
probs = softmax(outputs.logits, dim=1)
|