Spaces:
Sleeping
Sleeping
shan gao
commited on
Commit
·
efb3827
1
Parent(s):
3496b64
modify streamit_app.py
Browse files- .DS_Store +0 -0
- requirements.txt +5 -1
- src/streamlit_app.py +230 -35
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
+
streamlit
|
| 4 |
+
llama-index
|
| 5 |
+
llama-cloud-services
|
| 6 |
+
llama-index-llms-openai
|
| 7 |
+
llama-index-embeddings-openai
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,235 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
"""
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Streamlit app for the RYBREVANT RAG chatbot.
|
| 3 |
+
|
| 4 |
+
- Builds one vector + summary tool per document (PI and brochure)
|
| 5 |
+
- Routes questions to the right tool with a FunctionAgent and ObjectIndex
|
| 6 |
+
- Surfaces concise answers with citations plus a safety disclaimer
|
| 7 |
|
| 8 |
+
Set environment variables before running:
|
| 9 |
+
- OPENAI_API_KEY
|
| 10 |
+
- LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion)
|
| 11 |
|
| 12 |
+
Run locally:
|
| 13 |
+
streamlit run rybrevant_streamlit.py
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
import asyncio
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
import streamlit as st
|
| 23 |
+
from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex
|
| 24 |
+
from llama_index.core.agent.workflow import FunctionAgent
|
| 25 |
+
from llama_index.core.node_parser import SentenceSplitter
|
| 26 |
+
from llama_index.core.objects import ObjectIndex
|
| 27 |
+
from llama_index.core.tools import FunctionTool, QueryEngineTool
|
| 28 |
+
from llama_index.core.vector_stores import FilterCondition, MetadataFilters
|
| 29 |
+
from llama_index.embeddings.openai import OpenAIEmbedding
|
| 30 |
+
from llama_index.llms.openai import OpenAI
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
EU_BASE_URL = "https://api.cloud.eu.llamaindex.ai"
|
| 34 |
+
from llama_cloud_services import (
|
| 35 |
+
LlamaParse,
|
| 36 |
+
EU_BASE_URL,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
DATA_DIR = Path("data")
|
| 41 |
+
DOC_SOURCES: Dict[str, Tuple[str, str]] = {
|
| 42 |
+
"rybrevant.pdf": (
|
| 43 |
+
"https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf",
|
| 44 |
+
"PI",
|
| 45 |
+
),
|
| 46 |
+
"brochure.pdf": (
|
| 47 |
+
"https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf",
|
| 48 |
+
"brochure",
|
| 49 |
+
),
|
| 50 |
+
}
|
| 51 |
+
BASE_SYSTEM_PROMPT = (
|
| 52 |
+
"You are an agent designed to answer queries over a set of RYBREVANT documents." \
|
| 53 |
+
"Please always use the tools provided to answer a question. Do not rely on prior knowledge." \
|
| 54 |
+
"When responding, keep answers concise, always mention the source: exact document + page "
|
| 55 |
+
"(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer "
|
| 56 |
+
"('Not medical advice; consult your healthcare professional')."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _ensure_data_files() -> None:
|
| 61 |
+
"""Download source PDFs if they are missing."""
|
| 62 |
+
DATA_DIR.mkdir(exist_ok=True)
|
| 63 |
+
for filename, (url, _) in DOC_SOURCES.items():
|
| 64 |
+
path = DATA_DIR / filename
|
| 65 |
+
if path.exists():
|
| 66 |
+
continue
|
| 67 |
+
resp = requests.get(url, timeout=60)
|
| 68 |
+
resp.raise_for_status()
|
| 69 |
+
path.write_bytes(resp.content)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _build_tools_for_doc(file_path: Path, name: str):
|
| 73 |
+
"""Create vector and summary tools for a single document."""
|
| 74 |
+
parser = LlamaParse(result_type="text", language="en", api_key=os.getenv("LLAMA_CLOUD_API_KEY"))
|
| 75 |
+
documents = SimpleDirectoryReader(
|
| 76 |
+
input_files=[str(file_path)],
|
| 77 |
+
file_extractor={".pdf": parser},
|
| 78 |
+
).load_data()
|
| 79 |
+
|
| 80 |
+
for i, doc in enumerate(documents):
|
| 81 |
+
page = (
|
| 82 |
+
doc.metadata.get("page_label")
|
| 83 |
+
or doc.metadata.get("page")
|
| 84 |
+
or doc.metadata.get("page_number")
|
| 85 |
+
or doc.metadata.get("page_idx")
|
| 86 |
+
or i + 1
|
| 87 |
+
)
|
| 88 |
+
doc.metadata["page_label"] = page
|
| 89 |
+
doc.metadata["source"] = name
|
| 90 |
+
|
| 91 |
+
splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120)
|
| 92 |
+
nodes = splitter.get_nodes_from_documents(documents)
|
| 93 |
+
for node in nodes:
|
| 94 |
+
if "page_label" not in node.metadata:
|
| 95 |
+
node.metadata["page_label"] = (
|
| 96 |
+
node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx")
|
| 97 |
+
)
|
| 98 |
+
node.metadata["source"] = node.metadata.get("source") or name
|
| 99 |
+
|
| 100 |
+
embed_model = OpenAIEmbedding(model="text-embedding-3-large")
|
| 101 |
+
vector_index = VectorStoreIndex(nodes, embed_model=embed_model)
|
| 102 |
+
|
| 103 |
+
def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str:
|
| 104 |
+
"""Grounded Q&A with optional page filters + citations.
|
| 105 |
+
|
| 106 |
+
Useful if you have specific questions over the document.
|
| 107 |
+
Always leave page_numbers as None UNLESS there is a specific page you want to search for.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
query (str): the string query to be embedded.
|
| 111 |
+
page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE
|
| 112 |
+
if we want to perform a vector search
|
| 113 |
+
over all pages. Otherwise, filter by the set of specified pages.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
page_numbers = page_numbers or []
|
| 117 |
+
metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers]
|
| 118 |
+
query_engine = vector_index.as_query_engine(
|
| 119 |
+
similarity_top_k=4,
|
| 120 |
+
filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR),
|
| 121 |
+
)
|
| 122 |
+
response = query_engine.query(query)
|
| 123 |
+
|
| 124 |
+
citations = []
|
| 125 |
+
for sn in response.source_nodes:
|
| 126 |
+
page = sn.node.metadata.get("page_label")
|
| 127 |
+
src = sn.node.metadata.get("source", name)
|
| 128 |
+
citations.append(f"{src} p.{page}" if page else src)
|
| 129 |
+
citations = list(dict.fromkeys(citations))
|
| 130 |
+
|
| 131 |
+
if citations:
|
| 132 |
+
return f"{response}\n\nSources: {', '.join(citations)}"
|
| 133 |
+
return str(response)
|
| 134 |
+
|
| 135 |
+
vector_tool = FunctionTool.from_defaults(
|
| 136 |
+
name=f"vector_tool_{name}",
|
| 137 |
+
fn=vector_query,
|
| 138 |
+
description=f"Vector search over {name}; responds with grounded answer + page citations.",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
summary_index = SummaryIndex(nodes)
|
| 142 |
+
summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True)
|
| 143 |
+
summary_tool = QueryEngineTool.from_defaults(
|
| 144 |
+
name=f"summary_tool_{name}",
|
| 145 |
+
query_engine=summary_query_engine,
|
| 146 |
+
description=f"Useful for summarization questions related to {name}.",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return vector_tool, summary_tool
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@st.cache_resource(show_spinner=False)
|
| 153 |
+
def build_agent():
|
| 154 |
+
"""Build and cache the FunctionAgent with tool routing."""
|
| 155 |
+
_ensure_data_files()
|
| 156 |
+
|
| 157 |
+
tool_sets = []
|
| 158 |
+
for filename, (_, display_name) in DOC_SOURCES.items():
|
| 159 |
+
tools = _build_tools_for_doc(DATA_DIR / filename, display_name)
|
| 160 |
+
tool_sets.extend(tools)
|
| 161 |
+
|
| 162 |
+
obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex)
|
| 163 |
+
obj_retriever = obj_index.as_retriever(similarity_top_k=2)
|
| 164 |
+
|
| 165 |
+
llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
|
| 166 |
+
agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False)
|
| 167 |
+
return agent
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def run_agent(agent: FunctionAgent, prompt: str) -> str:
|
| 171 |
+
"""Run the agent synchronously inside Streamlit."""
|
| 172 |
+
handler = agent.run(prompt)
|
| 173 |
+
return str(asyncio.run(handler))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _require_env(var_name: str) -> bool:
|
| 177 |
+
"""Check that required environment variables are set; inform user if missing."""
|
| 178 |
+
if os.getenv(var_name):
|
| 179 |
+
return True
|
| 180 |
+
st.error(f"Missing environment variable: {var_name}")
|
| 181 |
+
return False
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def main() -> None:
|
| 185 |
+
st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide")
|
| 186 |
+
st.title("RYBREVANT Q&A RAG")
|
| 187 |
+
st.write(
|
| 188 |
+
"Ask about the RYBREVANT prescribing information (PI) or patient brochure. "
|
| 189 |
+
"Responses stay grounded in the source documents and include page citations."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
with st.sidebar:
|
| 193 |
+
st.header("About")
|
| 194 |
+
st.markdown(
|
| 195 |
+
"- Sources: PI and patient brochure\n"
|
| 196 |
+
"- Answers include citations and a safety disclaimer\n"
|
| 197 |
+
"- Data/parsing cached in this Space runtime"
|
| 198 |
+
)
|
| 199 |
+
st.divider()
|
| 200 |
+
st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.")
|
| 201 |
+
|
| 202 |
+
has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY")
|
| 203 |
+
if not has_keys:
|
| 204 |
+
st.stop()
|
| 205 |
+
|
| 206 |
+
agent = build_agent()
|
| 207 |
+
|
| 208 |
+
if "messages" not in st.session_state:
|
| 209 |
+
st.session_state.messages = []
|
| 210 |
+
|
| 211 |
+
for role, content in st.session_state.messages:
|
| 212 |
+
with st.chat_message(role):
|
| 213 |
+
st.markdown(content)
|
| 214 |
+
|
| 215 |
+
prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...")
|
| 216 |
+
if prompt:
|
| 217 |
+
st.session_state.messages.append(("user", prompt))
|
| 218 |
+
with st.chat_message("user"):
|
| 219 |
+
st.markdown(prompt)
|
| 220 |
+
|
| 221 |
+
with st.chat_message("assistant"):
|
| 222 |
+
with st.spinner("Grounding answer in the documents..."):
|
| 223 |
+
try:
|
| 224 |
+
response = run_agent(agent, prompt)
|
| 225 |
+
except Exception as exc: # pylint: disable=broad-except
|
| 226 |
+
st.error(f"Something went wrong: {exc}")
|
| 227 |
+
return
|
| 228 |
+
st.markdown(response)
|
| 229 |
+
st.session_state.messages.append(("assistant", response))
|
| 230 |
+
|
| 231 |
+
st.caption("Not medical advice; always consult a healthcare professional.")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
main()
|