MedLLM-Assistant / test /data_ingest.py
VuvanAn's picture
Upload 47 files
09dc9d3 verified
import argparse
import os
from typing import List
from ..rag_pipeline import get_embeddings, load_data
from ..utils import load_local, save_local
def main(args):
print(f"Log: {args}")
if args.clear_vectorstore:
import shutil
if os.path.isdir(args.vectorstore_dir):
shutil.rmtree(args.vectorstore_dir)
embed_model = get_embeddings(args.embed_model_name)
vectorstore, docs = load_local(args.vectorstore_dir, embed_model)
new_docs = []
for data_path in args.data_paths:
new_docs.extend(load_data(data_path, args.file_type))
print(f"Got {len(new_docs)} documents.")
if args.chunk_method == "recursive":
from ..rag_pipeline import recursive_chunking
new_docs = recursive_chunking(new_docs, args.chunk_size, args.chunk_overlap)
elif args.chunk_method == "markdown":
from ..rag_pipeline import markdown_chunking
new_docs = markdown_chunking(new_docs, args.chunk_size, args.chunk_overlap)
print(f"Got {len(new_docs)} chunks.")
from langchain_community.vectorstores import FAISS
if vectorstore is None:
vectorstore = FAISS.from_documents(new_docs, embed_model)
docs = new_docs
print(f"Successfully consumed {len(new_docs)} documents.")
else:
docs.extend(new_docs)
vectorstore.add_documents(new_docs)
save_local(args.vectorstore_dir, vectorstore, docs)
import json
with open(os.path.join(args.vectorstore_dir, "config.json"), "a") as f:
json.dump(vars(args), f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
data_paths = [
'dataset/RAG_Data/wiki_vi',
'dataset/RAG_Data/youmed',
'dataset/RAG_Data/mimic_ex_report',
'dataset/RAG_Data/Download sach y/OCR',
]
# Dataset params
parser.add_argument("--data_paths", type=List[str], required=False, default=data_paths)
parser.add_argument("--vectorstore_dir", type=str, required=False, default="notebook/An/master/knowledge/vectorstore_full")
parser.add_argument("--file_type", type=str, choices=["pdf", "txt"], default="txt")
# Model params
parser.add_argument("--embed_model_name", type=str, default="alibaba-nlp/gte-multilingual-base")
# Index params
parser.add_argument("--chunk_size", type=int, default=2048)
parser.add_argument("--chunk_overlap", type=int, default=512)
parser.add_argument("--chunk_method", type=str, choices=["recursive", "markdown"], default="markdown")
# Vectorstore params
parser.add_argument("--vectorstore", type=str, choices=["faiss", "chroma"], default="faiss")
parser.add_argument("--clear_vectorstore", action="store_true", default=True)
args = parser.parse_args()
main(args)