Spaces:
Runtime error
Runtime error
| from operator import index | |
| import streamlit as st | |
| import logging | |
| import os | |
| from annotated_text import annotation | |
| from json import JSONDecodeError | |
| from markdown import markdown | |
| from utils.config import parser | |
| from utils.haystack import start_document_store, query, initialize_pipeline | |
| from utils.ui import reset_results, set_initial_state | |
| import pandas as pd | |
| import haystack | |
| try: | |
| args = parser.parse_args() | |
| document_store = start_document_store(type=args.store) | |
| st.set_page_config( | |
| page_title="MLReplySearch", | |
| layout="centered", | |
| page_icon=":shark:", | |
| menu_items={ | |
| 'Get Help': 'https://www.extremelycoolapp.com/help', | |
| 'Report a bug': "https://www.extremelycoolapp.com/bug", | |
| 'About': "# This is a header. This is an *extremely* cool app!" | |
| } | |
| ) | |
| st.sidebar.image("ml_logo.png", use_column_width=True) | |
| # Sidebar for Task Selection | |
| st.sidebar.header('Options:') | |
| # OpenAI Key Input | |
| openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password") | |
| if openai_key: | |
| task_options = ['Extractive', 'Generative'] | |
| else: | |
| task_options = ['Extractive'] | |
| task_selection = st.sidebar.radio('Select the task:', task_options) | |
| # Check the task and initialize pipeline accordingly | |
| if task_selection == 'Extractive': | |
| pipeline_extractive = initialize_pipeline("extractive", document_store) | |
| elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it | |
| pipeline_rag = initialize_pipeline("rag", document_store, openai_key=openai_key) | |
| set_initial_state() | |
| st.write('# ' + args.name) | |
| if "question" not in st.session_state: | |
| st.session_state.question = "" | |
| # Search bar | |
| question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results) | |
| run_pressed = st.button("Run") | |
| run_query = ( | |
| run_pressed or question != st.session_state.question #or task_selection != st.session_state.task | |
| ) | |
| # Get results for query | |
| if run_query and question: | |
| if task_selection == 'Extractive': | |
| reset_results() | |
| st.session_state.question = question | |
| with st.spinner("π Running your pipeline"): | |
| try: | |
| st.session_state.results_extractive = query(pipeline_extractive, question) | |
| st.session_state.task = task_selection | |
| except JSONDecodeError as je: | |
| st.error( | |
| "π An error occurred reading the results. Is the document store working?" | |
| ) | |
| except Exception as e: | |
| logging.exception(e) | |
| st.error("π An error occurred during the request.") | |
| elif task_selection == 'Generative': | |
| reset_results() | |
| st.session_state.question = question | |
| with st.spinner("π Running your pipeline"): | |
| try: | |
| st.session_state.results_generative = query(pipeline_rag, question) | |
| st.session_state.task = task_selection | |
| except JSONDecodeError as je: | |
| st.error( | |
| "π An error occurred reading the results. Is the document store working?" | |
| ) | |
| except Exception as e: | |
| if "API key is invalid" in str(e): | |
| logging.exception(e) | |
| st.error("π incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.") | |
| else: | |
| logging.exception(e) | |
| st.error("π An error occurred during the request.") | |
| # Display results | |
| if (st.session_state.results_extractive or st.session_state.results_generative) and run_query: | |
| # Handle Extractive Answers | |
| if task_selection == 'Extractive': | |
| results = st.session_state.results_extractive | |
| st.subheader("Extracted Answers:") | |
| if 'answers' in results: | |
| answers = results['answers'] | |
| treshold = 0.2 | |
| higher_then_treshold = any(ans.score > treshold for ans in answers) | |
| if not higher_then_treshold: | |
| st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True) | |
| for count, answer in enumerate(answers): | |
| if answer.answer: | |
| text, context = answer.answer, answer.context | |
| start_idx = context.find(text) | |
| end_idx = start_idx + len(text) | |
| score = round(answer.score, 3) | |
| st.markdown(f"**Answer {count + 1}:**") | |
| st.markdown( | |
| context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:], | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| st.info( | |
| "π€ Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!" | |
| ) | |
| # Handle Generative Answers | |
| elif task_selection == 'Generative': | |
| results = st.session_state.results_generative | |
| st.subheader("Generated Answer:") | |
| if 'results' in results: | |
| st.markdown("**Answer:**") | |
| st.write(results['results'][0]) | |
| # Handle Retrieved Documents | |
| if 'documents' in results: | |
| retrieved_documents = results['documents'] | |
| st.subheader("Retriever Results:") | |
| data = [] | |
| for i, document in enumerate(retrieved_documents): | |
| # Truncate the content | |
| truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content | |
| data.append([i + 1, document.meta['name'], truncated_content]) | |
| # Convert data to DataFrame and display using Streamlit | |
| df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content']) | |
| st.table(df) | |
| except SystemExit as e: | |
| os._exit(e.code) | |