PharmaBot / app.py
alperensn's picture
Update app.py
59992e2 verified
# =================================================================================
# app.py: Main application file for the Streamlit web interface
# =================================================================================
import streamlit as st
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
import os
# === LlamaIndex persist kontrolü ve dataset'ten indirme yardımcıları ===
LLAMA_INDEX_DATASET_ID = os.getenv("HF_INDEX_DATASET_ID", "alperensn/llamaIndexVectorBase_fda")
LLAMA_INDEX_SUBDIR = os.getenv("HF_INDEX_SUBDIR", "").strip() # dataset içinde alt klasör kullanıyorsan burada belirt
# Eski ve yeni (default__) adlandırmaları birlikte kontrol edelim
MARKERS_CLASSIC = {"index_store.json", "docstore.json", "graph_store.json", "default__vector_store.json", "image__vector_store.json" }
MARKERS_DEFAULT = {"default__index_store.json", "default__docstore.json", "default__vector_store.json", "default_image__vector_store.json","default__graph_store.json"}
def _persist_path(base_dir: str) -> str:
return os.path.join(base_dir, LLAMA_INDEX_SUBDIR) if LLAMA_INDEX_SUBDIR else base_dir
def llama_index_exists(base_dir: str) -> bool:
path = _persist_path(base_dir)
if not os.path.isdir(path):
return False
files = set(os.listdir(path))
return (MARKERS_CLASSIC.issubset(files) or MARKERS_DEFAULT.issubset(files))
def download_llama_index_if_needed(base_dir: str):
path = _persist_path(base_dir)
os.makedirs(path, exist_ok=True)
if llama_index_exists(base_dir):
return
snapshot_download(
repo_id=LLAMA_INDEX_DATASET_ID,
repo_type="dataset",
local_dir=path,
local_dir_use_symlinks=False,
)
# İndirilen LlamaIndex persist klasörünü altlarda aramak gerekirse:
def find_llama_index_dir(base_dir: str) -> str:
wanted_sets = [MARKERS_CLASSIC, MARKERS_DEFAULT]
if os.path.isdir(base_dir):
files = set(os.listdir(base_dir))
if any(ws.issubset(files) for ws in wanted_sets):
return base_dir
for root, _, files in os.walk(base_dir):
files = set(files)
if any(ws.issubset(files) for ws in wanted_sets):
return root
return base_dir
# Load environment variables from .env file
load_dotenv()
# Import the modules we've created
import config
import rag_pipeline # Now using the LlamaIndex pipeline
# --- Page Configuration ---
st.set_page_config(
page_title="PharmaBot",
page_icon="🤖",
layout="wide",
initial_sidebar_state="expanded",
)
# --- State Management ---
def initialize_state():
"""Initializes session state variables."""
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Welcome to PharmaBot! How can I help you today?"}]
if "query_engine" not in st.session_state:
st.session_state.query_engine = None
if "initialized" not in st.session_state:
st.session_state.initialized = False
# --- UI Components ---
def setup_sidebar():
"""Sets up the sidebar with app information."""
with st.sidebar:
st.header("About PharmaBot")
st.info(
"PharmaBot is an AI assistant designed to answer questions about "
"pharmaceuticals based on a knowledge base of RAG documents. "
"It uses a Retrieval-Augmented Generation (RAG) pipeline to provide accurate, "
"context-aware answers."
)
st.warning("**Disclaimer: I am an AI assistant, not a medical professional. This information is for educational purposes only. Please consult with a qualified healthcare provider for any health concerns or before making any medical decisions.**"
)
st.markdown("---")
st.header("Technical Details")
st.markdown(
f"""
- **LLM Model:** `{config.LLM_MODEL_ID}`
- **Embedding Model:** `{config.EMBEDDING_MODEL_NAME}`
- **Vector Type:** `LLama Index Vector Store`
- **Vector Store:** `{config.VECTOR_STORE_PATH}`
"""
)
def display_chat_history():
"""Displays the chat history."""
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
def handle_user_input(chat_engine):
"""Handles user input and displays the response."""
if prompt := st.chat_input("Ask me anything about pharmaceuticals..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = chat_engine.chat(prompt)
st.write(str(response))
st.session_state.messages.append({"role": "assistant", "content": str(response)})
import time
from build_knowledge_base import build_vector_store
import os
# --- Main Application Logic ---
def main():
"""Main function to run the Streamlit app."""
st.set_page_config(page_title="PharmaBot Assistant", page_icon="💊")
initialize_state()
st.title("💊 PharmaBot: Your AI Pharmaceutical Assistant")
setup_sidebar()
# Initialize the RAG pipeline if it hasn't been already
if not st.session_state.initialized:
# 1) Önce dataset'ten persist edilmiş index'i yerel klasöre indirmeyi dene
if not llama_index_exists(config.LLAMA_INDEX_STORE_PATH):
with st.status("Knowledge base not found locally. Downloading from dataset...", expanded=True) as status:
try:
status.write(f"Downloading persisted index from: {LLAMA_INDEX_DATASET_ID}")
download_llama_index_if_needed(config.LLAMA_INDEX_STORE_PATH)
detected_dir = find_llama_index_dir(config.LLAMA_INDEX_STORE_PATH)
if detected_dir != config.LLAMA_INDEX_STORE_PATH:
config.LLAMA_INDEX_STORE_PATH = detected_dir
status.update(label="Index downloaded from dataset.", state="complete", expanded=False)
time.sleep(1)
except Exception as e:
status.update(label="Dataset download failed. Falling back to local build...", state="running", expanded=True)
try:
status.write("This is a one-time setup and may take a few minutes...")
build_vector_store()
status.update(label="Knowledge base built successfully!", state="complete", expanded=False)
time.sleep(1)
except Exception as be:
status.update(label="Build Failed", state="error", expanded=True)
st.error(f"An error occurred while preparing the knowledge base:\n- dataset error: {e}\n- build error: {be}")
st.stop()
# 2) RAG pipeline init
with st.status("Initializing the RAG pipeline...", expanded=True) as status:
try:
status.write("Step 1/3: Initializing LLM and embedding models...")
rag_pipeline.initialize_llm_and_embed_model()
status.write("Step 2/3: Loading vector index from storage...")
index = rag_pipeline.load_vector_index() # load_vector_index() zaten config.LLAMA_INDEX_STORE_PATH'ten okuyorsa değişiklik gerekmez
status.write("Step 3/3: Building the conversational chat engine...")
st.session_state.query_engine = rag_pipeline.build_query_engine(index)
st.session_state.initialized = True
status.update(label="Initialization Complete!", state="complete", expanded=False)
time.sleep(1)
except Exception as e:
status.update(label="Initialization Failed", state="error")
st.error(f"An unexpected error occurred during initialization: {e}")
return
st.rerun()
# Display chat and handle input if initialized
if st.session_state.initialized:
display_chat_history()
handle_user_input(st.session_state.query_engine)
if __name__ == "__main__":
main()