AI_Toolkit / src /rag_engine.py
NavyDevilDoc's picture
Update src/rag_engine.py
10a9f86 verified
import os
import shutil
import logging
from typing import List, Tuple, Optional
from huggingface_hub import snapshot_download
# LangChain Imports
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
# Internal Core Imports
from core.PineconeManager import PineconeManager
from core.AcronymManager import AcronymManager
from core.ChunkingManager import ChunkingManager
from flashrank import Ranker, RerankRequest
# CONFIGURATION
PINECONE_KEY = os.getenv("PINECONE_API_KEY")
UPLOAD_DIR = "source_documents"
logger = logging.getLogger(__name__)
# Initialize Reranker
try:
reranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/tmp/flashrank_cache")
except Exception as e:
logger.warning(f"Reranker failed to load: {e}")
reranker = None
def get_embedding_func(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
try:
# CHECK 1: OpenAI
if "openai" in model_name.lower() or "text-embedding" in model_name.lower():
if not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API Key not found.")
return OpenAIEmbeddings(model=model_name)
# CHECK 2: YOUR CUSTOM FINE-TUNE
elif "navy-custom-models" in model_name:
logger.info(f"Downloading custom model from: {model_name}")
parts = model_name.split("/")
repo_id = f"{parts[0]}/{parts[1]}"
folder_name = parts[2]
storage_path = snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=f"{folder_name}/*"
)
local_model_path = os.path.join(storage_path, folder_name)
# FIX: Explicitly set device to CPU to avoid meta-tensor errors
return HuggingFaceEmbeddings(
model_name=local_model_path,
model_kwargs={'device': 'cpu'}
)
# CHECK 3: Standard Public Models
else:
# FIX: Explicitly set device to CPU here as well
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'}
)
except Exception as e:
logger.error(f"Failed to load embedding model '{model_name}': {e}")
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
def save_uploaded_file(uploaded_file, username: str) -> str:
user_dir = os.path.join(UPLOAD_DIR, username)
os.makedirs(user_dir, exist_ok=True)
file_path = os.path.join(user_dir, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
return file_path
def process_file(file_path: str, chunking_strategy: str = "paragraph", embed_model_name: str = "all-mpnet-base-v2") -> List[Document]:
"""Delegates to ChunkingManager."""
try:
logger.info(f"Initializing ChunkingManager for {file_path} using {chunking_strategy}")
manager = ChunkingManager(embedding_model_name=embed_model_name)
chunks = manager.process_document(file_path=file_path, strategy=chunking_strategy, preprocess=True)
if isinstance(chunks, dict):
flat_chunks = []
for key, val in chunks.items():
if isinstance(val, list): flat_chunks.extend(val)
return flat_chunks
return chunks
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
return []
def ingest_file(file_path: str, username: str, index_name: str, embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2", strategy: str = "paragraph") -> Tuple[bool, str]:
if not PINECONE_KEY or not index_name: return False, "Pinecone Configuration Missing."
try:
# 1. Chunking
docs = process_file(file_path, chunking_strategy=strategy, embed_model_name=embed_model_name)
if not docs: return False, "No valid chunks generated."
# 2. METADATA SANITIZATION (The Fix for Pinecone IDs)
# We enforce that 'source' is just the filename, stripping the path.
clean_filename = os.path.basename(file_path)
for doc in docs:
doc.metadata["source"] = clean_filename
# Remove any absolute paths that might have leaked into metadata
if "file_path" in doc.metadata: del doc.metadata["file_path"]
# 3. Acronym Learning
acronym_mgr = AcronymManager()
for doc in docs:
acronym_mgr.scan_text_for_acronyms(doc.page_content)
# 4. Pinecone Manager
pm = PineconeManager(PINECONE_KEY)
# 5. SAFETY CHECK
emb_fn = get_embedding_func(embed_model_name)
test_vec = emb_fn.embed_query("test")
model_dim = len(test_vec)
if not pm.check_dimension_compatibility(index_name, model_dim):
return False, f"Dimension Mismatch! Index '{index_name}' expects {model_dim}d vectors."
# 6. PRE-EMPTIVE DELETE
pm.delete_file(index_name, clean_filename, namespace=username)
# 7. UPLOAD
vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
# Now IDs will be "filename.txt_0", "filename.txt_1" etc.
custom_ids = [f"{clean_filename}_{i}" for i, doc in enumerate(docs)]
vstore.add_documents(docs, ids=custom_ids)
return True, f"Successfully updated {clean_filename} ({len(docs)} chunks)."
except Exception as e:
logger.error(f"Ingestion failed: {e}")
return False, str(e)
def process_and_add_text(text: str, source_name: str, username: str, index_name: str, embed_model_name: str = None) -> Tuple[bool, str]:
if not PINECONE_KEY or not index_name: return False, "Pinecone Configuration Missing."
try:
pm = PineconeManager(PINECONE_KEY)
clean_source = os.path.basename(source_name)
# 1. DELETE OLD
pm.delete_file(index_name, clean_source, namespace=username)
# 2. BACKUP
user_docs_dir = os.path.join(UPLOAD_DIR, username)
os.makedirs(user_docs_dir, exist_ok=True)
backup_path = os.path.join(user_docs_dir, clean_source)
with open(backup_path, "w", encoding='utf-8') as f:
f.write(text)
# 3. CHUNK
manager = ChunkingManager(embedding_model_name=embed_model_name)
docs = manager.process_document(backup_path, strategy="paragraph", preprocess=True)
# 4. SANITIZE METADATA
for doc in docs:
doc.metadata["source"] = clean_source
doc.metadata["file_type"] = "generated"
doc.metadata["strategy"] = "flattened"
# 5. UPLOAD
emb_fn = get_embedding_func(embed_model_name)
vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
custom_ids = [f"{clean_source}_{i}" for i, _ in enumerate(docs)]
vstore.add_documents(docs, ids=custom_ids)
return True, f"Updated: {clean_source} ({len(docs)} chunks)"
except Exception as e:
logger.error(f"Error indexing text: {e}")
return False, str(e)
def search_knowledge_base(query: str, username: str, index_name: str, embed_model_name: str, k: int = 5, final_k: int = 5):
if not PINECONE_KEY or not index_name: return []
try:
pm = PineconeManager(PINECONE_KEY)
emb_fn = get_embedding_func(embed_model_name)
vstore = pm.get_vectorstore(index_name, emb_fn, namespace=username)
broad_k = final_k * 3
initial_docs = vstore.similarity_search(query, k=broad_k)
if not initial_docs or not reranker:
return initial_docs[:final_k]
passages = [{"id": str(i), "text": doc.page_content, "meta": doc.metadata} for i, doc in enumerate(initial_docs)]
rerank_request = RerankRequest(query=query, passages=passages)
ranked_results = reranker.rerank(rerank_request)
final_docs = []
for res in ranked_results[:final_k]:
meta = res.get("meta", {})
meta["rerank_score"] = res.get("score")
final_docs.append(Document(page_content=res["text"], metadata=meta))
return final_docs
except Exception as e:
logger.error(f"Search failed: {e}")
return []
def delete_document(username: str, filename: str, index_name: str):
user_dir = os.path.join(UPLOAD_DIR, username)
file_path = os.path.join(user_dir, filename)
if os.path.exists(file_path): os.remove(file_path)
if PINECONE_KEY and index_name:
try:
pm = PineconeManager(PINECONE_KEY)
pm.delete_file(index_name, filename, namespace=username)
except Exception as e:
logger.error(f"Pinecone delete failed: {e}")
def list_documents(username: str) -> List[dict]:
user_dir = os.path.join(UPLOAD_DIR, username)
if not os.path.exists(user_dir): return []
return [{"filename": f, "source": f} for f in os.listdir(user_dir) if f.lower().endswith(('.txt', '.md', '.pdf', '.docx'))]
def rebuild_cache_from_pinecone(username: str, index_name: str) -> Tuple[bool, str]:
if not PINECONE_KEY or not index_name: return False, "Pinecone config missing."
try:
pm = PineconeManager(PINECONE_KEY)
ids = pm.get_all_ids(index_name, username)
if not ids: return False, "No data found in Pinecone."
user_dir = os.path.join(UPLOAD_DIR, username)
# We wipe it clean first
if os.path.exists(user_dir): shutil.rmtree(user_dir)
os.makedirs(user_dir, exist_ok=True)
batch_size = 100
reconstructed_files = {}
for i in range(0, len(ids), batch_size):
batch_ids = ids[i : i + batch_size]
response = pm.fetch_vectors(index_name, batch_ids, username)
vectors = response.vectors
for vec_id, vec_data in vectors.items():
meta = vec_data.metadata or {}
# THE RESYNC FIX: Force basename to avoid "dir/dir/file" bugs
raw_source = meta.get('source', 'unknown.txt')
source = os.path.basename(raw_source)
text = meta.get('text') or meta.get('page_content') or ''
try:
if "_" in vec_id: chunk_index = int(vec_id.rsplit('_', 1)[-1])
else: chunk_index = 0
except ValueError: chunk_index = 0
if source not in reconstructed_files: reconstructed_files[source] = []
reconstructed_files[source].append((chunk_index, text))
count = 0
for filename, chunks in reconstructed_files.items():
chunks.sort(key=lambda x: x[0])
full_text = "\n\n".join([c[1] for c in chunks])
file_path = os.path.join(user_dir, filename)
with open(file_path, "w", encoding="utf-8") as f: f.write(full_text)
count += 1
return True, f"Restored {count} files from Pinecone!"
except Exception as e:
logger.error(f"Cache rebuild failed: {e}")
return False, str(e)