Update app.py
Browse files
app.py
CHANGED
|
@@ -108,103 +108,13 @@ def build_experimental_ui():
|
|
| 108 |
|
| 109 |
button_query = st.button('Submit', disabled=False)
|
| 110 |
|
| 111 |
-
|
| 112 |
if button_query:
|
| 113 |
|
| 114 |
for question in questions_df['question']:
|
| 115 |
instruction = f'{prompt}.Question:{question}'
|
| 116 |
-
|
| 117 |
-
print('---- run query ----')
|
| 118 |
-
print(f'model: {selected_model} embeddings: {selected_embeddings}')
|
| 119 |
-
if selected_embeddings!=st.session_state['selected_embeddings']:
|
| 120 |
-
st.session_state['selected_embeddings'] = selected_embeddings
|
| 121 |
-
texts = load_pdf_document(pdf_docs)
|
| 122 |
-
st.session_state['retriever'] = get_retriever_from_text(texts, embeddings[selected_embeddings])
|
| 123 |
-
# qa = RetrievalQA.from_chain_type(llm=models[selected_model], chain_type="stuff",
|
| 124 |
-
# retriever=st.session_state['retriever'], return_source_documents=True)
|
| 125 |
-
st.session_state['docs'] = st.session_state['retriever'].get_relevant_documents(st.session_state.query)
|
| 126 |
-
context = '\n\n'.join([doc.page_content for doc in st.session_state['docs']])
|
| 127 |
-
st.session_state['context'] = context
|
| 128 |
-
source_files = get_pdf_file_names(st.session_state['pdf_file'])
|
| 129 |
-
#st.session_state['conversation']= get_conversation_chain(st.session_state['retriever'])
|
| 130 |
-
|
| 131 |
-
if strategy=='Without Chain-of-Thought':
|
| 132 |
-
user_token = model_configs[selected_model]['USER_TOKEN']
|
| 133 |
-
end_token = model_configs[selected_model]['END_TOKEN']
|
| 134 |
-
assistant_token = model_configs[selected_model]['ASSISTANT_TOKEN']
|
| 135 |
-
prompt_pattern, prompt = create_prompt(user_token, instruction, st.session_state.query, end_token, assistant_token, context)
|
| 136 |
-
updated_context = truncate_context(prompt_pattern, context,
|
| 137 |
-
max_token_len=model_configs[selected_model]['MAX_TOKENS'],
|
| 138 |
-
max_new_token_length=model_configs[selected_model]['MAX_NEW_TOKEN_LENGTH'])
|
| 139 |
-
updated_prompt = prompt_pattern.replace('{context}', updated_context)
|
| 140 |
-
print(updated_prompt)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
with st.spinner():
|
| 144 |
-
answer = models[selected_model].generate([updated_prompt]).generations[0][0].text.strip()
|
| 145 |
-
st.write(answer)
|
| 146 |
-
chat_content['answer'] = answer
|
| 147 |
-
chat_content['source'] = source_files
|
| 148 |
-
chat_content['context']=st.session_state['context']
|
| 149 |
-
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
|
| 150 |
-
if st.session_state['chat_history']:
|
| 151 |
-
st.session_state['chat_history'].append(chat_content)
|
| 152 |
-
else:
|
| 153 |
-
st.session_state['chat_history']=[chat_content]
|
| 154 |
-
print('------chat history-----',st.session_state['chat_history'])
|
| 155 |
-
if updated_prompt!=prompt:
|
| 156 |
-
st.caption(f"Note: The context has been truncated to fit model max tokens of {model_configs[selected_model]['MAX_TOKENS']}. Original context contains {len(context.split())} words. Truncated context contains {len(updated_context.split())} words.")
|
| 157 |
|
| 158 |
-
|
| 159 |
-
chain = PDSCoverageChain()
|
| 160 |
-
with st.spinner():
|
| 161 |
-
answer = chain.generate(models[selected_model], model_configs[selected_model], st.session_state.query, context)
|
| 162 |
-
st.write(answer)
|
| 163 |
-
chat_content['answer'] = answer
|
| 164 |
-
chat_content['source'] = source_files
|
| 165 |
-
chat_content['context']=st.session_state['context']
|
| 166 |
-
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
|
| 167 |
-
if st.session_state['chat_history']:
|
| 168 |
-
st.session_state['chat_history'].append(chat_content)
|
| 169 |
-
else:
|
| 170 |
-
st.session_state['chat_history']=[chat_content]
|
| 171 |
-
print('------chat history-----',st.session_state['chat_history'])
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
if st.session_state['docs']:
|
| 175 |
-
|
| 176 |
-
docs = st.session_state['docs']
|
| 177 |
|
| 178 |
-
col3, col4, col5, col6 = st.columns([0.2,0.35, 0.65, 3.8])
|
| 179 |
-
if st.session_state.query is None:
|
| 180 |
-
disable_query = True
|
| 181 |
-
else:
|
| 182 |
-
disable_query = False
|
| 183 |
-
chat_history = st.session_state['chat_history']
|
| 184 |
-
with col3:
|
| 185 |
-
st.button(":thumbsup:", on_click = get_feedback,disabled=disable_query,
|
| 186 |
-
kwargs=dict(upvote=True, downvote=False,
|
| 187 |
-
button='upvote'))
|
| 188 |
-
with col4:
|
| 189 |
-
st.button(":thumbsdown:", on_click = get_feedback,disabled=disable_query,
|
| 190 |
-
kwargs=dict(upvote=False, downvote=True,
|
| 191 |
-
button='downvote'))
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
with st.expander("References"):
|
| 195 |
-
for doc in docs:
|
| 196 |
-
print('-------',doc)
|
| 197 |
-
#st.markdown('###### Page {}'.format(doc.metadata['page']))
|
| 198 |
-
st.write(doc.page_content.replace('\n','\n\n').replace('$','\$').replace('**',''))
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
st.button("End Chat", on_click = get_feedback,
|
| 203 |
-
kwargs=dict(button='end-chat',
|
| 204 |
-
chat_history=chat_history))
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
|
| 209 |
else:
|
| 210 |
st.info("Under Development")
|
|
|
|
| 108 |
|
| 109 |
button_query = st.button('Submit', disabled=False)
|
| 110 |
|
|
|
|
| 111 |
if button_query:
|
| 112 |
|
| 113 |
for question in questions_df['question']:
|
| 114 |
instruction = f'{prompt}.Question:{question}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
else:
|
| 120 |
st.info("Under Development")
|