|
|
import os |
|
|
import streamlit as st |
|
|
from langchain_aws import BedrockEmbeddings |
|
|
|
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain_core.documents import Document |
|
|
from typing_extensions import List, Dict, TypedDict |
|
|
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
|
|
|
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader |
|
|
from langgraph.graph import MessagesState |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import SystemMessage |
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langchain_milvus import Milvus |
|
|
from langchain_openai import ChatOpenAI |
|
|
from pydantic import BaseModel, Field |
|
|
from logging_config import setup_logger |
|
|
from flashrank import Ranker |
|
|
from langchain_community.document_compressors import FlashrankRerank |
|
|
from langchain.retrievers import ContextualCompressionRetriever |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
def init_vector_db(embeddings): |
|
|
|
|
|
URI = "db/vectordb_milvus.db" |
|
|
collection_name = "my_collection" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
st.info("Checking for existing Milvus db...") |
|
|
vector_store = Milvus( |
|
|
embedding_function=embeddings, |
|
|
connection_args={"uri": URI}, |
|
|
auto_id=True, |
|
|
collection_name=collection_name, |
|
|
index_params={"index_type": "FLAT", "metric_type": "COSINE"}, |
|
|
) |
|
|
|
|
|
results = vector_store.similarity_search("test query", k=1) |
|
|
|
|
|
if len(results) > 0: |
|
|
st.success("Document data found in existing collection.") |
|
|
documents_loaded = True |
|
|
else: |
|
|
st.info("Collection exists but might be empty. Will check for documents.") |
|
|
documents_loaded = False |
|
|
|
|
|
except Exception as e: |
|
|
st.info("Creating new Milvus collection...") |
|
|
vector_store = Milvus( |
|
|
embedding_function=embeddings, |
|
|
connection_args={"uri": URI}, |
|
|
auto_id=True, |
|
|
collection_name=collection_name, |
|
|
index_params={"index_type": "FLAT", "metric_type": "COSINE"}, |
|
|
) |
|
|
documents_loaded = False |
|
|
|
|
|
|
|
|
if not documents_loaded: |
|
|
folder_path = "docs" |
|
|
loader = DirectoryLoader( |
|
|
folder_path, |
|
|
glob="**/*.pdf", |
|
|
loader_cls=PyPDFLoader |
|
|
) |
|
|
|
|
|
try: |
|
|
documents = loader.load() |
|
|
st.info(f"Loaded {len(documents)} PDF pages.") |
|
|
|
|
|
if len(documents) > 0: |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
all_splits = text_splitter.split_documents(documents) |
|
|
st.info(f"Total Document splits: {len(all_splits)}") |
|
|
|
|
|
|
|
|
_ = vector_store.add_documents(documents=all_splits) |
|
|
st.success("Documents added to vector store.") |
|
|
else: |
|
|
st.warning("No PDF documents found in the 'docs' folder.") |
|
|
except Exception as e: |
|
|
st.error(f"Error loading documents: {e}") |
|
|
|
|
|
|
|
|
retriever = vector_store.as_retriever(search_kwargs={"k": 100}) |
|
|
|
|
|
model_name = "ms-marco-MultiBERT-L-12" |
|
|
|
|
|
ranker_client = Ranker(model_name=model_name, |
|
|
cache_dir="./models") |
|
|
|
|
|
compressor = FlashrankRerank(client=ranker_client, top_n=30) |
|
|
compression_retriever = ContextualCompressionRetriever( |
|
|
base_compressor=compressor, base_retriever=retriever |
|
|
) |
|
|
|
|
|
return vector_store, compression_retriever |
|
|
|
|
|
|