legisqa-local / src /legisqa_local /tabs /rag_sbs_tab.py
gabrielaltay's picture
update
ac2020e
"""Side-by-side RAG tab implementation"""
import streamlit as st
from legisqa_local.tabs.base import BaseTab
from legisqa_local.components.forms import get_generative_config, get_retrieval_config
from legisqa_local.components.display import render_response
from legisqa_local.core.rag import process_query
from legisqa_local.config.models import PROVIDER_MODELS
class RAGSideBySideTab(BaseTab):
"""Side-by-side RAG comparison tab"""
def __init__(self):
super().__init__("RAG (side-by-side)", "query_rag_sbs")
def render(self):
"""Render the side-by-side RAG tab"""
SS = st.session_state
with st.form(f"{self.key_prefix}|query_form"):
query = st.text_area(
"Enter a query that can be answered with congressional legislation:"
)
cols = st.columns(2)
with cols[0]:
query_submitted = st.form_submit_button("Submit")
with cols[1]:
status_placeholder = st.empty()
grp1a, grp2a = st.columns(2)
gen_configs = {}
ret_configs = {}
with grp1a:
st.header("Group 1")
key_prefix = f"{self.key_prefix}|grp1"
with st.expander("Generative Config"):
gen_configs["grp1"] = get_generative_config(key_prefix)
with st.expander("Retrieval Config"):
ret_configs["grp1"] = get_retrieval_config(key_prefix)
with grp2a:
st.header("Group 2")
key_prefix = f"{self.key_prefix}|grp2"
with st.expander("Generative Config"):
gen_configs["grp2"] = get_generative_config(key_prefix)
with st.expander("Retrieval Config"):
ret_configs["grp2"] = get_retrieval_config(key_prefix)
grp1b, grp2b = st.columns(2)
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
grp_names = {"grp1": "Group 1", "grp2": "Group 2"}
for post_key_prefix in ["grp1", "grp2"]:
with sbs_cols[post_key_prefix]:
key_prefix = f"{self.key_prefix}|{post_key_prefix}"
rkey = f"{key_prefix}|response"
if query_submitted:
with status_placeholder:
with st.spinner(
"generating response for {}".format(grp_names[post_key_prefix])
):
SS[rkey] = process_query(
gen_configs[post_key_prefix],
ret_configs[post_key_prefix],
query,
)
if response := SS.get(rkey):
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
gen_configs[post_key_prefix]["model_name"]
]
render_response(
response,
model_info,
gen_configs[post_key_prefix]["provider"],
gen_configs[post_key_prefix]["should_escape_markdown"],
gen_configs[post_key_prefix]["should_add_legis_urls"],
tag=grp_names[post_key_prefix],
)