Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from utils.asyncHandler import asyncHandler | |
| from src.MultiRag.constants import EMBEDDING_MODEL | |
| from src.MultiRag.constants import EXCEPTED_FILE_TYPE, RETREIVER_DEFAULT_K | |
| import logging | |
| # ---------------- Embedding Model ---------------- | |
| class CompatibleEmbeddings(HuggingFaceEmbeddings): | |
| def __call__(self, text: str): | |
| return self.embed_query(text) | |
| embedding_model = CompatibleEmbeddings(model=EMBEDDING_MODEL) | |
| # ---------------- Document Fetcher ---------------- | |
| async def document_fetcher(docs: str = "data"): | |
| # 1. Handle URL case | |
| if docs.startswith("http://") or docs.startswith("https://"): | |
| logging.info(f"Detected URL: {docs}. Loading via WebBaseLoader...") | |
| from langchain_community.document_loaders import WebBaseLoader | |
| try: | |
| loader = WebBaseLoader(docs) | |
| return loader.load() | |
| except Exception as e: | |
| logging.error(f"Failed to load URL {docs}: {e}") | |
| return [] | |
| # 2. Handle Local File/Dir case | |
| if not os.path.exists(docs): | |
| logging.error(f"Docs path not found: {docs}") | |
| return [] # Return empty instead of raising to prevent crash | |
| if os.path.isfile(docs): | |
| files = [os.path.basename(docs)] | |
| docs_dir = os.path.dirname(docs) or "." | |
| else: | |
| files = os.listdir(docs) | |
| docs_dir = docs | |
| from langchain_community.document_loaders import TextLoader, PyPDFLoader | |
| documents = [] | |
| for file in files: | |
| file_path = os.path.join(docs_dir, file) | |
| ext = file.split(".")[-1].lower() | |
| try: | |
| if ext == "txt": | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| documents.extend(loader.load()) | |
| elif ext == "pdf": | |
| loader = PyPDFLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif ext == "docx": | |
| from langchain_community.document_loaders import Docx2txtLoader | |
| loader = Docx2txtLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif ext in ["png", "jpg", "jpeg"]: | |
| import easyocr | |
| from langchain_core.documents import Document | |
| from src.MultiRag.utils.image_embedding import image_to_text | |
| logging.info(f"Processing image {file} with EasyOCR and BLIP...") | |
| # 1. Word-to-word transcript | |
| reader = easyocr.Reader(['en'], gpu=False) | |
| ocr_results = reader.readtext(file_path) | |
| transcript = " ".join([res[1] for res in ocr_results]) | |
| # 2. Image caption | |
| caption = await image_to_text(file_path) | |
| logging.info(f"Image processed. Transcript length: {len(transcript)}") | |
| documents.append(Document( | |
| page_content=f"IMAGE TRANSCRIPT: {transcript}\n\nIMAGE DESCRIPTION: {caption}", | |
| metadata={"source": file_path} | |
| )) | |
| except Exception as e: | |
| logging.error(f"Failed to load {file_path}: {e}") | |
| if ext in ["png", "jpg", "jpeg"]: | |
| from langchain_core.documents import Document | |
| logging.info(f"Using fallback for image: {file_path}") | |
| documents.append(Document(page_content=f"Image file: {file}\nNote: Word-to-word extraction failed.", metadata={"source": file_path})) | |
| return documents | |
| # ---------------- Chunking ---------------- | |
| async def chunking_documents(documents, chunk_size: int = 200, chunk_overlap: int = 0): | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| ) | |
| return splitter.split_documents(documents) | |
| # ---------------- FAISS Vector Store ---------------- | |
| async def create_vector_store(path: str = "db", docs: str = "data"): | |
| if os.path.exists(path) and os.path.exists(os.path.join(path, "index.faiss")): | |
| try: | |
| logging.info("Existing FAISS DB found. Loading...") | |
| vectorstore = FAISS.load_local(path, embedding_model, allow_dangerous_deserialization=True) | |
| return vectorstore | |
| except Exception as e: | |
| logging.warning(f"Failed to load existing FAISS DB: {e}. Creating new one.") | |
| logging.info("Creating new FAISS DB...") | |
| documents = await document_fetcher(docs=docs) | |
| if not documents: | |
| logging.warning(f"No documents found or failed to load any documents from {docs}. Skipping FAISS creation.") | |
| return None | |
| chunks = await chunking_documents(documents) | |
| if not chunks: | |
| logging.warning(f"No chunks created from documents in {docs}. Skipping FAISS creation.") | |
| return None | |
| vectorstore = FAISS.from_documents( | |
| documents=chunks, | |
| embedding=embedding_model | |
| ) | |
| # Save locally | |
| vectorstore.save_local(path) | |
| return vectorstore | |
| # ---------------- Retriever ---------------- | |
| async def create_retreiver(vectorstore, k: int = RETREIVER_DEFAULT_K): | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": k}) | |
| return retriever | |
| # ---------------- Get Raw Documents ---------------- | |
| async def get_documents(docs: str = "data") -> str: | |
| documents = await document_fetcher(docs=docs) | |
| text = "\n".join([doc.page_content for doc in documents]) | |
| return text |