Spaces:
Sleeping
Sleeping
File size: 9,430 Bytes
2024ac0 efb3827 2024ac0 efb3827 2024ac0 efb3827 2024ac0 efb3827 d2aac9d efb3827 1c9e654 efb3827 914651d efb3827 914651d efb3827 914651d efb3827 d2aac9d efb3827 914651d efb3827 1c9e654 efb3827 1c9e654 d2aac9d efb3827 914651d efb3827 eb2c4c9 efb3827 914651d efb3827 aff4b96 efb3827 aff4b96 efb3827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
"""
Streamlit app for the RYBREVANT RAG chatbot.
- Builds one vector + summary tool per document (PI and brochure)
- Routes questions to the right tool with a FunctionAgent and ObjectIndex
- Surfaces concise answers with citations plus a safety disclaimer
Set environment variables before running:
- OPENAI_API_KEY
- LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion)
Run locally:
streamlit run rybrevant_streamlit.py
"""
import asyncio
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import requests
import streamlit as st
from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.objects import ObjectIndex
from llama_index.core.tools import FunctionTool, QueryEngineTool
from llama_index.core.vector_stores import FilterCondition, MetadataFilters
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_cloud_services import LlamaParse, EU_BASE_URL
DATA_DIR = Path("data")
DOC_SOURCES: Dict[str, Tuple[str, str, str]] = {
# filename: (url, short_label, long_description)
"rybrevant.pdf": (
"https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf",
"PI",
"RYBREVANT prescribing information (official label; use for dosing, safety, administration)",
),
"brochure.pdf": (
"https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf",
"brochure",
"RYBREVANT patient brochure (patient-friendly overview; use for general awareness)",
),
}
BASE_SYSTEM_PROMPT = (
"You are an agent designed to answer queries over a set of RYBREVANT documents." \
"Please always use the tools provided to answer a question. Do not rely on prior knowledge." \
"When responding, keep answers concise, always mention the source: exact document + page "
"(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer "
"('Not medical advice; consult your healthcare professional')."
)
def _ensure_data_files() -> None:
"""Download source PDFs if they are missing."""
DATA_DIR.mkdir(exist_ok=True)
for filename, (url, _, _) in DOC_SOURCES.items():
path = DATA_DIR / filename
if path.exists():
continue
resp = requests.get(url, timeout=60)
resp.raise_for_status()
path.write_bytes(resp.content)
def _build_tools_for_doc(file_path: Path, name: str):
"""Create vector and summary tools for a single document."""
parser = LlamaParse(
result_type="text",
language="en",
api_key=os.getenv("LLAMA_CLOUD_API_KEY"),
base_url=EU_BASE_URL,
)
documents = SimpleDirectoryReader(
input_files=[str(file_path)],
file_extractor={".pdf": parser},
).load_data()
for i, doc in enumerate(documents):
page = (
doc.metadata.get("page_label")
or doc.metadata.get("page")
or doc.metadata.get("page_number")
or doc.metadata.get("page_idx")
or i + 1
)
doc.metadata["page_label"] = page
doc.metadata["source"] = name
splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120)
nodes = splitter.get_nodes_from_documents(documents)
for node in nodes:
if "page_label" not in node.metadata:
node.metadata["page_label"] = (
node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx")
)
node.metadata["source"] = node.metadata.get("source") or name
if not nodes:
raise ValueError(f"No text nodes parsed from {file_path}. Check parser credentials or PDF availability.")
embed_model = OpenAIEmbedding(model="text-embedding-3-large")
vector_index = VectorStoreIndex(nodes, embed_model=embed_model)
def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str:
"""Grounded Q&A with optional page filters + citations.
Useful if you have specific questions over the document.
Always leave page_numbers as None UNLESS there is a specific page you want to search for.
Args:
query (str): the string query to be embedded.
page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE
if we want to perform a vector search
over all pages. Otherwise, filter by the set of specified pages.
"""
page_numbers = page_numbers or []
metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers]
query_engine = vector_index.as_query_engine(
similarity_top_k=4,
filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR),
)
response = query_engine.query(query)
citations = []
for sn in response.source_nodes:
page = sn.node.metadata.get("page_label")
src = sn.node.metadata.get("source", name)
citations.append(f"{src} p.{page}" if page else src)
citations = list(dict.fromkeys(citations))
if citations:
return f"{response}\n\nSources: {', '.join(citations)}"
return str(response)
vector_tool = FunctionTool.from_defaults(
name=f"vector_tool_{name}",
fn=vector_query,
description=f"Vector search over {name}; responds with grounded answer + page citations. Primary source for {name}.",
)
summary_index = SummaryIndex(nodes)
summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True)
summary_tool = QueryEngineTool.from_defaults(
name=f"summary_tool_{name}",
query_engine=summary_query_engine,
description=f"Useful for summarization questions related to {name}.",
)
return vector_tool, summary_tool
@st.cache_resource(show_spinner=False)
def build_agent():
"""Build and cache the FunctionAgent with tool routing."""
_ensure_data_files()
tool_sets = []
for filename, (_, display_name, _long_desc) in DOC_SOURCES.items():
tools = _build_tools_for_doc(DATA_DIR / filename, display_name)
tool_sets.extend(tools)
obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex)
obj_retriever = obj_index.as_retriever(similarity_top_k=4)
llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False)
return agent
async def _arun_agent(agent: FunctionAgent, prompt: str) -> str:
"""Await the agent workflow and return the stringified response."""
handler = agent.run(prompt)
return str(await handler)
def run_agent(agent: FunctionAgent, prompt: str) -> str:
"""Run the agent from Streamlit, whether or not an event loop is already running."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(_arun_agent(agent, prompt))
else:
return loop.run_until_complete(_arun_agent(agent, prompt))
def _require_env(var_name: str) -> bool:
"""Check that required environment variables are set; inform user if missing."""
if os.getenv(var_name):
return True
st.error(f"Missing environment variable: {var_name}")
return False
def main() -> None:
st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide")
st.title("RYBREVANT Q&A RAG")
st.write(
"Ask about the RYBREVANT prescribing information (PI) or patient brochure. "
"Responses stay grounded in the source documents and include page citations."
)
with st.sidebar:
st.header("About")
st.markdown(
"- Sources: PI and patient brochure\n"
"- Answers include citations and a safety disclaimer\n"
"- Data/parsing cached in this Space runtime"
)
st.divider()
st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.")
has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY")
if not has_keys:
st.stop()
agent = build_agent()
if "messages" not in st.session_state:
st.session_state.messages = []
for role, content in st.session_state.messages:
with st.chat_message(role):
st.markdown(content)
prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...")
if prompt:
st.session_state.messages.append(("user", prompt))
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Grounding answer in the documents..."):
try:
response = run_agent(agent, prompt)
except Exception as exc: # pylint: disable=broad-except
st.error(f"Something went wrong: {exc}")
return
st.markdown(response)
st.session_state.messages.append(("assistant", response))
st.caption("Not medical advice; always consult a healthcare professional.")
if __name__ == "__main__":
main()
|