AI_Toolkit / src /rag_engine.py
NavyDevilDoc's picture
Update src/rag_engine.py
99043ee verified
raw
history blame
5.36 kB
import os
import logging
from typing import List, Literal
# LangChain imports for the Markdown logic
from langchain_core.documents import Document
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
# Custom Core Imports
from core.ParagraphChunker import ParagraphChunker
from core.TokenChunker import TokenChunker
# Configure Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _process_markdown(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 100) -> List[Document]:
"""
Internal helper to process Markdown files using Header Semantic Splitting.
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
markdown_text = f.read()
# Define headers to split on (Logic: Keep context attached to the section)
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
# Stage 1: Split by Structure (Headers)
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
md_header_splits = markdown_splitter.split_text(markdown_text)
# Stage 2: Split by Size (Recursively split long sections)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
final_docs = text_splitter.split_documents(md_header_splits)
# Add source metadata
for doc in final_docs:
doc.metadata['source'] = file_path
doc.metadata['file_type'] = 'md'
logger.info(f"Markdown processing complete: {len(final_docs)} chunks created.")
return final_docs
except Exception as e:
logger.error(f"Error processing Markdown file {file_path}: {e}")
return []
def process_file(
file_path: str,
chunking_strategy: Literal["paragraph", "token"] = "paragraph",
chunk_size: int = 512,
chunk_overlap: int = 50,
model_name: str = "gpt-4o" # Used for token counting in your custom classes
) -> List[Document]:
"""
Main entry point for processing a single file.
Routes to the correct custom chunker or markdown handler based on extension.
"""
if not os.path.exists(file_path):
logger.error(f"File not found: {file_path}")
return []
file_extension = os.path.splitext(file_path)[1].lower()
logger.info(f"Processing {file_path} using strategy: {chunking_strategy}")
# ---------------------------------------------------------
# 1. Handle Markdown (Specialized Logic)
# ---------------------------------------------------------
if file_extension == ".md":
return _process_markdown(file_path, chunk_size, chunk_overlap)
# ---------------------------------------------------------
# 2. Handle PDF and TXT (Custom Core Logic)
# ---------------------------------------------------------
elif file_extension in [".pdf", ".txt"]:
# Initialize the appropriate Custom Chunker
if chunking_strategy == "token":
chunker = TokenChunker(
model_name=model_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
else:
# Paragraph chunker relies on semantic boundaries, not strict sizes
chunker = ParagraphChunker(model_name=model_name)
# Process based on file type
try:
if file_extension == ".pdf":
# Uses OCREnhancedPDFLoader internally via BaseChunker
return chunker.process_document(file_path)
elif file_extension == ".txt":
# Uses direct text reading with paragraph preservation
return chunker.process_text_file(file_path)
except Exception as e:
logger.error(f"Error using {chunking_strategy} chunker on {file_path}: {e}")
return []
else:
logger.warning(f"Unsupported file extension: {file_extension}")
return []
def load_documents_from_directory(
directory_path: str,
chunking_strategy: Literal["paragraph", "token"] = "paragraph"
) -> List[Document]:
"""
Batch helper to process a directory of files.
"""
all_docs = []
for root, _, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
# Only process supported extensions
if file.lower().endswith(('.pdf', '.txt', '.md')):
docs = process_file(file_path, chunking_strategy=chunking_strategy)
all_docs.extend(docs)
return all_docs
def list_documents(username: str = "default") -> List[str]:
"""
Lists all supported documents for a specific user.
Adjust 'source_documents' if your folder is named differently.
"""
# Define your source directory (Update this path if you use a different one!)
base_dir = "source_documents"
user_dir = os.path.join(base_dir, username)
if not os.path.exists(user_dir):
return []
files = []
for f in os.listdir(user_dir):
if f.lower().endswith(('.pdf', '.txt', '.md')):
files.append(f)
return files