Update README.md
Browse files
README.md
CHANGED
|
@@ -30,16 +30,17 @@ label2diacritic = {0: 'ู', 1: 'ู', 2: 'ู', 3: 'ู', 4: ''}
|
|
| 30 |
|
| 31 |
def arabic2diacritics(text, model, tokenizer):
|
| 32 |
tokens = tokenizer(text, return_tensors="pt")
|
| 33 |
-
preds = (model(**tokens).logits.sigmoid() > 0.5)[0]
|
| 34 |
new_text = []
|
| 35 |
for p, c in zip(preds, text):
|
|
|
|
| 36 |
for i in range(1, 5):
|
| 37 |
if p[i]:
|
| 38 |
new_text.append(label2diacritic[i])
|
| 39 |
# check shadda last
|
| 40 |
if p[0]:
|
| 41 |
new_text.append(label2diacritic[0])
|
| 42 |
-
|
| 43 |
new_text = "".join(new_text)
|
| 44 |
return new_text
|
| 45 |
|
|
|
|
| 30 |
|
| 31 |
def arabic2diacritics(text, model, tokenizer):
|
| 32 |
tokens = tokenizer(text, return_tensors="pt")
|
| 33 |
+
preds = (model(**tokens).logits.sigmoid() > 0.5)[0][1:-1] # remove CLS and SEP
|
| 34 |
new_text = []
|
| 35 |
for p, c in zip(preds, text):
|
| 36 |
+
new_text.append(c)
|
| 37 |
for i in range(1, 5):
|
| 38 |
if p[i]:
|
| 39 |
new_text.append(label2diacritic[i])
|
| 40 |
# check shadda last
|
| 41 |
if p[0]:
|
| 42 |
new_text.append(label2diacritic[0])
|
| 43 |
+
|
| 44 |
new_text = "".join(new_text)
|
| 45 |
return new_text
|
| 46 |
|