Spaces:
Sleeping
Sleeping
| from llama_index.core import load_index_from_storage, StorageContext, SimpleDirectoryReader, VectorStoreIndex, QueryBundle | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core import Settings | |
| from llama_index.llms.groq import Groq | |
| from llama_index.llms.ollama import Ollama | |
| from llama_index.readers.file import DocxReader | |
| from llama_index.core.node_parser import SimpleFileNodeParser, SentenceSplitter, SimpleNodeParser | |
| from llama_index.core.storage.docstore import SimpleDocumentStore | |
| from llama_index.vector_stores.faiss import FaissVectorStore | |
| from llama_index.core.retrievers import RecursiveRetriever | |
| from llama_index.core.schema import IndexNode | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.core.response.notebook_utils import display_source_node | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| import faiss | |
| import re | |
| from core.config import settings | |
| from llama_index.core.schema import MetadataMode | |
| import pickle | |
| from llama_index.core.node_parser import SentenceWindowNodeParser | |
| from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | |
| from llama_index.postprocessor.cohere_rerank import CohereRerank | |
| from prompt.prompt import qa_prompt_tmpl, refine_prompt_tmpl | |
| # #Settings | |
| # Settings.embed_model = HuggingFaceEmbedding( | |
| # model_name= settings.EMBEDDING_MODEL | |
| # ) | |
| # Settings.llm = Groq(model=settings.MODEL_ID, api_key= settings.MODEL_API_KEY) | |
| Settings.embed_model = OpenAIEmbedding( | |
| model_name= settings.OPENAI_EMBEDDING_MODEL | |
| ) | |
| Settings.llm = OpenAI(model = settings.OPENAI_MODEL, | |
| api_key = settings.OPENAI_API_KEY, max_tokens = 512) | |
| def windows_parser(documents: str): | |
| # create the sentence window node parser w/ default settings | |
| # d = settings.EMBEDDING_MODEL_DIMENSIONS | |
| d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
| faiss_index = faiss.IndexFlatL2(d) | |
| # assign faiss as the vector_store to the context | |
| vector_store = FaissVectorStore(faiss_index=faiss_index) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| node_parser = SentenceWindowNodeParser.from_defaults( | |
| window_size=50, | |
| window_metadata_key="window", | |
| original_text_metadata_key="original_text", | |
| ) | |
| sentence_nodes = node_parser.get_nodes_from_documents(documents) | |
| sentence_index = VectorStoreIndex(sentence_nodes, | |
| storage_context=storage_context, | |
| show_progress=True,) | |
| sentence_index.storage_context.persist() | |
| def window_query(query: str): | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| sentence_index = load_index_from_storage(storage_context=storage_context) | |
| query_engine = sentence_index.as_query_engine( | |
| similarity_top_k=3, | |
| # the target key defaults to `window` to match the node_parser's default | |
| node_postprocessors=[ | |
| MetadataReplacementPostProcessor(target_metadata_key="window"), | |
| CohereRerank(api_key=settings.COHERE_API_KEY, top_n=2), | |
| ], | |
| verbose=True, | |
| ) | |
| query_engine.update_prompts( | |
| {"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
| "response_synthesizer:refine_template": refine_prompt_tmpl,} | |
| ) | |
| response = query_engine.query(f"{query}") | |
| window = response.source_nodes[0].node.metadata["window"][:500] | |
| sentence = response.source_nodes[0].node.metadata["original_text"][:500] | |
| print(f"Window: {window}") | |
| print("------------------") | |
| print(f"Original Sentence: {sentence}") | |
| return str(response) | |
| def document_prepare(path: str): | |
| #load documents | |
| documents = SimpleDirectoryReader(path, file_extractor = {'.docx': DocxReader()}).load_data() | |
| print(len(documents)) | |
| #extract metadata if needed | |
| # extract_metadata(documents) | |
| # documents[0].excluded_llm_metadata_keys = ["law_number", "file_name", "file_type", "file_size","creation_date", "last_modified_date"] | |
| # documents[0].excluded_embed_metadata_keys = ["law_number", "law_name","file_name", "file_type", "file_size","creation_date", "last_modified_date"] | |
| # # print("LLM: ",documents[0].get_content(metadata_mode=MetadataMode.LLM)[:500]) | |
| # print("Embed: ", documents[0].get_content(metadata_mode=MetadataMode.EMBED)[:500]) | |
| return documents | |
| def extract_metadata(docs: list) -> None: | |
| for doc in docs: | |
| text = doc.text | |
| # The regular expression pattern | |
| pattern_laws_number = r"(?i)số[:\s]+([^\s.,]+)" | |
| pattern_laws_name = r"(NGHỊ ĐỊNH|LUẬT)\s+(.*?)\s+Căn cứ" | |
| # Find the match | |
| match_laws_number = re.search(pattern_laws_number, text) | |
| match_laws_name = re.search(pattern_laws_name, text) | |
| # Extract and print the result if a match is found | |
| # print("before:", doc.metadata) | |
| if match_laws_number: | |
| # print("Found:", match_laws_number.group(1)) # Output: 59/2020/QH14 | |
| (doc.metadata) = {**doc.metadata, "law_number" : f"{match_laws_number.group(1)}"} | |
| if match_laws_name: | |
| # print("Found:", f"{match_laws_name.group(1)} {match_laws_name.group(2)}") # Output: Luật doanh nghiệp | |
| (doc.metadata) = {**doc.metadata, "law_name" : f"{match_laws_name.group(1)} {match_laws_name.group(2)}"} | |
| # print("after:", doc.metadata, "\n") | |
| def faiss_setup(documents: list) -> None : | |
| d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
| faiss_index = faiss.IndexFlatL2(d) | |
| # assign faiss as the vector_store to the context | |
| vector_store = FaissVectorStore(faiss_index=faiss_index) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| index = VectorStoreIndex.from_documents( | |
| documents, | |
| storage_context = storage_context) | |
| def faiss_load(query: str) -> str: | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| index = load_index_from_storage(storage_context=storage_context) | |
| query_engine = index.as_query_engine() | |
| vector_retriever = index.as_retriever(similarity_top_k=2) | |
| response = query_engine.query(query) | |
| retrieved_nodes = vector_retriever.retrieve(query) | |
| print(retrieved_nodes[0]) | |
| return response | |
| def get_all_nodes(documents: list): | |
| # Save all_nodes to a file | |
| node_parser = SimpleNodeParser.from_defaults(chunk_size=settings.MAX_NEW_TOKENS, chunk_overlap= settings.MAX_OVERLAPS) | |
| base_nodes = node_parser.get_nodes_from_documents(documents) | |
| # set node ids to be a constant | |
| for idx, node in enumerate(base_nodes): | |
| node.id_ = f"node-{idx}" | |
| #original: 1024. Divided into 8 128, 4 256, 2 512 | |
| sub_chunk_sizes = [(settings.MAX_NEW_TOKENS/8), (settings.MAX_NEW_TOKENS/4), (settings.MAX_NEW_TOKENS/2)] | |
| sub_overlap_sizes = [(settings.MAX_OVERLAPS/8), (settings.MAX_OVERLAPS/4), (settings.MAX_OVERLAPS/2)] | |
| sub_node_parsers = [ | |
| SimpleNodeParser.from_defaults(chunk_size=c, chunk_overlap=o) for c, o in zip(sub_chunk_sizes, sub_overlap_sizes) | |
| ] | |
| all_nodes = [] | |
| for base_node in base_nodes: | |
| for n in sub_node_parsers: | |
| sub_nodes = n.get_nodes_from_documents([base_node]) | |
| sub_inodes = [ | |
| IndexNode.from_text_node(sn, base_node.node_id) for sn in sub_nodes | |
| ] | |
| all_nodes.extend(sub_inodes) | |
| # also add original node to node | |
| original_node = IndexNode.from_text_node(base_node, base_node.node_id) | |
| all_nodes.append(original_node) | |
| # print('done nodes') | |
| return all_nodes | |
| def sub_chunk_setup(all_nodes:list ) -> None: | |
| # Load all_nodes from a file | |
| # d = settings.OPENAI_EMBEDDING_MODEL_DIMS | |
| d = settings.EMBEDDING_MODEL_DIMENSIONS | |
| faiss_index = faiss.IndexFlatL2(d) | |
| # assign faiss as the vector_store to the context | |
| vector_store = FaissVectorStore(faiss_index=faiss_index) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| index = VectorStoreIndex( | |
| all_nodes, | |
| storage_context = storage_context, | |
| show_progress= True | |
| ) | |
| print('done setup') | |
| index.storage_context.persist() | |
| def sub_chunk_query(all_nodes:list, query: str) -> str: | |
| # Load all_nodes from a file | |
| all_nodes_dict = {n.node_id: n for n in all_nodes} | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| index = load_index_from_storage(storage_context=storage_context) | |
| vector_retriever_chunk = index.as_retriever(similarity_top_k=3) | |
| retriever_chunk = RecursiveRetriever( | |
| "vector", | |
| retriever_dict={"vector": vector_retriever_chunk}, | |
| node_dict=all_nodes_dict, | |
| verbose=True, | |
| ) | |
| nodes = retriever_chunk.retrieve(QueryBundle(query)) | |
| for node in nodes: | |
| display_source_node(node, source_length=2000) | |
| # print(settings.MAX_NEW_TOKENS) | |
| query_engine = RetrieverQueryEngine.from_args( | |
| retriever_chunk, storage_context = storage_context | |
| ) | |
| response = str(query_engine.query(f"{query}")) | |
| # print(response) | |
| return response | |
| if __name__ == "__main__": | |
| documents = document_prepare(settings.RAW_DATA_DIR) | |
| # all_nodes = get_all_nodes(documents) | |
| # faiss_setup(documents) | |
| # sub_chunk_setup(all_nodes) | |
| # windows_parser(documents) | |
| # examples=[ | |
| # 'Chào bán cổ phần cho cổ đông hiện hữu của công ty cổ phần không phải là công ty đại chúng được thực hiện ra sao ?', | |
| # 'Quyền của doanh nghiệp là những quyền nào?', | |
| # 'Các trường hợp nào được coi là tên gây nhầm lẫn ?', | |
| # 'Các quy định về chào bán trái phiếu riêng lẻ', | |
| # 'Doanh nghiệp có quyền và nghĩa vụ như thế nào?', | |
| # 'Thành lập công ty TNHH thì quy trình như thế nào?' | |
| # ] | |
| examples = [ | |
| "Công ty cổ phần là gì?", | |
| "Định nghĩa về “góp vốn” trong Luật Doanh nghiệp là gì?", | |
| "Khái niệm “cổ đông” được hiểu như thế nào?", | |
| "Thế nào là “vốn điều lệ” trong doanh nghiệp?", | |
| "“Doanh nghiệp có vốn đầu tư nước ngoài” là gì?" | |
| ] | |
| for example in examples: | |
| # query = examples[3] | |
| query = example | |
| print("///////////////////////////////") | |
| print(query) | |
| # print(faiss_load(query)) | |
| # print(sub_chunk_query(all_nodes, query)) | |
| print("Answer:", window_query(query)) | |
| print("\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\") | |