File size: 853 Bytes
846f8be f6dd13c eb4f615 f6dd13c eb4f615 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from transformers import pipeline
import string
# Replace this with your own checkpoint
model_checkpoint = "results/checkpoint-16000"
question_answerer = pipeline("question-answering", model = model_checkpoint)
def predict(question, context):
answer = question_answerer(question = question,
context = context)
exclude = set(string.punctuation)
text = answer['answer']
text = ''.join(ch for ch in text if ch not in exclude)
answer['answer'] = text
return answer
if __name__ == '__main__':
question = 'Combi cao bao nhiêu.'
context = 'Combi là sinh viên năm 2 trường Ecole Polytechnique. Chiều cao của Combi là 1m73, cân nặng là 63kg. Combi thích học Machine Learning vì Machine Learning cần nhiều toán.'
print(predict(question, context)) |