zhangfan commited on
Commit ·
e87e483
1
Parent(s): 2f72338
Update README.md
Browse files
README.md
CHANGED
|
@@ -58,6 +58,7 @@ model = AutoModelForSequenceClassification.from_pretrained("Fan-s/reddit-tc-bert
|
|
| 58 |
post = "don't make gravy with asbestos."
|
| 59 |
response = "i'd expect someone with a culinary background to know that. since we're talking about school dinner ladies, they need to learn this pronto."
|
| 60 |
|
|
|
|
| 61 |
def predict(post, response, max_seq_length=128):
|
| 62 |
with torch.no_grad():
|
| 63 |
args = (post, response)
|
|
@@ -67,8 +68,8 @@ def predict(post, response, max_seq_length=128):
|
|
| 67 |
item = torch.argmax(logits, dim=1)
|
| 68 |
predict_label = label_list[item]
|
| 69 |
return predict_label, logits
|
| 70 |
-
|
| 71 |
-
# predict whether the two sentences match
|
| 72 |
predict_label, logits = predict(post, response)
|
|
|
|
| 73 |
print("predict_label:", predict_label)
|
| 74 |
```
|
|
|
|
| 58 |
post = "don't make gravy with asbestos."
|
| 59 |
response = "i'd expect someone with a culinary background to know that. since we're talking about school dinner ladies, they need to learn this pronto."
|
| 60 |
|
| 61 |
+
# Predict whether the two sentences match
|
| 62 |
def predict(post, response, max_seq_length=128):
|
| 63 |
with torch.no_grad():
|
| 64 |
args = (post, response)
|
|
|
|
| 68 |
item = torch.argmax(logits, dim=1)
|
| 69 |
predict_label = label_list[item]
|
| 70 |
return predict_label, logits
|
| 71 |
+
|
|
|
|
| 72 |
predict_label, logits = predict(post, response)
|
| 73 |
+
# Matched
|
| 74 |
print("predict_label:", predict_label)
|
| 75 |
```
|