AI_Toolkit / src /rag_engine.py
NavyDevilDoc's picture
Update src/rag_engine.py
ae9a929 verified
raw
history blame
11.3 kB
import os
import shutil
import logging
from typing import List, Literal, Tuple
# --- LANGCHAIN & DB IMPORTS ---
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
# --- CUSTOM CORE IMPORTS ---
from core.ParagraphChunker import ParagraphChunker
from core.TokenChunker import TokenChunker
from core.AcronymManager import AcronymManager
# --- CONFIGURATION ---
CHROMA_PATH = "chroma_db"
UPLOAD_DIR = "source_documents"
EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
RERANK_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- LAZY LOADING GLOBALS ---
_embedding_func = None
_rerank_model = None
def get_embedding_func():
"""Lazy loads the embedding model to save startup resources."""
global _embedding_func
if _embedding_func is None:
logger.info(f"⏳ Loading Embedding Model: {EMBED_MODEL_NAME}...")
_embedding_func = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
logger.info("✅ Embedding Model Loaded.")
return _embedding_func
def get_rerank_model():
"""Lazy loads the Cross-Encoder model."""
global _rerank_model
if _rerank_model is None:
logger.info(f"⏳ Loading Reranker: {RERANK_MODEL_NAME}...")
_rerank_model = CrossEncoder(RERANK_MODEL_NAME)
logger.info("✅ Reranker Loaded.")
return _rerank_model
# --- PART 1: CHUNKING LOGIC (The New System) ---
def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]:
"""Internal helper to process Markdown files using Header Semantic Splitting."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
markdown_text = f.read()
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
md_header_splits = markdown_splitter.split_text(markdown_text)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
final_docs = text_splitter.split_documents(md_header_splits)
for doc in final_docs:
doc.metadata['source'] = os.path.basename(file_path)
doc.metadata['file_type'] = 'md'
doc.metadata['strategy'] = 'markdown_header'
return final_docs
except Exception as e:
logger.error(f"Error processing Markdown file {file_path}: {e}")
return []
def process_file(
file_path: str,
chunking_strategy: Literal["paragraph", "token"] = "paragraph",
chunk_size: int = 512,
chunk_overlap: int = 100,
model_name: str = "gpt-4o"
) -> List[Document]:
"""
Main chunking engine. Routes file to specific chunkers based on type/strategy.
"""
if not os.path.exists(file_path):
logger.error(f"File not found: {file_path}")
return []
file_extension = os.path.splitext(file_path)[1].lower()
file_name = os.path.basename(file_path)
logger.info(f"Processing {file_name} using strategy: {chunking_strategy}")
# 1. Handle Markdown
if file_extension == ".md":
return _process_markdown(file_path, chunk_size, chunk_overlap)
# 2. Handle PDF and TXT
elif file_extension in [".pdf", ".txt"]:
if chunking_strategy == "token":
chunker = TokenChunker(
model_name=model_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
else:
chunker = ParagraphChunker(model_name=model_name)
try:
if file_extension == ".pdf":
docs = chunker.process_document(file_path)
elif file_extension == ".txt":
docs = chunker.process_text_file(file_path)
# Ensure metadata consistency
for doc in docs:
doc.metadata["source"] = file_name
doc.metadata["strategy"] = chunking_strategy
return docs
except Exception as e:
logger.error(f"Error using {chunking_strategy} chunker on {file_name}: {e}")
return []
else:
logger.warning(f"Unsupported file extension: {file_extension}")
return []
# --- PART 2: DATABASE & FILE MANAGEMENT (The Old Stable System) ---
def save_uploaded_file(uploaded_file, username: str = "default") -> str:
"""Saves a StreamlitUploadedFile to disk so the loaders can read it."""
try:
user_dir = os.path.join(UPLOAD_DIR, username)
os.makedirs(user_dir, exist_ok=True)
file_path = os.path.join(user_dir, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
logger.info(f"File saved: {file_path}")
return file_path
except Exception as e:
logger.error(f"Error saving file: {e}")
return None
def process_and_add_text(text: str, source_name: str, username: str) -> Tuple[bool, str]:
"""
Ingests raw text string (e.g., from the Flattener tool) directly into Chroma.
"""
try:
user_db_path = os.path.join(CHROMA_PATH, username)
emb_fn = get_embedding_func()
# Initialize DB
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
# Create a Document object directly
doc = Document(
page_content=text,
metadata={
"source": source_name,
"strategy": "flattened_text",
"file_type": "generated"
}
)
# Add single document
db.add_documents([doc])
return True, f"Successfully indexed flattened text: {source_name}"
except Exception as e:
logger.error(f"Error indexing raw text: {e}")
return False, f"Error: {str(e)}"
def ingest_file(file_path: str, username: str, strategy: str = "paragraph") -> Tuple[bool, str]:
try:
# 1. Chunk the file
docs = process_file(file_path, chunking_strategy=strategy)
if not docs:
return False, "No valid chunks generated from file."
# --- ACRONYM SCANNING ---
# We scan the raw text of the chunks to learn new definitions
acronym_mgr = AcronymManager()
for doc in docs:
acronym_mgr.scan_text_for_acronyms(doc.page_content)
# -----------------------------
# 2. Add to Chroma DB
user_db_path = os.path.join(CHROMA_PATH, username)
emb_fn = get_embedding_func()
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
db.add_documents(docs)
return True, f"Successfully indexed {len(docs)} chunks."
except Exception as e:
logger.error(f"Ingestion failed: {e}")
return False, f"System Error: {str(e)}"
def search_knowledge_base(query: str, username: str, k: int = 10, final_k: int = 4) -> List[Document]:
user_db_path = os.path.join(CHROMA_PATH, username)
if not os.path.exists(user_db_path):
return []
try:
# --- NEW: QUERY EXPANSION ---
acronym_mgr = AcronymManager()
expanded_query = acronym_mgr.expand_query(query)
if expanded_query != query:
logger.info(f"Query Expanded: '{query}' -> '{expanded_query}'")
else:
expanded_query = query
# ----------------------------
# 1. Vector Retrieval (Use expanded_query instead of query)
emb_fn = get_embedding_func()
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
results = db.similarity_search_with_relevance_scores(expanded_query, k=k) # <--- UPDATED VAR
if not results:
return []
# 2. Reranking (Pass expanded_query here too)
candidate_docs = [doc for doc, _ in results]
candidate_texts = [doc.page_content for doc in candidate_docs]
pairs = [[expanded_query, text] for text in candidate_texts] # <--- UPDATED VAR
reranker = get_rerank_model()
scores = reranker.predict(pairs)
# Sort by new score
scored_docs = list(zip(candidate_docs, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return [doc for doc, score in scored_docs[:final_k]]
except Exception as e:
logger.error(f"Search Error: {e}")
return []
def list_documents(username: str) -> List[dict]:
"""
Returns a list of unique files currently in the vector database.
(Used for the sidebar list)
"""
user_db_path = os.path.join(CHROMA_PATH, username)
if not os.path.exists(user_db_path):
return []
try:
emb_fn = get_embedding_func()
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
# Chroma's .get() returns all metadata
data = db.get()
metadatas = data['metadatas']
inventory = {}
for m in metadatas:
# Metadata keys might differ slightly, handle gracefully
src = m.get('source', 'Unknown')
if src not in inventory:
inventory[src] = {
"chunks": 0,
"strategy": m.get('strategy', 'unknown')
}
inventory[src]["chunks"] += 1
# FIXED: Added "source": k to the dictionary below
return [
{"filename": k, "chunks": v["chunks"], "strategy": v["strategy"], "source": k}
for k, v in inventory.items()
]
except Exception as e:
logger.error(f"Error listing docs: {e}")
return []
def delete_document(username: str, filename: str) -> Tuple[bool, str]:
"""Removes a document from the vector database."""
user_db_path = os.path.join(CHROMA_PATH, username)
try:
emb_fn = get_embedding_func()
db = Chroma(persist_directory=user_db_path, embedding_function=emb_fn)
data = db.get()
ids_to_delete = []
for i, meta in enumerate(data['metadatas']):
if meta.get('source') == filename:
ids_to_delete.append(data['ids'][i])
if ids_to_delete:
db.delete(ids=ids_to_delete)
return True, f"Deleted {filename}."
else:
return False, "File not found in index."
except Exception as e:
return False, f"Delete failed: {e}"
def reset_knowledge_base(username: str) -> Tuple[bool, str]:
"""Nukes the user's database folder."""
user_db_path = os.path.join(CHROMA_PATH, username)
if os.path.exists(user_db_path):
shutil.rmtree(user_db_path)
return True, "Database Reset."
return False, "Database already empty."