File size: 4,579 Bytes
c226c6e
 
 
444204c
 
c226c6e
 
 
444204c
c226c6e
 
 
 
 
 
 
444204c
c226c6e
 
7f0063b
c226c6e
 
359f142
 
c226c6e
 
 
359f142
c226c6e
 
 
 
3729140
c226c6e
 
 
 
 
 
 
 
 
 
87f0cb8
c226c6e
 
 
 
 
a79f0de
c226c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bda394
 
c226c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bd4359
8bda394
59f97fb
1bd4359
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os, time, transformers
import streamlit as st
from model import MRCQuestionAnswering
from relevance_ranking import rel_ranking
from huggingface_hub import login
from infer import *
from gg_search import GoogleSearch, getContent
ggsearch = GoogleSearch()



class Chatbot():
    def __init__(self):
        st.header('🦜 Question answering')
        st.warning("Warning: the processing may take long cause I have no any GPU now...")
        st.info("This app uses google search engine for each input question...")
        st.info("About me: namnh113")        
        
        self.API_KEY = st.sidebar.text_input(
            'API key (not necessary for now)',
            type='password',
            help="Type in your HuggingFace API key to use this app")
        
        self.model_checkpoint = 'namnh113/vi-mrc-large'

        self.checkpoint = st.sidebar.selectbox(
            label = "Choose model",
            options = [self.model_checkpoint],
            help="List available model to predict"
            )


    def generate_response(self, question):
        try:
            links, documents = ggsearch.search(question)
            if not documents:
                try:
                    for url in links:
                        docs = getContent(url)
                        if len(docs) > 20 and 'The security system for this website has been triggered. Completing the challenge below verifies you are a human and gives you access.' not in doc:
                            documents += [docs]
                except:
                    pass
        except:
            pass
        passages = rel_ranking(question, documents)
        # get top 40 relevant passages
        passages = '. '.join([p.replace('\n',', ') for p in passages[:40]])
        QA_input = {
            'question': question,
            'context': passages }
        if len(QA_input['question'].strip()) > 0:
            start = time.time()
            inputs = [tokenize_function(QA_input, tokenizer)]
            inputs_ids = data_collator(inputs, tokenizer)
            outputs = model(**inputs_ids)
            answer = extract_answer(inputs, outputs, tokenizer)[0]
            during = time.time() - start
            print("answer: {}. \nScore start: {}, Score end: {}, Time: {}".format(answer['answer'],
                                                                      answer['score_start'],
                                                                      answer['score_end'], during))
            answer = ' '.join([_.strip() for _ in answer['answer'].split()])
        return answer if answer else 'No answer found !'
        
    
    def form_data(self):
        # with st.form('my_form'):
            try:
                # if not self.API_KEY.startswith('hf_'):
                #     st.warning('Please enter your API key!', icon='⚠')
                
                
                if "messages" not in st.session_state:
                    st.session_state.messages = []

                st.write(f"You are using {self.checkpoint} model")

                for message in st.session_state.messages:
                    with st.chat_message(message.get('role')):
                        st.write(message.get("content"))
                text = st.chat_input(disabled=False)
                
                if text:
                    st.session_state.messages.append(
                        {
                            "role":"user",
                            "content": text
                        }
                    )
                    with st.chat_message("user"):
                        st.write(text)
                    
                    if text.lower() == "clear":
                        del st.session_state.messages
                        return
                        
                    result = self.generate_response(text)
                    st.session_state.messages.append(
                        {
                            "role": "assistant",
                            "content": result
                        }
                    )
                    with st.chat_message('assistant'):
                        st.markdown(result)
                
            except Exception as e:
                st.error(e, icon="🚨")

chatbot = Chatbot()
login(token=os.environ['hf_api_key'])
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = transformers.AutoTokenizer.from_pretrained(chatbot.model_checkpoint)
model = MRCQuestionAnswering.from_pretrained(chatbot.model_checkpoint)
chatbot.form_data()