retrieval_evaluation / src /streamlit_app.py
samiha123's picture
first commit
ae47781
import streamlit as st
import json
from retrievals import TFIDFRetriever, BM25Retriever
from retrieval import get_retrieval_dense, get_retrieval_tf_idf, get_retrieval_bm25
from embedding_function import sync_embed
import numpy as np
import os
from dotenv import load_dotenv
load_dotenv()
st.set_page_config(
page_title="Vector Store Query App",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
.block-container {
padding-top: 2rem;
padding-bottom: 2rem;
}
.section {
padding: 1rem;
border-radius: 0.5rem;
margin-bottom: 1rem;
}
</style>
""", unsafe_allow_html=True)
with st.sidebar:
st.title("About")
st.markdown("""
This app allows you to query a vector store and view results in both JSON format
and rendered markdown. Enter your question in the main panel and click 'Search'.
""")
retrieval_method = st.selectbox(
"Choose the retrieval method:",
["Sparse Retrievals", "Dense Retrievals", "Hybrid Retrievals"]
)
if retrieval_method == "Sparse Retrievals":
sparse_method = st.selectbox(
"Choose a Sparse Retrieval method:",
["BM25", "TF-IDF"]
)
st.write(f"Selected Sparse Method: {sparse_method}")
elif retrieval_method == "Dense Retrievals":
model_selection = st.selectbox(
"Choose a model:",
[
"sentence-transformers/all-MiniLM-L6-v2",
"intfloat/multilingual-e5-large"
]
)
st.write(f"Selected model: {model_selection}")
st.session_state.model_selection = model_selection
st.title("Vector Store Query Interface")
if 'results' not in st.session_state:
st.session_state.results = None
with st.form("query_form"):
col1, col2 = st.columns([4, 1])
with col1:
query = st.text_input(
"Enter your question:",
placeholder="What are you looking for?",
label_visibility="collapsed"
)
with col2:
st.write("")
if st.form_submit_button("Search", use_container_width=True):
if query:
# Dense Retrieval with selected model
if retrieval_method == "Dense Retrievals":
model_selection = st.session_state.get('model_selection')
api_key = os.getenv("HF_API_KEY")
embeddings = sync_embed(texts=query, model=model_selection, api_key=api_key)
st.session_state.results = get_retrieval_dense(query, model=model_selection, api_key=api_key)
elif retrieval_method == "Sparse Retrievals" and sparse_method == "TF-IDF":
st.session_state.results = get_retrieval_tf_idf(query)
elif retrieval_method == "Sparse Retrievals" and sparse_method == "BM25":
st.session_state.results = get_retrieval_bm25(query)
else:
st.warning("Please enter a question")
if st.session_state.results:
st.divider()
st.subheader("Results")
col_left, col_right = st.columns([1, 2], gap="large")
with col_left:
st.markdown("**JSON Output**")
st.code(
json.dumps(st.session_state.results['json'], indent=2),
language='json'
)
with col_right:
st.markdown("**Document Content**")
for i, doc in enumerate(st.session_state.results['json']['results']):
with st.container():
st.markdown(f"### Document {i+1}")
st.markdown(doc['content'])
st.markdown(f"**Source:** {doc['metadata']}")
st.divider()
elif st.session_state.results is None:
st.info("๐Ÿ‘ˆ Enter a question and click Search to get started")