MRC001 / inference.py
combi2k2's picture
Update inference.py
eb4f615
raw
history blame contribute delete
853 Bytes
from transformers import pipeline
import string
# Replace this with your own checkpoint
model_checkpoint = "results/checkpoint-16000"
question_answerer = pipeline("question-answering", model = model_checkpoint)
def predict(question, context):
answer = question_answerer(question = question,
context = context)
exclude = set(string.punctuation)
text = answer['answer']
text = ''.join(ch for ch in text if ch not in exclude)
answer['answer'] = text
return answer
if __name__ == '__main__':
question = 'Combi cao bao nhiêu.'
context = 'Combi là sinh viên năm 2 trường Ecole Polytechnique. Chiều cao của Combi là 1m73, cân nặng là 63kg. Combi thích học Machine Learning vì Machine Learning cần nhiều toán.'
print(predict(question, context))