gabrielaltay's picture
update
ac2020e
"""Single RAG tab implementation"""
import streamlit as st
from legisqa_local.tabs.base import BaseTab
from legisqa_local.components.forms import get_generative_config, get_retrieval_config
from legisqa_local.components.display import render_example_queries, render_response
from legisqa_local.core.rag import process_query
from legisqa_local.config.models import PROVIDER_MODELS
class RAGTab(BaseTab):
"""Single RAG query tab"""
def __init__(self):
super().__init__("RAG", "query_rag")
def render(self):
"""Render the RAG tab"""
SS = st.session_state
render_example_queries()
with st.form(f"{self.key_prefix}|query_form"):
query = st.text_area(
"Enter a query that can be answered with congressional legislation:"
)
cols = st.columns(2)
with cols[0]:
query_submitted = st.form_submit_button("Submit")
with cols[1]:
status_placeholder = st.empty()
col1, col2 = st.columns(2)
with col1:
with st.expander("Generative Config"):
gen_config = get_generative_config(self.key_prefix)
with col2:
with st.expander("Retrieval Config"):
ret_config = get_retrieval_config(self.key_prefix)
rkey = f"{self.key_prefix}|response"
if query_submitted:
with status_placeholder:
with st.spinner("generating response"):
SS[rkey] = process_query(gen_config, ret_config, query)
if response := SS.get(rkey):
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
render_response(
response,
model_info,
gen_config["provider"],
gen_config["should_escape_markdown"],
gen_config["should_add_legis_urls"],
)
with st.expander("Debug"):
st.write(response)