Spaces:
Sleeping
Sleeping
| import re | |
| import numpy as np | |
| import openai | |
| import streamlit_scrollable_textbox as stx | |
| import pinecone | |
| import streamlit as st | |
| st.set_page_config(layout="wide") # isort: split | |
| from utils import nltkmodules | |
| from utils.entity_extraction import ( | |
| extract_entities_docs, | |
| year_quarter_range, | |
| clean_companies, | |
| ticker_year_quarter_tuples_creator, | |
| extract_entities_keywords, | |
| clean_keywords_all_combs, | |
| ) | |
| from utils.models import ( | |
| get_alpaca_model, | |
| get_vicuna_ner_1_model, | |
| get_vicuna_ner_2_model, | |
| get_vicuna_text_gen_model, | |
| get_data, | |
| get_instructor_embedding_model_api, | |
| gpt_turbo_model, | |
| vicuna_text_generate, | |
| save_key, | |
| ) | |
| from utils.prompts import ( | |
| generate_prompt_alpaca_style, | |
| generate_multi_doc_context, | |
| ) | |
| from utils.retriever import ( | |
| query_pinecone, | |
| sentence_id_combine, | |
| get_indices_bm25, | |
| ) | |
| from utils.transcript_retrieval import retrieve_transcript | |
| from utils.vector_index import create_dense_embeddings | |
| st.title("Question Answering on Earnings Call Transcripts") | |
| st.write( | |
| "The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." | |
| ) | |
| # Caching Resources and Model APIs | |
| data = get_data() | |
| alpaca_model = get_alpaca_model() | |
| vicuna_ner_1_model = get_vicuna_ner_1_model() | |
| vicuna_ner_2_model = get_vicuna_ner_2_model() | |
| vicuna_text_gen_model = get_vicuna_text_gen_model() | |
| # Sidebar Options | |
| decoder_models_choice = ["GPT-3.5 Turbo", "Vicuna-7B"] | |
| with st.sidebar: | |
| st.subheader("Select Options:") | |
| use_bm25 = st.checkbox("Use 2-Stage Retrieval (BM25)", value=True) | |
| use_keyword_matching = st.checkbox( | |
| "Use Exact Keyword Matching", value=False | |
| ) | |
| num_results = int( | |
| st.number_input("Number of Results to query", 1, 15, value=4) | |
| ) | |
| window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) | |
| threshold = float( | |
| st.number_input( | |
| label="Similarity Score Threshold", | |
| step=0.05, | |
| format="%.2f", | |
| value=0.6, | |
| ) | |
| ) | |
| num_candidates = int( | |
| st.number_input( | |
| "Number of Candidates to Generate:", | |
| 25, | |
| 200, | |
| step=25, | |
| value=50, | |
| ) | |
| ) | |
| col1, col2 = st.columns([3, 3], gap="medium") | |
| with col1: | |
| query_text = st.text_area( | |
| "Input Query", | |
| value="How has the growth been for AMD in the PC market in Q1 and Q2 2020?", | |
| ) | |
| # Extracting Document Entities from Question | |
| ( | |
| companies, | |
| start_quarter, | |
| start_year, | |
| end_quarter, | |
| end_year, | |
| ) = extract_entities_docs(query_text, vicuna_ner_1_model) | |
| year_quarter_range_list = year_quarter_range( | |
| start_quarter, start_year, end_quarter, end_year | |
| ) | |
| ticker_list = clean_companies(companies) | |
| ticker_year_quarter_tuples_list = ticker_year_quarter_tuples_creator( | |
| ticker_list, year_quarter_range_list | |
| ) | |
| with col2: | |
| if ticker_year_quarter_tuples_list != []: | |
| st.markdown("**Companies mentioned in the question:**") | |
| for i in ticker_list: | |
| st.markdown("- " + i) | |
| st.write("**Duration:**") | |
| st.write(f"{start_quarter} {start_year} - {end_quarter} {end_year}") | |
| # Extract keywords from query | |
| all_keywords = extract_entities_keywords(query_text, vicuna_ner_2_model) | |
| if all_keywords != []: | |
| keywords = clean_keywords_all_combs(all_keywords) | |
| store_keywords = keywords.copy() | |
| else: | |
| keywords = None | |
| # Setting Keywords to None if use_keywords is False | |
| if use_keyword_matching == True: | |
| keywords = store_keywords | |
| else: | |
| keywords = None | |
| # Connect to PineCone Vector Database - Instructor Model | |
| pinecone.init( | |
| api_key=st.secrets["pinecone_instructor"], | |
| environment="us-west4-gcp-free", | |
| ) | |
| pinecone_index_name = "week13-instructor-xl" | |
| pinecone_index = pinecone.Index(pinecone_index_name) | |
| retriever_model = get_instructor_embedding_model_api() | |
| instruction = "Represent the finance query for retrieving related documents:" | |
| dense_query_embedding = create_dense_embeddings( | |
| query_text, retriever_model, instruction | |
| ) | |
| context_group = [] | |
| if ticker_year_quarter_tuples_list != []: | |
| for ticker, quarter, year in ticker_year_quarter_tuples_list: | |
| if use_bm25 == True: | |
| # Setting Ticker, Quarter, Year=None to trigger global bm25 | |
| indices = get_indices_bm25( | |
| data, query_text, None, None, None, num_candidates | |
| ) | |
| else: | |
| indices = None | |
| query_results = query_pinecone( | |
| dense_query_embedding, | |
| num_results, | |
| pinecone_index, | |
| year, | |
| quarter, | |
| ticker, | |
| keywords, | |
| indices, | |
| threshold, | |
| ) | |
| context = sentence_id_combine(data, query_results, lag=window) | |
| context_group.append((context, year, quarter, ticker)) | |
| multi_doc_context = generate_multi_doc_context(context_group) | |
| else: | |
| indices = None | |
| query_results = query_pinecone( | |
| dense_query_embedding, | |
| num_results, | |
| pinecone_index, | |
| None, | |
| None, | |
| None, | |
| keywords, | |
| indices, | |
| threshold, | |
| ) | |
| multi_doc_context = sentence_id_combine(data, query_results, lag=window) | |
| prompt = generate_prompt_alpaca_style(query_text, multi_doc_context) | |
| with col1: | |
| edited_prompt = st.text_area( | |
| label="Model Prompt", value=prompt, height=400 | |
| ) | |
| with st.sidebar: | |
| decoder_model = st.selectbox( | |
| "Select Text Generation Model", decoder_models_choice | |
| ) | |
| if decoder_model == "GPT-3.5 Turbo": | |
| with col2: | |
| with st.form("gpt_form"): | |
| openai_key = st.text_input( | |
| "Enter OpenAI key", | |
| value="", | |
| type="password", | |
| ) | |
| gpt_submitted = st.form_submit_button("Submit") | |
| if gpt_submitted: | |
| api_key = save_key(openai_key) | |
| openai.api_key = api_key | |
| generated_text = gpt_turbo_model(edited_prompt) | |
| st.subheader("Answer:") | |
| regex_pattern_sentences = ( | |
| "(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
| ) | |
| generated_text_list = re.split( | |
| regex_pattern_sentences, generated_text | |
| ) | |
| for answer_text in generated_text_list: | |
| answer_text = f"""{answer_text}""" | |
| st.write( | |
| f"<ul><li><p>{answer_text}</p></li></ul>", | |
| unsafe_allow_html=True, | |
| ) | |
| if decoder_model == "Vicuna-7B": | |
| with col2: | |
| with st.spinner( | |
| text="The Vicuna Model is running. The model takes approximately 10-15 mins to generate the text." | |
| ): | |
| generated_text = vicuna_text_generate( | |
| prompt, vicuna_text_gen_model | |
| ) | |
| st.subheader("Answer:") | |
| regex_pattern_sentences = "(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" | |
| generated_text_list = re.split(regex_pattern_sentences, generated_text) | |
| for answer_text in generated_text_list: | |
| answer_text = f"""{answer_text}""" | |
| st.write( | |
| f"<ul><li><p>{answer_text}</p></li></ul>", | |
| unsafe_allow_html=True, | |
| ) | |
| tab1, tab2 = st.tabs(["Retrieved Text", "Retrieved Documents"]) | |
| with tab1: | |
| with st.expander("See Retrieved Text"): | |
| st.subheader("Retrieved Text:") | |
| st.write( | |
| f"<p>{multi_doc_context}</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| with tab2: | |
| if ticker_year_quarter_tuples_list != []: | |
| for ticker, quarter, year in ticker_year_quarter_tuples_list: | |
| file_text = retrieve_transcript(data, year, quarter, ticker) | |
| with st.expander(f"See Transcript - {quarter} {year}"): | |
| st.subheader(f"Earnings Call Transcript - {quarter} {year}:") | |
| stx.scrollableTextbox( | |
| file_text, | |
| height=700, | |
| border=False, | |
| fontFamily="Helvetica", | |
| ) | |
| else: | |
| st.write( | |
| "No specific document/documents found. Please mention Ticker and Duration in the Question." | |
| ) | |