Spaces:
Sleeping
Sleeping
| """Side-by-side RAG tab implementation""" | |
| import streamlit as st | |
| from legisqa_local.tabs.base import BaseTab | |
| from legisqa_local.components.forms import get_generative_config, get_retrieval_config | |
| from legisqa_local.components.display import render_response | |
| from legisqa_local.core.rag import process_query | |
| from legisqa_local.config.models import PROVIDER_MODELS | |
| class RAGSideBySideTab(BaseTab): | |
| """Side-by-side RAG comparison tab""" | |
| def __init__(self): | |
| super().__init__("RAG (side-by-side)", "query_rag_sbs") | |
| def render(self): | |
| """Render the side-by-side RAG tab""" | |
| SS = st.session_state | |
| with st.form(f"{self.key_prefix}|query_form"): | |
| query = st.text_area( | |
| "Enter a query that can be answered with congressional legislation:" | |
| ) | |
| cols = st.columns(2) | |
| with cols[0]: | |
| query_submitted = st.form_submit_button("Submit") | |
| with cols[1]: | |
| status_placeholder = st.empty() | |
| grp1a, grp2a = st.columns(2) | |
| gen_configs = {} | |
| ret_configs = {} | |
| with grp1a: | |
| st.header("Group 1") | |
| key_prefix = f"{self.key_prefix}|grp1" | |
| with st.expander("Generative Config"): | |
| gen_configs["grp1"] = get_generative_config(key_prefix) | |
| with st.expander("Retrieval Config"): | |
| ret_configs["grp1"] = get_retrieval_config(key_prefix) | |
| with grp2a: | |
| st.header("Group 2") | |
| key_prefix = f"{self.key_prefix}|grp2" | |
| with st.expander("Generative Config"): | |
| gen_configs["grp2"] = get_generative_config(key_prefix) | |
| with st.expander("Retrieval Config"): | |
| ret_configs["grp2"] = get_retrieval_config(key_prefix) | |
| grp1b, grp2b = st.columns(2) | |
| sbs_cols = {"grp1": grp1b, "grp2": grp2b} | |
| grp_names = {"grp1": "Group 1", "grp2": "Group 2"} | |
| for post_key_prefix in ["grp1", "grp2"]: | |
| with sbs_cols[post_key_prefix]: | |
| key_prefix = f"{self.key_prefix}|{post_key_prefix}" | |
| rkey = f"{key_prefix}|response" | |
| if query_submitted: | |
| with status_placeholder: | |
| with st.spinner( | |
| "generating response for {}".format(grp_names[post_key_prefix]) | |
| ): | |
| SS[rkey] = process_query( | |
| gen_configs[post_key_prefix], | |
| ret_configs[post_key_prefix], | |
| query, | |
| ) | |
| if response := SS.get(rkey): | |
| model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][ | |
| gen_configs[post_key_prefix]["model_name"] | |
| ] | |
| render_response( | |
| response, | |
| model_info, | |
| gen_configs[post_key_prefix]["provider"], | |
| gen_configs[post_key_prefix]["should_escape_markdown"], | |
| gen_configs[post_key_prefix]["should_add_legis_urls"], | |
| tag=grp_names[post_key_prefix], | |
| ) | |