Spaces:
Sleeping
Sleeping
File size: 5,655 Bytes
1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 1f725d8 5551822 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from utils.asyncHandler import asyncHandler
from src.MultiRag.constants import EMBEDDING_MODEL
from src.MultiRag.constants import EXCEPTED_FILE_TYPE, RETREIVER_DEFAULT_K
import logging
# ---------------- Embedding Model ----------------
class CompatibleEmbeddings(HuggingFaceEmbeddings):
def __call__(self, text: str):
return self.embed_query(text)
embedding_model = CompatibleEmbeddings(model=EMBEDDING_MODEL)
# ---------------- Document Fetcher ----------------
@asyncHandler
async def document_fetcher(docs: str = "data"):
# 1. Handle URL case
if docs.startswith("http://") or docs.startswith("https://"):
logging.info(f"Detected URL: {docs}. Loading via WebBaseLoader...")
from langchain_community.document_loaders import WebBaseLoader
try:
loader = WebBaseLoader(docs)
return loader.load()
except Exception as e:
logging.error(f"Failed to load URL {docs}: {e}")
return []
# 2. Handle Local File/Dir case
if not os.path.exists(docs):
logging.error(f"Docs path not found: {docs}")
return [] # Return empty instead of raising to prevent crash
if os.path.isfile(docs):
files = [os.path.basename(docs)]
docs_dir = os.path.dirname(docs) or "."
else:
files = os.listdir(docs)
docs_dir = docs
from langchain_community.document_loaders import TextLoader, PyPDFLoader
documents = []
for file in files:
file_path = os.path.join(docs_dir, file)
ext = file.split(".")[-1].lower()
try:
if ext == "txt":
loader = TextLoader(file_path, encoding="utf-8")
documents.extend(loader.load())
elif ext == "pdf":
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif ext == "docx":
from langchain_community.document_loaders import Docx2txtLoader
loader = Docx2txtLoader(file_path)
documents.extend(loader.load())
elif ext in ["png", "jpg", "jpeg"]:
import easyocr
from langchain_core.documents import Document
from src.MultiRag.utils.image_embedding import image_to_text
logging.info(f"Processing image {file} with EasyOCR and BLIP...")
# 1. Word-to-word transcript
reader = easyocr.Reader(['en'], gpu=False)
ocr_results = reader.readtext(file_path)
transcript = " ".join([res[1] for res in ocr_results])
# 2. Image caption
caption = await image_to_text(file_path)
logging.info(f"Image processed. Transcript length: {len(transcript)}")
documents.append(Document(
page_content=f"IMAGE TRANSCRIPT: {transcript}\n\nIMAGE DESCRIPTION: {caption}",
metadata={"source": file_path}
))
except Exception as e:
logging.error(f"Failed to load {file_path}: {e}")
if ext in ["png", "jpg", "jpeg"]:
from langchain_core.documents import Document
logging.info(f"Using fallback for image: {file_path}")
documents.append(Document(page_content=f"Image file: {file}\nNote: Word-to-word extraction failed.", metadata={"source": file_path}))
return documents
# ---------------- Chunking ----------------
@asyncHandler
async def chunking_documents(documents, chunk_size: int = 200, chunk_overlap: int = 0):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return splitter.split_documents(documents)
# ---------------- FAISS Vector Store ----------------
@asyncHandler
async def create_vector_store(path: str = "db", docs: str = "data"):
if os.path.exists(path) and os.path.exists(os.path.join(path, "index.faiss")):
try:
logging.info("Existing FAISS DB found. Loading...")
vectorstore = FAISS.load_local(path, embedding_model, allow_dangerous_deserialization=True)
return vectorstore
except Exception as e:
logging.warning(f"Failed to load existing FAISS DB: {e}. Creating new one.")
logging.info("Creating new FAISS DB...")
documents = await document_fetcher(docs=docs)
if not documents:
logging.warning(f"No documents found or failed to load any documents from {docs}. Skipping FAISS creation.")
return None
chunks = await chunking_documents(documents)
if not chunks:
logging.warning(f"No chunks created from documents in {docs}. Skipping FAISS creation.")
return None
vectorstore = FAISS.from_documents(
documents=chunks,
embedding=embedding_model
)
# Save locally
vectorstore.save_local(path)
return vectorstore
# ---------------- Retriever ----------------
@asyncHandler
async def create_retreiver(vectorstore, k: int = RETREIVER_DEFAULT_K):
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
return retriever
# ---------------- Get Raw Documents ----------------
async def get_documents(docs: str = "data") -> str:
documents = await document_fetcher(docs=docs)
text = "\n".join([doc.page_content for doc in documents])
return text |