File size: 3,762 Bytes
3cfb95f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f819d
3cfb95f
b3f819d
 
 
3cfb95f
 
b3f819d
3cfb95f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import streamlit as st
from langchain_aws import BedrockEmbeddings

from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from typing_extensions import List, Dict, TypedDict

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph, END

from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langgraph.graph import MessagesState
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_milvus import Milvus
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from logging_config import setup_logger
from flashrank import Ranker
from langchain_community.document_compressors import FlashrankRerank
from langchain.retrievers import ContextualCompressionRetriever

logger = setup_logger(__name__)


def init_vector_db(embeddings):
    # Initialize vector store
    URI = "db/vectordb_milvus.db"
    collection_name = "my_collection"

    
    # Check if the collection already exists
    try:

        st.info("Checking for existing Milvus db...")
        vector_store = Milvus(
                embedding_function=embeddings,
                connection_args={"uri": URI},
                auto_id=True,
                collection_name=collection_name,
                index_params={"index_type": "FLAT", "metric_type": "COSINE"},
            )
        
        results = vector_store.similarity_search("test query", k=1)

        if len(results) > 0:
            st.success("Document data found in existing collection.")
            documents_loaded = True
        else:
            st.info("Collection exists but might be empty. Will check for documents.")
            documents_loaded = False
    
    except Exception as e:
        st.info("Creating new Milvus collection...")
        vector_store = Milvus(
            embedding_function=embeddings,
            connection_args={"uri": URI},
            auto_id=True,
            collection_name=collection_name,
            index_params={"index_type": "FLAT", "metric_type": "COSINE"},
        )
        documents_loaded = False
    
            # Load documents if needed
    if not documents_loaded:
        folder_path = "docs"
        loader = DirectoryLoader(
            folder_path,
            glob="**/*.pdf", 
            loader_cls=PyPDFLoader
        )
        
        try:
            documents = loader.load()
            st.info(f"Loaded {len(documents)} PDF pages.")

            if len(documents) > 0:
                # Split documents
                text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
                all_splits = text_splitter.split_documents(documents)
                st.info(f"Total Document splits: {len(all_splits)}")

                # Add documents to vector store
                _ = vector_store.add_documents(documents=all_splits)
                st.success("Documents added to vector store.")
            else:
                st.warning("No PDF documents found in the 'docs' folder.")
        except Exception as e:
            st.error(f"Error loading documents: {e}")


    retriever = vector_store.as_retriever(search_kwargs={"k": 100})

    model_name = "ms-marco-MultiBERT-L-12"
    # model_name = "ms-marco-MiniLM-L-12-v2"
    ranker_client = Ranker(model_name=model_name,
                    cache_dir="./models")

    compressor = FlashrankRerank(client=ranker_client, top_n=30)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor, base_retriever=retriever
    )

    return vector_store, compression_retriever