rybrevant_rag / src /streamlit_app.py
shan gao
fix
d2aac9d
"""
Streamlit app for the RYBREVANT RAG chatbot.
- Builds one vector + summary tool per document (PI and brochure)
- Routes questions to the right tool with a FunctionAgent and ObjectIndex
- Surfaces concise answers with citations plus a safety disclaimer
Set environment variables before running:
- OPENAI_API_KEY
- LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion)
Run locally:
streamlit run rybrevant_streamlit.py
"""
import asyncio
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import requests
import streamlit as st
from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.objects import ObjectIndex
from llama_index.core.tools import FunctionTool, QueryEngineTool
from llama_index.core.vector_stores import FilterCondition, MetadataFilters
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_cloud_services import LlamaParse, EU_BASE_URL
DATA_DIR = Path("data")
DOC_SOURCES: Dict[str, Tuple[str, str, str]] = {
# filename: (url, short_label, long_description)
"rybrevant.pdf": (
"https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf",
"PI",
"RYBREVANT prescribing information (official label; use for dosing, safety, administration)",
),
"brochure.pdf": (
"https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf",
"brochure",
"RYBREVANT patient brochure (patient-friendly overview; use for general awareness)",
),
}
BASE_SYSTEM_PROMPT = (
"You are an agent designed to answer queries over a set of RYBREVANT documents." \
"Please always use the tools provided to answer a question. Do not rely on prior knowledge." \
"When responding, keep answers concise, always mention the source: exact document + page "
"(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer "
"('Not medical advice; consult your healthcare professional')."
)
def _ensure_data_files() -> None:
"""Download source PDFs if they are missing."""
DATA_DIR.mkdir(exist_ok=True)
for filename, (url, _, _) in DOC_SOURCES.items():
path = DATA_DIR / filename
if path.exists():
continue
resp = requests.get(url, timeout=60)
resp.raise_for_status()
path.write_bytes(resp.content)
def _build_tools_for_doc(file_path: Path, name: str):
"""Create vector and summary tools for a single document."""
parser = LlamaParse(
result_type="text",
language="en",
api_key=os.getenv("LLAMA_CLOUD_API_KEY"),
base_url=EU_BASE_URL,
)
documents = SimpleDirectoryReader(
input_files=[str(file_path)],
file_extractor={".pdf": parser},
).load_data()
for i, doc in enumerate(documents):
page = (
doc.metadata.get("page_label")
or doc.metadata.get("page")
or doc.metadata.get("page_number")
or doc.metadata.get("page_idx")
or i + 1
)
doc.metadata["page_label"] = page
doc.metadata["source"] = name
splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120)
nodes = splitter.get_nodes_from_documents(documents)
for node in nodes:
if "page_label" not in node.metadata:
node.metadata["page_label"] = (
node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx")
)
node.metadata["source"] = node.metadata.get("source") or name
if not nodes:
raise ValueError(f"No text nodes parsed from {file_path}. Check parser credentials or PDF availability.")
embed_model = OpenAIEmbedding(model="text-embedding-3-large")
vector_index = VectorStoreIndex(nodes, embed_model=embed_model)
def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str:
"""Grounded Q&A with optional page filters + citations.
Useful if you have specific questions over the document.
Always leave page_numbers as None UNLESS there is a specific page you want to search for.
Args:
query (str): the string query to be embedded.
page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE
if we want to perform a vector search
over all pages. Otherwise, filter by the set of specified pages.
"""
page_numbers = page_numbers or []
metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers]
query_engine = vector_index.as_query_engine(
similarity_top_k=4,
filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR),
)
response = query_engine.query(query)
citations = []
for sn in response.source_nodes:
page = sn.node.metadata.get("page_label")
src = sn.node.metadata.get("source", name)
citations.append(f"{src} p.{page}" if page else src)
citations = list(dict.fromkeys(citations))
if citations:
return f"{response}\n\nSources: {', '.join(citations)}"
return str(response)
vector_tool = FunctionTool.from_defaults(
name=f"vector_tool_{name}",
fn=vector_query,
description=f"Vector search over {name}; responds with grounded answer + page citations. Primary source for {name}.",
)
summary_index = SummaryIndex(nodes)
summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True)
summary_tool = QueryEngineTool.from_defaults(
name=f"summary_tool_{name}",
query_engine=summary_query_engine,
description=f"Useful for summarization questions related to {name}.",
)
return vector_tool, summary_tool
@st.cache_resource(show_spinner=False)
def build_agent():
"""Build and cache the FunctionAgent with tool routing."""
_ensure_data_files()
tool_sets = []
for filename, (_, display_name, _long_desc) in DOC_SOURCES.items():
tools = _build_tools_for_doc(DATA_DIR / filename, display_name)
tool_sets.extend(tools)
obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex)
obj_retriever = obj_index.as_retriever(similarity_top_k=4)
llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False)
return agent
async def _arun_agent(agent: FunctionAgent, prompt: str) -> str:
"""Await the agent workflow and return the stringified response."""
handler = agent.run(prompt)
return str(await handler)
def run_agent(agent: FunctionAgent, prompt: str) -> str:
"""Run the agent from Streamlit, whether or not an event loop is already running."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(_arun_agent(agent, prompt))
else:
return loop.run_until_complete(_arun_agent(agent, prompt))
def _require_env(var_name: str) -> bool:
"""Check that required environment variables are set; inform user if missing."""
if os.getenv(var_name):
return True
st.error(f"Missing environment variable: {var_name}")
return False
def main() -> None:
st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide")
st.title("RYBREVANT Q&A RAG")
st.write(
"Ask about the RYBREVANT prescribing information (PI) or patient brochure. "
"Responses stay grounded in the source documents and include page citations."
)
with st.sidebar:
st.header("About")
st.markdown(
"- Sources: PI and patient brochure\n"
"- Answers include citations and a safety disclaimer\n"
"- Data/parsing cached in this Space runtime"
)
st.divider()
st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.")
has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY")
if not has_keys:
st.stop()
agent = build_agent()
if "messages" not in st.session_state:
st.session_state.messages = []
for role, content in st.session_state.messages:
with st.chat_message(role):
st.markdown(content)
prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...")
if prompt:
st.session_state.messages.append(("user", prompt))
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Grounding answer in the documents..."):
try:
response = run_agent(agent, prompt)
except Exception as exc: # pylint: disable=broad-except
st.error(f"Something went wrong: {exc}")
return
st.markdown(response)
st.session_state.messages.append(("assistant", response))
st.caption("Not medical advice; always consult a healthcare professional.")
if __name__ == "__main__":
main()