Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit app for the RYBREVANT RAG chatbot. | |
| - Builds one vector + summary tool per document (PI and brochure) | |
| - Routes questions to the right tool with a FunctionAgent and ObjectIndex | |
| - Surfaces concise answers with citations plus a safety disclaimer | |
| Set environment variables before running: | |
| - OPENAI_API_KEY | |
| - LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion) | |
| Run locally: | |
| streamlit run rybrevant_streamlit.py | |
| """ | |
| import asyncio | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import requests | |
| import streamlit as st | |
| from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex | |
| from llama_index.core.agent.workflow import FunctionAgent | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.core.objects import ObjectIndex | |
| from llama_index.core.tools import FunctionTool, QueryEngineTool | |
| from llama_index.core.vector_stores import FilterCondition, MetadataFilters | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.llms.openai import OpenAI | |
| from llama_cloud_services import LlamaParse, EU_BASE_URL | |
| DATA_DIR = Path("data") | |
| DOC_SOURCES: Dict[str, Tuple[str, str, str]] = { | |
| # filename: (url, short_label, long_description) | |
| "rybrevant.pdf": ( | |
| "https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf", | |
| "PI", | |
| "RYBREVANT prescribing information (official label; use for dosing, safety, administration)", | |
| ), | |
| "brochure.pdf": ( | |
| "https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf", | |
| "brochure", | |
| "RYBREVANT patient brochure (patient-friendly overview; use for general awareness)", | |
| ), | |
| } | |
| BASE_SYSTEM_PROMPT = ( | |
| "You are an agent designed to answer queries over a set of RYBREVANT documents." \ | |
| "Please always use the tools provided to answer a question. Do not rely on prior knowledge." \ | |
| "When responding, keep answers concise, always mention the source: exact document + page " | |
| "(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer " | |
| "('Not medical advice; consult your healthcare professional')." | |
| ) | |
| def _ensure_data_files() -> None: | |
| """Download source PDFs if they are missing.""" | |
| DATA_DIR.mkdir(exist_ok=True) | |
| for filename, (url, _, _) in DOC_SOURCES.items(): | |
| path = DATA_DIR / filename | |
| if path.exists(): | |
| continue | |
| resp = requests.get(url, timeout=60) | |
| resp.raise_for_status() | |
| path.write_bytes(resp.content) | |
| def _build_tools_for_doc(file_path: Path, name: str): | |
| """Create vector and summary tools for a single document.""" | |
| parser = LlamaParse( | |
| result_type="text", | |
| language="en", | |
| api_key=os.getenv("LLAMA_CLOUD_API_KEY"), | |
| base_url=EU_BASE_URL, | |
| ) | |
| documents = SimpleDirectoryReader( | |
| input_files=[str(file_path)], | |
| file_extractor={".pdf": parser}, | |
| ).load_data() | |
| for i, doc in enumerate(documents): | |
| page = ( | |
| doc.metadata.get("page_label") | |
| or doc.metadata.get("page") | |
| or doc.metadata.get("page_number") | |
| or doc.metadata.get("page_idx") | |
| or i + 1 | |
| ) | |
| doc.metadata["page_label"] = page | |
| doc.metadata["source"] = name | |
| splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120) | |
| nodes = splitter.get_nodes_from_documents(documents) | |
| for node in nodes: | |
| if "page_label" not in node.metadata: | |
| node.metadata["page_label"] = ( | |
| node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx") | |
| ) | |
| node.metadata["source"] = node.metadata.get("source") or name | |
| if not nodes: | |
| raise ValueError(f"No text nodes parsed from {file_path}. Check parser credentials or PDF availability.") | |
| embed_model = OpenAIEmbedding(model="text-embedding-3-large") | |
| vector_index = VectorStoreIndex(nodes, embed_model=embed_model) | |
| def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str: | |
| """Grounded Q&A with optional page filters + citations. | |
| Useful if you have specific questions over the document. | |
| Always leave page_numbers as None UNLESS there is a specific page you want to search for. | |
| Args: | |
| query (str): the string query to be embedded. | |
| page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE | |
| if we want to perform a vector search | |
| over all pages. Otherwise, filter by the set of specified pages. | |
| """ | |
| page_numbers = page_numbers or [] | |
| metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers] | |
| query_engine = vector_index.as_query_engine( | |
| similarity_top_k=4, | |
| filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR), | |
| ) | |
| response = query_engine.query(query) | |
| citations = [] | |
| for sn in response.source_nodes: | |
| page = sn.node.metadata.get("page_label") | |
| src = sn.node.metadata.get("source", name) | |
| citations.append(f"{src} p.{page}" if page else src) | |
| citations = list(dict.fromkeys(citations)) | |
| if citations: | |
| return f"{response}\n\nSources: {', '.join(citations)}" | |
| return str(response) | |
| vector_tool = FunctionTool.from_defaults( | |
| name=f"vector_tool_{name}", | |
| fn=vector_query, | |
| description=f"Vector search over {name}; responds with grounded answer + page citations. Primary source for {name}.", | |
| ) | |
| summary_index = SummaryIndex(nodes) | |
| summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True) | |
| summary_tool = QueryEngineTool.from_defaults( | |
| name=f"summary_tool_{name}", | |
| query_engine=summary_query_engine, | |
| description=f"Useful for summarization questions related to {name}.", | |
| ) | |
| return vector_tool, summary_tool | |
| def build_agent(): | |
| """Build and cache the FunctionAgent with tool routing.""" | |
| _ensure_data_files() | |
| tool_sets = [] | |
| for filename, (_, display_name, _long_desc) in DOC_SOURCES.items(): | |
| tools = _build_tools_for_doc(DATA_DIR / filename, display_name) | |
| tool_sets.extend(tools) | |
| obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex) | |
| obj_retriever = obj_index.as_retriever(similarity_top_k=4) | |
| llm = OpenAI(model="gpt-3.5-turbo", temperature=0) | |
| agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False) | |
| return agent | |
| async def _arun_agent(agent: FunctionAgent, prompt: str) -> str: | |
| """Await the agent workflow and return the stringified response.""" | |
| handler = agent.run(prompt) | |
| return str(await handler) | |
| def run_agent(agent: FunctionAgent, prompt: str) -> str: | |
| """Run the agent from Streamlit, whether or not an event loop is already running.""" | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| return asyncio.run(_arun_agent(agent, prompt)) | |
| else: | |
| return loop.run_until_complete(_arun_agent(agent, prompt)) | |
| def _require_env(var_name: str) -> bool: | |
| """Check that required environment variables are set; inform user if missing.""" | |
| if os.getenv(var_name): | |
| return True | |
| st.error(f"Missing environment variable: {var_name}") | |
| return False | |
| def main() -> None: | |
| st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide") | |
| st.title("RYBREVANT Q&A RAG") | |
| st.write( | |
| "Ask about the RYBREVANT prescribing information (PI) or patient brochure. " | |
| "Responses stay grounded in the source documents and include page citations." | |
| ) | |
| with st.sidebar: | |
| st.header("About") | |
| st.markdown( | |
| "- Sources: PI and patient brochure\n" | |
| "- Answers include citations and a safety disclaimer\n" | |
| "- Data/parsing cached in this Space runtime" | |
| ) | |
| st.divider() | |
| st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.") | |
| has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY") | |
| if not has_keys: | |
| st.stop() | |
| agent = build_agent() | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for role, content in st.session_state.messages: | |
| with st.chat_message(role): | |
| st.markdown(content) | |
| prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...") | |
| if prompt: | |
| st.session_state.messages.append(("user", prompt)) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Grounding answer in the documents..."): | |
| try: | |
| response = run_agent(agent, prompt) | |
| except Exception as exc: # pylint: disable=broad-except | |
| st.error(f"Something went wrong: {exc}") | |
| return | |
| st.markdown(response) | |
| st.session_state.messages.append(("assistant", response)) | |
| st.caption("Not medical advice; always consult a healthcare professional.") | |
| if __name__ == "__main__": | |
| main() | |