File size: 5,655 Bytes
1f725d8
 
5551822
1f725d8
 
 
 
5551822
1f725d8
 
 
5551822
 
 
1f725d8
5551822
1f725d8
5551822
1f725d8
 
5551822
 
 
 
 
 
 
 
 
 
1f725d8
5551822
1f725d8
5551822
 
1f725d8
5551822
 
 
 
 
 
1f725d8
 
 
 
 
5551822
1f725d8
 
 
 
 
 
 
 
 
 
 
5551822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f725d8
 
 
5551822
 
 
 
1f725d8
 
 
 
 
 
 
 
 
 
 
5551822
1f725d8
 
5551822
1f725d8
5551822
 
 
 
 
 
 
 
 
 
 
1f725d8
5551822
 
 
 
1f725d8
5551822
 
 
 
 
1f725d8
5551822
1f725d8
 
5551822
 
 
1f725d8
 
5551822
 
1f725d8
 
 
 
 
 
5551822
 
1f725d8
5551822
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os

from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from utils.asyncHandler import asyncHandler
from src.MultiRag.constants import EMBEDDING_MODEL
from src.MultiRag.constants import EXCEPTED_FILE_TYPE, RETREIVER_DEFAULT_K
import logging

# ---------------- Embedding Model ----------------
class CompatibleEmbeddings(HuggingFaceEmbeddings):
    def __call__(self, text: str):
        return self.embed_query(text)

embedding_model = CompatibleEmbeddings(model=EMBEDDING_MODEL)

# ---------------- Document Fetcher ----------------
@asyncHandler
async def document_fetcher(docs: str = "data"):
    # 1. Handle URL case
    if docs.startswith("http://") or docs.startswith("https://"):
        logging.info(f"Detected URL: {docs}. Loading via WebBaseLoader...")
        from langchain_community.document_loaders import WebBaseLoader
        try:
            loader = WebBaseLoader(docs)
            return loader.load()
        except Exception as e:
            logging.error(f"Failed to load URL {docs}: {e}")
            return []

    # 2. Handle Local File/Dir case
    if not os.path.exists(docs):
        logging.error(f"Docs path not found: {docs}")
        return [] # Return empty instead of raising to prevent crash

    if os.path.isfile(docs):
        files = [os.path.basename(docs)]
        docs_dir = os.path.dirname(docs) or "."
    else:
        files = os.listdir(docs)
        docs_dir = docs

    from langchain_community.document_loaders import TextLoader, PyPDFLoader

    documents = []
    for file in files:
        file_path = os.path.join(docs_dir, file)
        ext = file.split(".")[-1].lower()

        try:
            if ext == "txt":
                loader = TextLoader(file_path, encoding="utf-8")
                documents.extend(loader.load())

            elif ext == "pdf":
                loader = PyPDFLoader(file_path)
                documents.extend(loader.load())

            elif ext == "docx":
                from langchain_community.document_loaders import Docx2txtLoader
                loader = Docx2txtLoader(file_path)
                documents.extend(loader.load())

            elif ext in ["png", "jpg", "jpeg"]:
                import easyocr
                from langchain_core.documents import Document
                from src.MultiRag.utils.image_embedding import image_to_text
                
                logging.info(f"Processing image {file} with EasyOCR and BLIP...")
                
                # 1. Word-to-word transcript
                reader = easyocr.Reader(['en'], gpu=False)
                ocr_results = reader.readtext(file_path)
                transcript = " ".join([res[1] for res in ocr_results])
                
                # 2. Image caption
                caption = await image_to_text(file_path)
                
                logging.info(f"Image processed. Transcript length: {len(transcript)}")
                documents.append(Document(
                    page_content=f"IMAGE TRANSCRIPT: {transcript}\n\nIMAGE DESCRIPTION: {caption}",
                    metadata={"source": file_path}
                ))

        except Exception as e:
            logging.error(f"Failed to load {file_path}: {e}")
            if ext in ["png", "jpg", "jpeg"]:
                from langchain_core.documents import Document
                logging.info(f"Using fallback for image: {file_path}")
                documents.append(Document(page_content=f"Image file: {file}\nNote: Word-to-word extraction failed.", metadata={"source": file_path}))

    return documents


# ---------------- Chunking ----------------
@asyncHandler
async def chunking_documents(documents, chunk_size: int = 200, chunk_overlap: int = 0):
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
    )
    return splitter.split_documents(documents)


# ---------------- FAISS Vector Store ----------------
@asyncHandler
async def create_vector_store(path: str = "db", docs: str = "data"):

    if os.path.exists(path) and os.path.exists(os.path.join(path, "index.faiss")):
        try:
            logging.info("Existing FAISS DB found. Loading...")
            vectorstore = FAISS.load_local(path, embedding_model, allow_dangerous_deserialization=True)
            return vectorstore
        except Exception as e:
            logging.warning(f"Failed to load existing FAISS DB: {e}. Creating new one.")
    
    logging.info("Creating new FAISS DB...")
    documents = await document_fetcher(docs=docs)
    if not documents:
        logging.warning(f"No documents found or failed to load any documents from {docs}. Skipping FAISS creation.")
        return None
        
    chunks = await chunking_documents(documents)
    if not chunks:
        logging.warning(f"No chunks created from documents in {docs}. Skipping FAISS creation.")
        return None

    vectorstore = FAISS.from_documents(
        documents=chunks,
        embedding=embedding_model
    )

    # Save locally
    vectorstore.save_local(path)

    return vectorstore


# ---------------- Retriever ----------------
@asyncHandler
async def create_retreiver(vectorstore, k: int = RETREIVER_DEFAULT_K):
    retriever = vectorstore.as_retriever(search_kwargs={"k": k})
    return retriever


# ---------------- Get Raw Documents ----------------
async def get_documents(docs: str = "data") -> str:
    documents = await document_fetcher(docs=docs)
    text = "\n".join([doc.page_content for doc in documents])
    return text