|
|
import argparse
|
|
|
import os
|
|
|
from typing import List
|
|
|
|
|
|
from ..rag_pipeline import get_embeddings, load_data
|
|
|
from ..utils import load_local, save_local
|
|
|
|
|
|
def main(args):
|
|
|
print(f"Log: {args}")
|
|
|
|
|
|
if args.clear_vectorstore:
|
|
|
import shutil
|
|
|
if os.path.isdir(args.vectorstore_dir):
|
|
|
shutil.rmtree(args.vectorstore_dir)
|
|
|
|
|
|
embed_model = get_embeddings(args.embed_model_name)
|
|
|
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
|
|
|
|
|
|
new_docs = []
|
|
|
for data_path in args.data_paths:
|
|
|
new_docs.extend(load_data(data_path, args.file_type))
|
|
|
print(f"Got {len(new_docs)} documents.")
|
|
|
|
|
|
if args.chunk_method == "recursive":
|
|
|
from ..rag_pipeline import recursive_chunking
|
|
|
new_docs = recursive_chunking(new_docs, args.chunk_size, args.chunk_overlap)
|
|
|
elif args.chunk_method == "markdown":
|
|
|
from ..rag_pipeline import markdown_chunking
|
|
|
new_docs = markdown_chunking(new_docs, args.chunk_size, args.chunk_overlap)
|
|
|
print(f"Got {len(new_docs)} chunks.")
|
|
|
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
if vectorstore is None:
|
|
|
vectorstore = FAISS.from_documents(new_docs, embed_model)
|
|
|
docs = new_docs
|
|
|
print(f"Successfully consumed {len(new_docs)} documents.")
|
|
|
else:
|
|
|
docs.extend(new_docs)
|
|
|
vectorstore.add_documents(new_docs)
|
|
|
|
|
|
save_local(args.vectorstore_dir, vectorstore, docs)
|
|
|
|
|
|
import json
|
|
|
with open(os.path.join(args.vectorstore_dir, "config.json"), "a") as f:
|
|
|
json.dump(vars(args), f)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
data_paths = [
|
|
|
'dataset/RAG_Data/wiki_vi',
|
|
|
'dataset/RAG_Data/youmed',
|
|
|
'dataset/RAG_Data/mimic_ex_report',
|
|
|
'dataset/RAG_Data/Download sach y/OCR',
|
|
|
]
|
|
|
|
|
|
|
|
|
parser.add_argument("--data_paths", type=List[str], required=False, default=data_paths)
|
|
|
parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
|
|
|
parser.add_argument("--file_type", type=str, choices=["pdf", "txt"], default="txt")
|
|
|
|
|
|
|
|
|
parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
|
|
|
|
|
|
|
|
|
parser.add_argument("--chunk_size", type=int, default=2048)
|
|
|
parser.add_argument("--chunk_overlap", type=int, default=512)
|
|
|
parser.add_argument("--chunk_method", type=str, choices=["recursive", "markdown"], default="markdown")
|
|
|
|
|
|
|
|
|
parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
|
|
|
parser.add_argument("--clear_vectorstore", action="store_true", default=True)
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
main(args) |