import os import streamlit as st from langchain_aws import BedrockEmbeddings from langchain.chat_models import init_chat_model from langchain_core.documents import Document from typing_extensions import List, Dict, TypedDict from langchain_text_splitters import RecursiveCharacterTextSplitter from langgraph.graph import START, StateGraph, END from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader from langgraph.graph import MessagesState from langchain_core.tools import tool from langchain_core.messages import SystemMessage from langgraph.prebuilt import ToolNode, tools_condition from langchain_milvus import Milvus from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from logging_config import setup_logger from flashrank import Ranker from langchain_community.document_compressors import FlashrankRerank from langchain.retrievers import ContextualCompressionRetriever logger = setup_logger(__name__) def init_vector_db(embeddings): # Initialize vector store URI = "db/vectordb_milvus.db" collection_name = "my_collection" # Check if the collection already exists try: st.info("Checking for existing Milvus db...") vector_store = Milvus( embedding_function=embeddings, connection_args={"uri": URI}, auto_id=True, collection_name=collection_name, index_params={"index_type": "FLAT", "metric_type": "COSINE"}, ) results = vector_store.similarity_search("test query", k=1) if len(results) > 0: st.success("Document data found in existing collection.") documents_loaded = True else: st.info("Collection exists but might be empty. Will check for documents.") documents_loaded = False except Exception as e: st.info("Creating new Milvus collection...") vector_store = Milvus( embedding_function=embeddings, connection_args={"uri": URI}, auto_id=True, collection_name=collection_name, index_params={"index_type": "FLAT", "metric_type": "COSINE"}, ) documents_loaded = False # Load documents if needed if not documents_loaded: folder_path = "docs" loader = DirectoryLoader( folder_path, glob="**/*.pdf", loader_cls=PyPDFLoader ) try: documents = loader.load() st.info(f"Loaded {len(documents)} PDF pages.") if len(documents) > 0: # Split documents text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) all_splits = text_splitter.split_documents(documents) st.info(f"Total Document splits: {len(all_splits)}") # Add documents to vector store _ = vector_store.add_documents(documents=all_splits) st.success("Documents added to vector store.") else: st.warning("No PDF documents found in the 'docs' folder.") except Exception as e: st.error(f"Error loading documents: {e}") retriever = vector_store.as_retriever(search_kwargs={"k": 100}) model_name = "ms-marco-MultiBERT-L-12" # model_name = "ms-marco-MiniLM-L-12-v2" ranker_client = Ranker(model_name=model_name, cache_dir="./models") compressor = FlashrankRerank(client=ranker_client, top_n=30) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) return vector_store, compression_retriever