|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
from dotenv import load_dotenv |
|
|
from huggingface_hub import snapshot_download |
|
|
import os |
|
|
|
|
|
|
|
|
LLAMA_INDEX_DATASET_ID = os.getenv("HF_INDEX_DATASET_ID", "alperensn/llamaIndexVectorBase_fda") |
|
|
LLAMA_INDEX_SUBDIR = os.getenv("HF_INDEX_SUBDIR", "").strip() |
|
|
|
|
|
|
|
|
MARKERS_CLASSIC = {"index_store.json", "docstore.json", "graph_store.json", "default__vector_store.json", "image__vector_store.json" } |
|
|
MARKERS_DEFAULT = {"default__index_store.json", "default__docstore.json", "default__vector_store.json", "default_image__vector_store.json","default__graph_store.json"} |
|
|
|
|
|
|
|
|
def _persist_path(base_dir: str) -> str: |
|
|
return os.path.join(base_dir, LLAMA_INDEX_SUBDIR) if LLAMA_INDEX_SUBDIR else base_dir |
|
|
|
|
|
def llama_index_exists(base_dir: str) -> bool: |
|
|
path = _persist_path(base_dir) |
|
|
if not os.path.isdir(path): |
|
|
return False |
|
|
files = set(os.listdir(path)) |
|
|
return (MARKERS_CLASSIC.issubset(files) or MARKERS_DEFAULT.issubset(files)) |
|
|
|
|
|
|
|
|
def download_llama_index_if_needed(base_dir: str): |
|
|
path = _persist_path(base_dir) |
|
|
os.makedirs(path, exist_ok=True) |
|
|
if llama_index_exists(base_dir): |
|
|
return |
|
|
snapshot_download( |
|
|
repo_id=LLAMA_INDEX_DATASET_ID, |
|
|
repo_type="dataset", |
|
|
local_dir=path, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
|
|
|
def find_llama_index_dir(base_dir: str) -> str: |
|
|
wanted_sets = [MARKERS_CLASSIC, MARKERS_DEFAULT] |
|
|
if os.path.isdir(base_dir): |
|
|
files = set(os.listdir(base_dir)) |
|
|
if any(ws.issubset(files) for ws in wanted_sets): |
|
|
return base_dir |
|
|
for root, _, files in os.walk(base_dir): |
|
|
files = set(files) |
|
|
if any(ws.issubset(files) for ws in wanted_sets): |
|
|
return root |
|
|
return base_dir |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
import config |
|
|
import rag_pipeline |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="PharmaBot", |
|
|
page_icon="🤖", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded", |
|
|
) |
|
|
|
|
|
|
|
|
def initialize_state(): |
|
|
"""Initializes session state variables.""" |
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [{"role": "assistant", "content": "Welcome to PharmaBot! How can I help you today?"}] |
|
|
if "query_engine" not in st.session_state: |
|
|
st.session_state.query_engine = None |
|
|
if "initialized" not in st.session_state: |
|
|
st.session_state.initialized = False |
|
|
|
|
|
|
|
|
def setup_sidebar(): |
|
|
"""Sets up the sidebar with app information.""" |
|
|
with st.sidebar: |
|
|
st.header("About PharmaBot") |
|
|
st.info( |
|
|
"PharmaBot is an AI assistant designed to answer questions about " |
|
|
"pharmaceuticals based on a knowledge base of RAG documents. " |
|
|
"It uses a Retrieval-Augmented Generation (RAG) pipeline to provide accurate, " |
|
|
"context-aware answers." |
|
|
) |
|
|
st.warning("**Disclaimer: I am an AI assistant, not a medical professional. This information is for educational purposes only. Please consult with a qualified healthcare provider for any health concerns or before making any medical decisions.**" |
|
|
) |
|
|
st.markdown("---") |
|
|
st.header("Technical Details") |
|
|
st.markdown( |
|
|
f""" |
|
|
- **LLM Model:** `{config.LLM_MODEL_ID}` |
|
|
- **Embedding Model:** `{config.EMBEDDING_MODEL_NAME}` |
|
|
- **Vector Type:** `LLama Index Vector Store` |
|
|
- **Vector Store:** `{config.VECTOR_STORE_PATH}` |
|
|
""" |
|
|
) |
|
|
|
|
|
def display_chat_history(): |
|
|
"""Displays the chat history.""" |
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.write(message["content"]) |
|
|
|
|
|
def handle_user_input(chat_engine): |
|
|
"""Handles user input and displays the response.""" |
|
|
if prompt := st.chat_input("Ask me anything about pharmaceuticals..."): |
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.write(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
with st.spinner("Thinking..."): |
|
|
response = chat_engine.chat(prompt) |
|
|
st.write(str(response)) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": str(response)}) |
|
|
|
|
|
import time |
|
|
from build_knowledge_base import build_vector_store |
|
|
import os |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run the Streamlit app.""" |
|
|
st.set_page_config(page_title="PharmaBot Assistant", page_icon="💊") |
|
|
initialize_state() |
|
|
st.title("💊 PharmaBot: Your AI Pharmaceutical Assistant") |
|
|
setup_sidebar() |
|
|
|
|
|
|
|
|
if not st.session_state.initialized: |
|
|
|
|
|
|
|
|
if not llama_index_exists(config.LLAMA_INDEX_STORE_PATH): |
|
|
with st.status("Knowledge base not found locally. Downloading from dataset...", expanded=True) as status: |
|
|
try: |
|
|
status.write(f"Downloading persisted index from: {LLAMA_INDEX_DATASET_ID}") |
|
|
download_llama_index_if_needed(config.LLAMA_INDEX_STORE_PATH) |
|
|
detected_dir = find_llama_index_dir(config.LLAMA_INDEX_STORE_PATH) |
|
|
if detected_dir != config.LLAMA_INDEX_STORE_PATH: |
|
|
config.LLAMA_INDEX_STORE_PATH = detected_dir |
|
|
status.update(label="Index downloaded from dataset.", state="complete", expanded=False) |
|
|
time.sleep(1) |
|
|
except Exception as e: |
|
|
status.update(label="Dataset download failed. Falling back to local build...", state="running", expanded=True) |
|
|
try: |
|
|
status.write("This is a one-time setup and may take a few minutes...") |
|
|
build_vector_store() |
|
|
status.update(label="Knowledge base built successfully!", state="complete", expanded=False) |
|
|
time.sleep(1) |
|
|
except Exception as be: |
|
|
status.update(label="Build Failed", state="error", expanded=True) |
|
|
st.error(f"An error occurred while preparing the knowledge base:\n- dataset error: {e}\n- build error: {be}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
with st.status("Initializing the RAG pipeline...", expanded=True) as status: |
|
|
try: |
|
|
status.write("Step 1/3: Initializing LLM and embedding models...") |
|
|
rag_pipeline.initialize_llm_and_embed_model() |
|
|
|
|
|
status.write("Step 2/3: Loading vector index from storage...") |
|
|
index = rag_pipeline.load_vector_index() |
|
|
|
|
|
status.write("Step 3/3: Building the conversational chat engine...") |
|
|
st.session_state.query_engine = rag_pipeline.build_query_engine(index) |
|
|
|
|
|
st.session_state.initialized = True |
|
|
status.update(label="Initialization Complete!", state="complete", expanded=False) |
|
|
time.sleep(1) |
|
|
except Exception as e: |
|
|
status.update(label="Initialization Failed", state="error") |
|
|
st.error(f"An unexpected error occurred during initialization: {e}") |
|
|
return |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.initialized: |
|
|
display_chat_history() |
|
|
handle_user_input(st.session_state.query_engine) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|