Spaces:
Runtime error
Runtime error
Commit ·
c1b0837
1
Parent(s): db188a6
Testing indicberta
Browse files- apps/inference.py +3 -11
apps/inference.py
CHANGED
|
@@ -13,21 +13,13 @@ def load_model(masked_text, model_name):
|
|
| 13 |
from_flax = False
|
| 14 |
if model_name == "flax-community/roberta-hindi":
|
| 15 |
from_flax = True
|
| 16 |
-
|
| 17 |
-
if model_name == "neuralspace-reverie/indic-transformers-hi-bert":
|
| 18 |
-
masked_text = masked_text.replace("<mask>", "[MASK]")
|
| 19 |
-
|
| 20 |
-
elif model_name == "ai4bharat/indic-bert":
|
| 21 |
-
masked_text = masked_text.replace("<mask>", "([MASK])")
|
| 22 |
-
|
| 23 |
-
st.write(model_name, masked_text)
|
| 24 |
-
|
| 25 |
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
|
| 26 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
| 27 |
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
| 28 |
-
|
| 29 |
result_sentence = nlp(masked_text)
|
| 30 |
-
|
| 31 |
return result_sentence
|
| 32 |
|
| 33 |
|
|
|
|
| 13 |
from_flax = False
|
| 14 |
if model_name == "flax-community/roberta-hindi":
|
| 15 |
from_flax = True
|
| 16 |
+
# st.write(model_name, masked_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
|
| 18 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 19 |
+
MASK_TOKEN = tokenizer.mask_token
|
| 20 |
+
masked_text = masked_text.replcae("<mask>", MASK_TOKEN)
|
| 21 |
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
|
|
|
| 22 |
result_sentence = nlp(masked_text)
|
|
|
|
| 23 |
return result_sentence
|
| 24 |
|
| 25 |
|