innofacisteven's picture
Update rag.py
1a8deeb verified
import os
import shutil # <-- Added this missing import
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
StorageContext,
load_index_from_storage,
Settings
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.llms import MockLLM
from llama_index.readers.file import PyMuPDFReader
# Set cache directory for Hugging Face models
os.environ["HF_HOME"] = "/tmp/hf"
# Configure Global Settings
# Using BAAI bge-small-zh-v1.5 for efficient Chinese/English embedding
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-zh-v1.5")
# Use a MockLLM since we are doing retrieval only; saves memory and avoids API calls
Settings.llm = MockLLM()
PERSIST_DIR = "./storage"
DATA_DIR = "./data"
def initialize_index():
"""Load existing index from disk or create a new one if data exists."""
if os.path.exists(os.path.join(PERSIST_DIR, "docstore.json")):
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
index = load_index_from_storage(storage_context)
else:
os.makedirs(DATA_DIR, exist_ok=True)
loader = SimpleDirectoryReader(
DATA_DIR,
recursive=True,
required_exts=[".pdf"],
file_extractor={".pdf": PyMuPDFReader()}
)
if os.listdir(DATA_DIR):
documents = loader.load_data()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
index = VectorStoreIndex.from_documents([])
return index
# Global variable to cache the engine
_cached_engine = None
def get_retriever():
"""Initialize and return a query engine optimized for retrieval only."""
global _cached_engine
if _cached_engine is None:
index = initialize_index()
_cached_engine = index.as_query_engine(
similarity_top_k=8,
response_mode="no_text" # Returns only source nodes, no LLM synthesis
)
return _cached_engine
def add_document(file_path: str):
"""Move file to data directory and rebuild the index."""
global _cached_engine
os.makedirs(DATA_DIR, exist_ok=True)
# Move the uploaded file to the persistent data folder
dest = os.path.join(DATA_DIR, os.path.basename(file_path))
# Ensure the destination doesn't exist or handle it (shutil.move requires this)
if os.path.exists(dest):
os.remove(dest)
shutil.move(file_path, dest)
# Reload all documents and rebuild index
documents = SimpleDirectoryReader(DATA_DIR).load_data()
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=PERSIST_DIR)
# Reset engine to force reload on next query
_cached_engine = None
return len(documents)