shan gao commited on
Commit
efb3827
·
1 Parent(s): 3496b64

modify streamit_app.py

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. requirements.txt +5 -1
  3. src/streamlit_app.py +230 -35
.DS_Store ADDED
Binary file (6.15 kB). View file
 
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ llama-index
5
+ llama-cloud-services
6
+ llama-index-llms-openai
7
+ llama-index-embeddings-openai
src/streamlit_app.py CHANGED
@@ -1,40 +1,235 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
  """
7
- # Welcome to Streamlit!
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Streamlit app for the RYBREVANT RAG chatbot.
3
+
4
+ - Builds one vector + summary tool per document (PI and brochure)
5
+ - Routes questions to the right tool with a FunctionAgent and ObjectIndex
6
+ - Surfaces concise answers with citations plus a safety disclaimer
7
 
8
+ Set environment variables before running:
9
+ - OPENAI_API_KEY
10
+ - LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion)
11
 
12
+ Run locally:
13
+ streamlit run rybrevant_streamlit.py
14
  """
15
 
16
+ import asyncio
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ import requests
22
+ import streamlit as st
23
+ from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex
24
+ from llama_index.core.agent.workflow import FunctionAgent
25
+ from llama_index.core.node_parser import SentenceSplitter
26
+ from llama_index.core.objects import ObjectIndex
27
+ from llama_index.core.tools import FunctionTool, QueryEngineTool
28
+ from llama_index.core.vector_stores import FilterCondition, MetadataFilters
29
+ from llama_index.embeddings.openai import OpenAIEmbedding
30
+ from llama_index.llms.openai import OpenAI
31
+
32
+
33
+ EU_BASE_URL = "https://api.cloud.eu.llamaindex.ai"
34
+ from llama_cloud_services import (
35
+ LlamaParse,
36
+ EU_BASE_URL,
37
+ )
38
+
39
+
40
+ DATA_DIR = Path("data")
41
+ DOC_SOURCES: Dict[str, Tuple[str, str]] = {
42
+ "rybrevant.pdf": (
43
+ "https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf",
44
+ "PI",
45
+ ),
46
+ "brochure.pdf": (
47
+ "https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf",
48
+ "brochure",
49
+ ),
50
+ }
51
+ BASE_SYSTEM_PROMPT = (
52
+ "You are an agent designed to answer queries over a set of RYBREVANT documents." \
53
+ "Please always use the tools provided to answer a question. Do not rely on prior knowledge." \
54
+ "When responding, keep answers concise, always mention the source: exact document + page "
55
+ "(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer "
56
+ "('Not medical advice; consult your healthcare professional')."
57
+ )
58
+
59
+
60
+ def _ensure_data_files() -> None:
61
+ """Download source PDFs if they are missing."""
62
+ DATA_DIR.mkdir(exist_ok=True)
63
+ for filename, (url, _) in DOC_SOURCES.items():
64
+ path = DATA_DIR / filename
65
+ if path.exists():
66
+ continue
67
+ resp = requests.get(url, timeout=60)
68
+ resp.raise_for_status()
69
+ path.write_bytes(resp.content)
70
+
71
+
72
+ def _build_tools_for_doc(file_path: Path, name: str):
73
+ """Create vector and summary tools for a single document."""
74
+ parser = LlamaParse(result_type="text", language="en", api_key=os.getenv("LLAMA_CLOUD_API_KEY"))
75
+ documents = SimpleDirectoryReader(
76
+ input_files=[str(file_path)],
77
+ file_extractor={".pdf": parser},
78
+ ).load_data()
79
+
80
+ for i, doc in enumerate(documents):
81
+ page = (
82
+ doc.metadata.get("page_label")
83
+ or doc.metadata.get("page")
84
+ or doc.metadata.get("page_number")
85
+ or doc.metadata.get("page_idx")
86
+ or i + 1
87
+ )
88
+ doc.metadata["page_label"] = page
89
+ doc.metadata["source"] = name
90
+
91
+ splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120)
92
+ nodes = splitter.get_nodes_from_documents(documents)
93
+ for node in nodes:
94
+ if "page_label" not in node.metadata:
95
+ node.metadata["page_label"] = (
96
+ node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx")
97
+ )
98
+ node.metadata["source"] = node.metadata.get("source") or name
99
+
100
+ embed_model = OpenAIEmbedding(model="text-embedding-3-large")
101
+ vector_index = VectorStoreIndex(nodes, embed_model=embed_model)
102
+
103
+ def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str:
104
+ """Grounded Q&A with optional page filters + citations.
105
+
106
+ Useful if you have specific questions over the document.
107
+ Always leave page_numbers as None UNLESS there is a specific page you want to search for.
108
+
109
+ Args:
110
+ query (str): the string query to be embedded.
111
+ page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE
112
+ if we want to perform a vector search
113
+ over all pages. Otherwise, filter by the set of specified pages.
114
+
115
+ """
116
+ page_numbers = page_numbers or []
117
+ metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers]
118
+ query_engine = vector_index.as_query_engine(
119
+ similarity_top_k=4,
120
+ filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR),
121
+ )
122
+ response = query_engine.query(query)
123
+
124
+ citations = []
125
+ for sn in response.source_nodes:
126
+ page = sn.node.metadata.get("page_label")
127
+ src = sn.node.metadata.get("source", name)
128
+ citations.append(f"{src} p.{page}" if page else src)
129
+ citations = list(dict.fromkeys(citations))
130
+
131
+ if citations:
132
+ return f"{response}\n\nSources: {', '.join(citations)}"
133
+ return str(response)
134
+
135
+ vector_tool = FunctionTool.from_defaults(
136
+ name=f"vector_tool_{name}",
137
+ fn=vector_query,
138
+ description=f"Vector search over {name}; responds with grounded answer + page citations.",
139
+ )
140
+
141
+ summary_index = SummaryIndex(nodes)
142
+ summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True)
143
+ summary_tool = QueryEngineTool.from_defaults(
144
+ name=f"summary_tool_{name}",
145
+ query_engine=summary_query_engine,
146
+ description=f"Useful for summarization questions related to {name}.",
147
+ )
148
+
149
+ return vector_tool, summary_tool
150
+
151
+
152
+ @st.cache_resource(show_spinner=False)
153
+ def build_agent():
154
+ """Build and cache the FunctionAgent with tool routing."""
155
+ _ensure_data_files()
156
+
157
+ tool_sets = []
158
+ for filename, (_, display_name) in DOC_SOURCES.items():
159
+ tools = _build_tools_for_doc(DATA_DIR / filename, display_name)
160
+ tool_sets.extend(tools)
161
+
162
+ obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex)
163
+ obj_retriever = obj_index.as_retriever(similarity_top_k=2)
164
+
165
+ llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
166
+ agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False)
167
+ return agent
168
+
169
+
170
+ def run_agent(agent: FunctionAgent, prompt: str) -> str:
171
+ """Run the agent synchronously inside Streamlit."""
172
+ handler = agent.run(prompt)
173
+ return str(asyncio.run(handler))
174
+
175
+
176
+ def _require_env(var_name: str) -> bool:
177
+ """Check that required environment variables are set; inform user if missing."""
178
+ if os.getenv(var_name):
179
+ return True
180
+ st.error(f"Missing environment variable: {var_name}")
181
+ return False
182
+
183
+
184
+ def main() -> None:
185
+ st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide")
186
+ st.title("RYBREVANT Q&A RAG")
187
+ st.write(
188
+ "Ask about the RYBREVANT prescribing information (PI) or patient brochure. "
189
+ "Responses stay grounded in the source documents and include page citations."
190
+ )
191
+
192
+ with st.sidebar:
193
+ st.header("About")
194
+ st.markdown(
195
+ "- Sources: PI and patient brochure\n"
196
+ "- Answers include citations and a safety disclaimer\n"
197
+ "- Data/parsing cached in this Space runtime"
198
+ )
199
+ st.divider()
200
+ st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.")
201
+
202
+ has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY")
203
+ if not has_keys:
204
+ st.stop()
205
+
206
+ agent = build_agent()
207
+
208
+ if "messages" not in st.session_state:
209
+ st.session_state.messages = []
210
+
211
+ for role, content in st.session_state.messages:
212
+ with st.chat_message(role):
213
+ st.markdown(content)
214
+
215
+ prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...")
216
+ if prompt:
217
+ st.session_state.messages.append(("user", prompt))
218
+ with st.chat_message("user"):
219
+ st.markdown(prompt)
220
+
221
+ with st.chat_message("assistant"):
222
+ with st.spinner("Grounding answer in the documents..."):
223
+ try:
224
+ response = run_agent(agent, prompt)
225
+ except Exception as exc: # pylint: disable=broad-except
226
+ st.error(f"Something went wrong: {exc}")
227
+ return
228
+ st.markdown(response)
229
+ st.session_state.messages.append(("assistant", response))
230
+
231
+ st.caption("Not medical advice; always consult a healthcare professional.")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ main()