Multi-Rag / src /MultiRag /utils /ingestion_utils.py
VashuTheGreat2's picture
Upload folder using huggingface_hub
5551822 verified
import os
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from utils.asyncHandler import asyncHandler
from src.MultiRag.constants import EMBEDDING_MODEL
from src.MultiRag.constants import EXCEPTED_FILE_TYPE, RETREIVER_DEFAULT_K
import logging
# ---------------- Embedding Model ----------------
class CompatibleEmbeddings(HuggingFaceEmbeddings):
def __call__(self, text: str):
return self.embed_query(text)
embedding_model = CompatibleEmbeddings(model=EMBEDDING_MODEL)
# ---------------- Document Fetcher ----------------
@asyncHandler
async def document_fetcher(docs: str = "data"):
# 1. Handle URL case
if docs.startswith("http://") or docs.startswith("https://"):
logging.info(f"Detected URL: {docs}. Loading via WebBaseLoader...")
from langchain_community.document_loaders import WebBaseLoader
try:
loader = WebBaseLoader(docs)
return loader.load()
except Exception as e:
logging.error(f"Failed to load URL {docs}: {e}")
return []
# 2. Handle Local File/Dir case
if not os.path.exists(docs):
logging.error(f"Docs path not found: {docs}")
return [] # Return empty instead of raising to prevent crash
if os.path.isfile(docs):
files = [os.path.basename(docs)]
docs_dir = os.path.dirname(docs) or "."
else:
files = os.listdir(docs)
docs_dir = docs
from langchain_community.document_loaders import TextLoader, PyPDFLoader
documents = []
for file in files:
file_path = os.path.join(docs_dir, file)
ext = file.split(".")[-1].lower()
try:
if ext == "txt":
loader = TextLoader(file_path, encoding="utf-8")
documents.extend(loader.load())
elif ext == "pdf":
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif ext == "docx":
from langchain_community.document_loaders import Docx2txtLoader
loader = Docx2txtLoader(file_path)
documents.extend(loader.load())
elif ext in ["png", "jpg", "jpeg"]:
import easyocr
from langchain_core.documents import Document
from src.MultiRag.utils.image_embedding import image_to_text
logging.info(f"Processing image {file} with EasyOCR and BLIP...")
# 1. Word-to-word transcript
reader = easyocr.Reader(['en'], gpu=False)
ocr_results = reader.readtext(file_path)
transcript = " ".join([res[1] for res in ocr_results])
# 2. Image caption
caption = await image_to_text(file_path)
logging.info(f"Image processed. Transcript length: {len(transcript)}")
documents.append(Document(
page_content=f"IMAGE TRANSCRIPT: {transcript}\n\nIMAGE DESCRIPTION: {caption}",
metadata={"source": file_path}
))
except Exception as e:
logging.error(f"Failed to load {file_path}: {e}")
if ext in ["png", "jpg", "jpeg"]:
from langchain_core.documents import Document
logging.info(f"Using fallback for image: {file_path}")
documents.append(Document(page_content=f"Image file: {file}\nNote: Word-to-word extraction failed.", metadata={"source": file_path}))
return documents
# ---------------- Chunking ----------------
@asyncHandler
async def chunking_documents(documents, chunk_size: int = 200, chunk_overlap: int = 0):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return splitter.split_documents(documents)
# ---------------- FAISS Vector Store ----------------
@asyncHandler
async def create_vector_store(path: str = "db", docs: str = "data"):
if os.path.exists(path) and os.path.exists(os.path.join(path, "index.faiss")):
try:
logging.info("Existing FAISS DB found. Loading...")
vectorstore = FAISS.load_local(path, embedding_model, allow_dangerous_deserialization=True)
return vectorstore
except Exception as e:
logging.warning(f"Failed to load existing FAISS DB: {e}. Creating new one.")
logging.info("Creating new FAISS DB...")
documents = await document_fetcher(docs=docs)
if not documents:
logging.warning(f"No documents found or failed to load any documents from {docs}. Skipping FAISS creation.")
return None
chunks = await chunking_documents(documents)
if not chunks:
logging.warning(f"No chunks created from documents in {docs}. Skipping FAISS creation.")
return None
vectorstore = FAISS.from_documents(
documents=chunks,
embedding=embedding_model
)
# Save locally
vectorstore.save_local(path)
return vectorstore
# ---------------- Retriever ----------------
@asyncHandler
async def create_retreiver(vectorstore, k: int = RETREIVER_DEFAULT_K):
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
return retriever
# ---------------- Get Raw Documents ----------------
async def get_documents(docs: str = "data") -> str:
documents = await document_fetcher(docs=docs)
text = "\n".join([doc.page_content for doc in documents])
return text