File size: 9,430 Bytes
2024ac0
efb3827
 
 
 
 
2024ac0
efb3827
 
 
2024ac0
efb3827
 
2024ac0
 
efb3827
 
 
 
 
 
 
d2aac9d
efb3827
 
 
 
 
 
 
 
 
1c9e654
efb3827
 
 
914651d
 
efb3827
 
 
914651d
efb3827
 
 
 
914651d
efb3827
 
 
d2aac9d
 
 
 
 
efb3827
 
 
 
 
 
914651d
efb3827
 
 
 
 
 
 
 
 
 
1c9e654
 
 
 
 
 
efb3827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c9e654
 
 
d2aac9d
 
efb3827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
914651d
efb3827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2c4c9
efb3827
 
 
 
914651d
efb3827
 
 
 
 
 
aff4b96
 
efb3827
aff4b96
 
 
 
 
 
 
 
 
 
 
 
efb3827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Streamlit app for the RYBREVANT RAG chatbot.

- Builds one vector + summary tool per document (PI and brochure)
- Routes questions to the right tool with a FunctionAgent and ObjectIndex
- Surfaces concise answers with citations plus a safety disclaimer

Set environment variables before running:
  - OPENAI_API_KEY
  - LLAMA_CLOUD_API_KEY (for LlamaParse PDF ingestion)

Run locally:
    streamlit run rybrevant_streamlit.py
"""

import asyncio
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import requests
import streamlit as st
from llama_index.core import SimpleDirectoryReader, SummaryIndex, VectorStoreIndex
from llama_index.core.agent.workflow import FunctionAgent
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.objects import ObjectIndex
from llama_index.core.tools import FunctionTool, QueryEngineTool
from llama_index.core.vector_stores import FilterCondition, MetadataFilters
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI


from llama_cloud_services import LlamaParse, EU_BASE_URL


DATA_DIR = Path("data")
DOC_SOURCES: Dict[str, Tuple[str, str, str]] = {
    # filename: (url, short_label, long_description)
    "rybrevant.pdf": (
        "https://www.jnjlabels.com/package-insert/product-monograph/prescribing-information/RYBREVANT-pi.pdf",
        "PI",
        "RYBREVANT prescribing information (official label; use for dosing, safety, administration)",
    ),
    "brochure.pdf": (
        "https://www.rybrevant.com/documents/RYBREVANT_Patient_Brochure_Digital.pdf",
        "brochure",
        "RYBREVANT patient brochure (patient-friendly overview; use for general awareness)",
    ),
}
BASE_SYSTEM_PROMPT = (
    "You are an agent designed to answer queries over a set of RYBREVANT documents." \
    "Please always use the tools provided to answer a question. Do not rely on prior knowledge." \
    "When responding, keep answers concise, always mention the source: exact document + page "
    "(e.g., 'PI p.12' or 'brochure p.5'), and end with a brief safety disclaimer "
    "('Not medical advice; consult your healthcare professional')."
)


def _ensure_data_files() -> None:
    """Download source PDFs if they are missing."""
    DATA_DIR.mkdir(exist_ok=True)
    for filename, (url, _, _) in DOC_SOURCES.items():
        path = DATA_DIR / filename
        if path.exists():
            continue
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        path.write_bytes(resp.content)


def _build_tools_for_doc(file_path: Path, name: str):
    """Create vector and summary tools for a single document."""
    parser = LlamaParse(
        result_type="text",
        language="en",
        api_key=os.getenv("LLAMA_CLOUD_API_KEY"),
        base_url=EU_BASE_URL,
    )
    documents = SimpleDirectoryReader(
        input_files=[str(file_path)],
        file_extractor={".pdf": parser},
    ).load_data()

    for i, doc in enumerate(documents):
        page = (
            doc.metadata.get("page_label")
            or doc.metadata.get("page")
            or doc.metadata.get("page_number")
            or doc.metadata.get("page_idx")
            or i + 1
        )
        doc.metadata["page_label"] = page
        doc.metadata["source"] = name

    splitter = SentenceSplitter(chunk_size=800, chunk_overlap=120)
    nodes = splitter.get_nodes_from_documents(documents)
    for node in nodes:
        if "page_label" not in node.metadata:
            node.metadata["page_label"] = (
                node.metadata.get("page") or node.metadata.get("page_number") or node.metadata.get("page_idx")
            )
        node.metadata["source"] = node.metadata.get("source") or name

    if not nodes:
        raise ValueError(f"No text nodes parsed from {file_path}. Check parser credentials or PDF availability.")

    embed_model = OpenAIEmbedding(model="text-embedding-3-large")
    vector_index = VectorStoreIndex(nodes, embed_model=embed_model)

    def vector_query(query: str, page_numbers: Optional[List[int]] = None) -> str:
        """Grounded Q&A with optional page filters + citations.

        Useful if you have specific questions over the document.
        Always leave page_numbers as None UNLESS there is a specific page you want to search for.

        Args:
            query (str): the string query to be embedded.
            page_numbers (Optional[List[int]]): Filter by set of pages. Leave as NONE
                if we want to perform a vector search
                over all pages. Otherwise, filter by the set of specified pages.

        """
        page_numbers = page_numbers or []
        metadata_dicts = [{"key": "page_label", "value": p} for p in page_numbers]
        query_engine = vector_index.as_query_engine(
            similarity_top_k=4,
            filters=MetadataFilters.from_dicts(metadata_dicts, condition=FilterCondition.OR),
        )
        response = query_engine.query(query)

        citations = []
        for sn in response.source_nodes:
            page = sn.node.metadata.get("page_label")
            src = sn.node.metadata.get("source", name)
            citations.append(f"{src} p.{page}" if page else src)
        citations = list(dict.fromkeys(citations))

        if citations:
            return f"{response}\n\nSources: {', '.join(citations)}"
        return str(response)

    vector_tool = FunctionTool.from_defaults(
        name=f"vector_tool_{name}",
        fn=vector_query,
        description=f"Vector search over {name}; responds with grounded answer + page citations. Primary source for {name}.",
    )

    summary_index = SummaryIndex(nodes)
    summary_query_engine = summary_index.as_query_engine(response_mode="tree_summarize", use_async=True)
    summary_tool = QueryEngineTool.from_defaults(
        name=f"summary_tool_{name}",
        query_engine=summary_query_engine,
        description=f"Useful for summarization questions related to {name}.",
    )

    return vector_tool, summary_tool


@st.cache_resource(show_spinner=False)
def build_agent():
    """Build and cache the FunctionAgent with tool routing."""
    _ensure_data_files()

    tool_sets = []
    for filename, (_, display_name, _long_desc) in DOC_SOURCES.items():
        tools = _build_tools_for_doc(DATA_DIR / filename, display_name)
        tool_sets.extend(tools)

    obj_index = ObjectIndex.from_objects(tool_sets, index_cls=VectorStoreIndex)
    obj_retriever = obj_index.as_retriever(similarity_top_k=4)

    llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
    agent = FunctionAgent(tool_retriever=obj_retriever, llm=llm, system_prompt=BASE_SYSTEM_PROMPT, verbose=False)
    return agent


async def _arun_agent(agent: FunctionAgent, prompt: str) -> str:
    """Await the agent workflow and return the stringified response."""
    handler = agent.run(prompt)
    return str(await handler)


def run_agent(agent: FunctionAgent, prompt: str) -> str:
    """Run the agent from Streamlit, whether or not an event loop is already running."""
    try:
        loop = asyncio.get_running_loop()
    except RuntimeError:
        return asyncio.run(_arun_agent(agent, prompt))
    else:
        return loop.run_until_complete(_arun_agent(agent, prompt))



def _require_env(var_name: str) -> bool:
    """Check that required environment variables are set; inform user if missing."""
    if os.getenv(var_name):
        return True
    st.error(f"Missing environment variable: {var_name}")
    return False


def main() -> None:
    st.set_page_config(page_title="RYBREVANT Q&A", page_icon="🩺", layout="wide")
    st.title("RYBREVANT Q&A RAG")
    st.write(
        "Ask about the RYBREVANT prescribing information (PI) or patient brochure. "
        "Responses stay grounded in the source documents and include page citations."
    )

    with st.sidebar:
        st.header("About")
        st.markdown(
            "- Sources: PI and patient brochure\n"
            "- Answers include citations and a safety disclaimer\n"
            "- Data/parsing cached in this Space runtime"
        )
        st.divider()
        st.markdown("Need to deploy? Push this app to Hugging Face Spaces with your API keys as secrets.")

    has_keys = _require_env("OPENAI_API_KEY") and _require_env("LLAMA_CLOUD_API_KEY")
    if not has_keys:
        st.stop()

    agent = build_agent()

    if "messages" not in st.session_state:
        st.session_state.messages = []

    for role, content in st.session_state.messages:
        with st.chat_message(role):
            st.markdown(content)

    prompt = st.chat_input("Ask a RYBREVANT question, e.g., dosing, administration, safety...")
    if prompt:
        st.session_state.messages.append(("user", prompt))
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            with st.spinner("Grounding answer in the documents..."):
                try:
                    response = run_agent(agent, prompt)
                except Exception as exc:  # pylint: disable=broad-except
                    st.error(f"Something went wrong: {exc}")
                    return
                st.markdown(response)
        st.session_state.messages.append(("assistant", response))

    st.caption("Not medical advice; always consult a healthcare professional.")


if __name__ == "__main__":
    main()