multimodal_rag_chatbot / load_vector_db.py
vamsidharmuthireddy's picture
Upload 6 files
b3f819d verified
import os
import streamlit as st
from langchain_aws import BedrockEmbeddings
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
from typing_extensions import List, Dict, TypedDict
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph, END
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langgraph.graph import MessagesState
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_milvus import Milvus
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from logging_config import setup_logger
from flashrank import Ranker
from langchain_community.document_compressors import FlashrankRerank
from langchain.retrievers import ContextualCompressionRetriever
logger = setup_logger(__name__)
def init_vector_db(embeddings):
# Initialize vector store
URI = "db/vectordb_milvus.db"
collection_name = "my_collection"
# Check if the collection already exists
try:
st.info("Checking for existing Milvus db...")
vector_store = Milvus(
embedding_function=embeddings,
connection_args={"uri": URI},
auto_id=True,
collection_name=collection_name,
index_params={"index_type": "FLAT", "metric_type": "COSINE"},
)
results = vector_store.similarity_search("test query", k=1)
if len(results) > 0:
st.success("Document data found in existing collection.")
documents_loaded = True
else:
st.info("Collection exists but might be empty. Will check for documents.")
documents_loaded = False
except Exception as e:
st.info("Creating new Milvus collection...")
vector_store = Milvus(
embedding_function=embeddings,
connection_args={"uri": URI},
auto_id=True,
collection_name=collection_name,
index_params={"index_type": "FLAT", "metric_type": "COSINE"},
)
documents_loaded = False
# Load documents if needed
if not documents_loaded:
folder_path = "docs"
loader = DirectoryLoader(
folder_path,
glob="**/*.pdf",
loader_cls=PyPDFLoader
)
try:
documents = loader.load()
st.info(f"Loaded {len(documents)} PDF pages.")
if len(documents) > 0:
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(documents)
st.info(f"Total Document splits: {len(all_splits)}")
# Add documents to vector store
_ = vector_store.add_documents(documents=all_splits)
st.success("Documents added to vector store.")
else:
st.warning("No PDF documents found in the 'docs' folder.")
except Exception as e:
st.error(f"Error loading documents: {e}")
retriever = vector_store.as_retriever(search_kwargs={"k": 100})
model_name = "ms-marco-MultiBERT-L-12"
# model_name = "ms-marco-MiniLM-L-12-v2"
ranker_client = Ranker(model_name=model_name,
cache_dir="./models")
compressor = FlashrankRerank(client=ranker_client, top_n=30)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
return vector_store, compression_retriever