File size: 853 Bytes
846f8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6dd13c
 
 
eb4f615
 
f6dd13c
eb4f615
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import pipeline
import string

# Replace this with your own checkpoint
model_checkpoint = "results/checkpoint-16000"
question_answerer = pipeline("question-answering", model = model_checkpoint)

def predict(question, context):
    answer = question_answerer(question = question,
                               context  = context)
    
    exclude = set(string.punctuation)
    
    text = answer['answer']
    text = ''.join(ch for ch in text if ch not in exclude)
    
    answer['answer'] = text

    return  answer
 
 if __name__ == '__main__':
    question = 'Combi cao bao nhiêu.'
    context  = 'Combi là sinh viên năm 2 trường Ecole Polytechnique. Chiều cao của Combi là 1m73, cân nặng là 63kg. Combi thích học Machine Learning vì Machine Learning cần nhiều toán.'

    print(predict(question, context))