Update README.md
Browse files
README.md
CHANGED
|
@@ -39,22 +39,21 @@ def predict_dialect(sent):
|
|
| 39 |
CAMeL Tools MADAR 6 DID model"""
|
| 40 |
|
| 41 |
predictions = DID.predict([sent])
|
|
|
|
| 42 |
|
| 43 |
if predictions[0].top != "MSA":
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
name = highest[0]
|
| 48 |
-
score = highest[1]
|
| 49 |
-
|
| 50 |
else:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
-
return
|
| 58 |
|
| 59 |
tokenizer = AutoTokenizer.from_pretrained('CAMeL-Lab/arat5-coda-did')
|
| 60 |
model = AutoModelForSeq2SeqLM.from_pretrained('CAMeL-Lab/arat5-coda-did')
|
|
|
|
| 39 |
CAMeL Tools MADAR 6 DID model"""
|
| 40 |
|
| 41 |
predictions = DID.predict([sent])
|
| 42 |
+
scores = predictions[0].scores
|
| 43 |
|
| 44 |
if predictions[0].top != "MSA":
|
| 45 |
+
# get the highest pred
|
| 46 |
+
pred = sorted(scores.items(),
|
| 47 |
+
key=lambda x: x[1], reverse=True)[0]
|
|
|
|
|
|
|
|
|
|
| 48 |
else:
|
| 49 |
+
# get the second highest pred
|
| 50 |
+
pred = sorted(scores.items(),
|
| 51 |
+
key=lambda x: x[1], reverse=True)[1]
|
| 52 |
+
|
| 53 |
+
dialect = pred[0]
|
| 54 |
+
score = pred[1]
|
| 55 |
|
| 56 |
+
return dialect, score
|
| 57 |
|
| 58 |
tokenizer = AutoTokenizer.from_pretrained('CAMeL-Lab/arat5-coda-did')
|
| 59 |
model = AutoModelForSeq2SeqLM.from_pretrained('CAMeL-Lab/arat5-coda-did')
|