|
|
from transformers import pipeline |
|
|
import string |
|
|
|
|
|
|
|
|
model_checkpoint = "results/checkpoint-16000" |
|
|
question_answerer = pipeline("question-answering", model = model_checkpoint) |
|
|
|
|
|
def predict(question, context): |
|
|
answer = question_answerer(question = question, |
|
|
context = context) |
|
|
|
|
|
exclude = set(string.punctuation) |
|
|
|
|
|
text = answer['answer'] |
|
|
text = ''.join(ch for ch in text if ch not in exclude) |
|
|
|
|
|
answer['answer'] = text |
|
|
|
|
|
return answer |
|
|
|
|
|
if __name__ == '__main__': |
|
|
question = 'Combi cao bao nhiêu.' |
|
|
context = 'Combi là sinh viên năm 2 trường Ecole Polytechnique. Chiều cao của Combi là 1m73, cân nặng là 63kg. Combi thích học Machine Learning vì Machine Learning cần nhiều toán.' |
|
|
|
|
|
print(predict(question, context)) |