Spaces:
Sleeping
Sleeping
| import os, time, transformers | |
| import streamlit as st | |
| from model import MRCQuestionAnswering | |
| from relevance_ranking import rel_ranking | |
| from huggingface_hub import login | |
| from infer import * | |
| from gg_search import GoogleSearch, getContent | |
| ggsearch = GoogleSearch() | |
| class Chatbot(): | |
| def __init__(self): | |
| st.header('🦜 Question answering') | |
| st.warning("Warning: the processing may take long cause I have no any GPU now...") | |
| st.info("This app uses google search engine for each input question...") | |
| st.info("About me: namnh113") | |
| self.API_KEY = st.sidebar.text_input( | |
| 'API key (not necessary for now)', | |
| type='password', | |
| help="Type in your HuggingFace API key to use this app") | |
| self.model_checkpoint = 'namnh113/vi-mrc-large' | |
| self.checkpoint = st.sidebar.selectbox( | |
| label = "Choose model", | |
| options = [self.model_checkpoint], | |
| help="List available model to predict" | |
| ) | |
| def generate_response(self, question): | |
| try: | |
| links, documents = ggsearch.search(question) | |
| if not documents: | |
| try: | |
| for url in links: | |
| docs = getContent(url) | |
| if len(docs) > 20 and 'The security system for this website has been triggered. Completing the challenge below verifies you are a human and gives you access.' not in doc: | |
| documents += [docs] | |
| except: | |
| pass | |
| except: | |
| pass | |
| passages = rel_ranking(question, documents) | |
| # get top 40 relevant passages | |
| passages = '. '.join([p.replace('\n',', ') for p in passages[:40]]) | |
| QA_input = { | |
| 'question': question, | |
| 'context': passages } | |
| if len(QA_input['question'].strip()) > 0: | |
| start = time.time() | |
| inputs = [tokenize_function(QA_input, tokenizer)] | |
| inputs_ids = data_collator(inputs, tokenizer) | |
| outputs = model(**inputs_ids) | |
| answer = extract_answer(inputs, outputs, tokenizer)[0] | |
| during = time.time() - start | |
| print("answer: {}. \nScore start: {}, Score end: {}, Time: {}".format(answer['answer'], | |
| answer['score_start'], | |
| answer['score_end'], during)) | |
| answer = ' '.join([_.strip() for _ in answer['answer'].split()]) | |
| return answer if answer else 'No answer found !' | |
| def form_data(self): | |
| # with st.form('my_form'): | |
| try: | |
| # if not self.API_KEY.startswith('hf_'): | |
| # st.warning('Please enter your API key!', icon='⚠') | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| st.write(f"You are using {self.checkpoint} model") | |
| for message in st.session_state.messages: | |
| with st.chat_message(message.get('role')): | |
| st.write(message.get("content")) | |
| text = st.chat_input(disabled=False) | |
| if text: | |
| st.session_state.messages.append( | |
| { | |
| "role":"user", | |
| "content": text | |
| } | |
| ) | |
| with st.chat_message("user"): | |
| st.write(text) | |
| if text.lower() == "clear": | |
| del st.session_state.messages | |
| return | |
| result = self.generate_response(text) | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "content": result | |
| } | |
| ) | |
| with st.chat_message('assistant'): | |
| st.markdown(result) | |
| except Exception as e: | |
| st.error(e, icon="🚨") | |
| chatbot = Chatbot() | |
| login(token=os.environ['hf_api_key']) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(chatbot.model_checkpoint) | |
| model = MRCQuestionAnswering.from_pretrained(chatbot.model_checkpoint) | |
| chatbot.form_data() |