Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline, | |
| get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES, | |
| PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS) | |
| st.set_page_config( | |
| page_title="Retrieval Augmentation with Haystack", | |
| ) | |
| st.markdown("<center> <h2> Reduce Hallucinations with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True) | |
| st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True) | |
| # if not st.session_state.get('pipelines_loaded', False): | |
| # with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| # p1, p2, p3 = app_init() | |
| # st.success('Pipelines are loaded', icon="β ") | |
| # st.session_state['pipelines_loaded'] = True | |
| placeholder = st.empty() | |
| with placeholder: | |
| search_bar, button = st.columns([3, 1]) | |
| with search_bar: | |
| username = st.text_area(f" ", max_chars=200, key='query') | |
| with button: | |
| st.write(" ") | |
| st.write(" ") | |
| run_pressed = st.button("Run") | |
| st.markdown("<center> <h5> Example questions </h5> </center>", unsafe_allow_html=True) | |
| st.write(" ") | |
| st.write(" ") | |
| c1, c2, c3, c4, c5 = st.columns(5) | |
| with c1: | |
| st.button(QUERIES[0], on_click=set_q1) | |
| with c2: | |
| st.button(QUERIES[1], on_click=set_q2) | |
| with c3: | |
| st.button(QUERIES[2], on_click=set_q3) | |
| with c4: | |
| st.button(QUERIES[3], on_click=set_q4) | |
| with c5: | |
| st.button(QUERIES[4], on_click=set_q5) | |
| st.write(" ") | |
| st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type") | |
| # st.sidebar.selectbox( | |
| # "Example Questions:", | |
| # QUERIES, | |
| # key='q_drop_down', on_change=set_question) | |
| st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True) | |
| placeholder_plain_gpt = st.empty() | |
| st.text(" ") | |
| st.text(" ") | |
| if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)": | |
| st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True) | |
| else: | |
| st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True) | |
| placeholder_retrieval_augmented = st.empty() | |
| if st.session_state.get('query') and run_pressed: | |
| ip = st.session_state['query'] | |
| with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| p1 = get_plain_pipeline() | |
| with st.spinner('Fetching answers from plain GPT... ' | |
| '\n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| answers = p1.run(ip) | |
| placeholder_plain_gpt.markdown(answers['results'][0]) | |
| if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented": | |
| with st.spinner( | |
| 'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... ' | |
| '\n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| p2 = get_retrieval_augmented_pipeline() | |
| with st.spinner('Getting relevant documents from documented stores and calculating answers... ' | |
| '\n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| answers_2 = p2.run(ip) | |
| else: | |
| with st.spinner( | |
| 'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \ | |
| n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| p3 = get_web_retrieval_augmented_pipeline() | |
| with st.spinner('Getting relevant documents from the Web and calculating answers... ' | |
| '\n This may take a few mins and might also fail if OpenAI API server is down.'): | |
| answers_2 = p3.run(ip) | |
| placeholder_retrieval_augmented.markdown(answers_2['results'][0]) | |