import streamlit as st import json from retrievals import TFIDFRetriever, BM25Retriever from retrieval import get_retrieval_dense, get_retrieval_tf_idf, get_retrieval_bm25 from embedding_function import sync_embed import numpy as np import os from dotenv import load_dotenv load_dotenv() st.set_page_config( page_title="Vector Store Query App", layout="wide", initial_sidebar_state="expanded" ) st.markdown(""" """, unsafe_allow_html=True) with st.sidebar: st.title("About") st.markdown(""" This app allows you to query a vector store and view results in both JSON format and rendered markdown. Enter your question in the main panel and click 'Search'. """) retrieval_method = st.selectbox( "Choose the retrieval method:", ["Sparse Retrievals", "Dense Retrievals", "Hybrid Retrievals"] ) if retrieval_method == "Sparse Retrievals": sparse_method = st.selectbox( "Choose a Sparse Retrieval method:", ["BM25", "TF-IDF"] ) st.write(f"Selected Sparse Method: {sparse_method}") elif retrieval_method == "Dense Retrievals": model_selection = st.selectbox( "Choose a model:", [ "sentence-transformers/all-MiniLM-L6-v2", "intfloat/multilingual-e5-large" ] ) st.write(f"Selected model: {model_selection}") st.session_state.model_selection = model_selection st.title("Vector Store Query Interface") if 'results' not in st.session_state: st.session_state.results = None with st.form("query_form"): col1, col2 = st.columns([4, 1]) with col1: query = st.text_input( "Enter your question:", placeholder="What are you looking for?", label_visibility="collapsed" ) with col2: st.write("") if st.form_submit_button("Search", use_container_width=True): if query: # Dense Retrieval with selected model if retrieval_method == "Dense Retrievals": model_selection = st.session_state.get('model_selection') api_key = os.getenv("HF_API_KEY") embeddings = sync_embed(texts=query, model=model_selection, api_key=api_key) st.session_state.results = get_retrieval_dense(query, model=model_selection, api_key=api_key) elif retrieval_method == "Sparse Retrievals" and sparse_method == "TF-IDF": st.session_state.results = get_retrieval_tf_idf(query) elif retrieval_method == "Sparse Retrievals" and sparse_method == "BM25": st.session_state.results = get_retrieval_bm25(query) else: st.warning("Please enter a question") if st.session_state.results: st.divider() st.subheader("Results") col_left, col_right = st.columns([1, 2], gap="large") with col_left: st.markdown("**JSON Output**") st.code( json.dumps(st.session_state.results['json'], indent=2), language='json' ) with col_right: st.markdown("**Document Content**") for i, doc in enumerate(st.session_state.results['json']['results']): with st.container(): st.markdown(f"### Document {i+1}") st.markdown(doc['content']) st.markdown(f"**Source:** {doc['metadata']}") st.divider() elif st.session_state.results is None: st.info("👈 Enter a question and click Search to get started")