File size: 3,941 Bytes
b974dc8
07c3ebf
 
86b5672
630b876
07c3ebf
 
ae47781
07c3ebf
ae47781
07c3ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194d499
 
 
07c3ebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b974dc8
07c3ebf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import json
from retrievals import TFIDFRetriever, BM25Retriever
from retrieval import get_retrieval_dense, get_retrieval_tf_idf, get_retrieval_bm25
from embedding_function import sync_embed
import numpy as np
import os
from dotenv import load_dotenv

load_dotenv()
st.set_page_config(
    page_title="Vector Store Query App",
    layout="wide",
    initial_sidebar_state="expanded"
)

st.markdown("""
    <style>
        .block-container {
            padding-top: 2rem;
            padding-bottom: 2rem;
        }
        .section {
            padding: 1rem;
            border-radius: 0.5rem;
            margin-bottom: 1rem;
        }
    </style>
""", unsafe_allow_html=True)

with st.sidebar:
    st.title("About")
    st.markdown("""
    This app allows you to query a vector store and view results in both JSON format
    and rendered markdown. Enter your question in the main panel and click 'Search'.
    """)

    retrieval_method = st.selectbox(
        "Choose the retrieval method:",
        ["Sparse Retrievals", "Dense Retrievals", "Hybrid Retrievals"]
    )
    
    if retrieval_method == "Sparse Retrievals":
        sparse_method = st.selectbox(
            "Choose a Sparse Retrieval method:",
            ["BM25", "TF-IDF"]
        )
        st.write(f"Selected Sparse Method: {sparse_method}")
    
    elif retrieval_method == "Dense Retrievals":
        model_selection = st.selectbox(
            "Choose a model:",
            [
                
                "sentence-transformers/all-MiniLM-L6-v2",
                "intfloat/multilingual-e5-large"
                
                
            ]
        )
        st.write(f"Selected model: {model_selection}")
        st.session_state.model_selection = model_selection

st.title("Vector Store Query Interface")

if 'results' not in st.session_state:
    st.session_state.results = None

with st.form("query_form"):
    col1, col2 = st.columns([4, 1])
    with col1:
        query = st.text_input(
            "Enter your question:",
            placeholder="What are you looking for?",
            label_visibility="collapsed"
        )
    with col2:
        st.write("")  
        if st.form_submit_button("Search", use_container_width=True):
            if query:
                # Dense Retrieval with selected model
                if retrieval_method == "Dense Retrievals":
                    model_selection = st.session_state.get('model_selection')
                    api_key = os.getenv("HF_API_KEY")
                    embeddings = sync_embed(texts=query, model=model_selection, api_key=api_key)
                    st.session_state.results = get_retrieval_dense(query, model=model_selection, api_key=api_key)

                elif retrieval_method == "Sparse Retrievals" and sparse_method == "TF-IDF":
                    st.session_state.results = get_retrieval_tf_idf(query)
                elif retrieval_method == "Sparse Retrievals" and sparse_method == "BM25":
                    st.session_state.results = get_retrieval_bm25(query)
            else:
                st.warning("Please enter a question")

if st.session_state.results:
    st.divider()
    st.subheader("Results")
    
    col_left, col_right = st.columns([1, 2], gap="large")
    
    with col_left:
        st.markdown("**JSON Output**")
        st.code(
            json.dumps(st.session_state.results['json'], indent=2),
            language='json'
        )
    
    with col_right:
        st.markdown("**Document Content**")
        
        for i, doc in enumerate(st.session_state.results['json']['results']):
            with st.container():
                st.markdown(f"### Document {i+1}")
                st.markdown(doc['content'])
                st.markdown(f"**Source:** {doc['metadata']}")
                st.divider()

elif st.session_state.results is None:
    st.info("👈 Enter a question and click Search to get started")