Spaces:
Sleeping
Sleeping
| from llama_index.core import StorageContext, load_index_from_storage | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core import Settings | |
| from llama_index.llms.groq import Groq | |
| from llama_index.llms.openai import OpenAI | |
| from core.config import settings | |
| from llama_index.vector_stores.faiss import FaissVectorStore | |
| from prompt.prompt import qa_prompt_tmpl, refine_prompt_tmpl | |
| from IPython.display import Markdown, display | |
| import re | |
| from llama_index.core.retrievers import RecursiveRetriever | |
| import string | |
| from llama_index.postprocessor.cohere_rerank import CohereRerank | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| import pickle | |
| from loader import get_all_nodes, document_prepare | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.core.query_engine import MultiStepQueryEngine | |
| from llama_index.core.indices.query.query_transform.base import ( | |
| StepDecomposeQueryTransform, | |
| ) | |
| from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | |
| #Settings | |
| # Settings.embed_model = HuggingFaceEmbedding( | |
| # model_name= settings.EMBEDDING_MODEL | |
| # ) | |
| Settings.embed_model = OpenAIEmbedding( | |
| model_name= settings.OPENAI_EMBEDDING_MODEL | |
| ) | |
| Settings.llm = OpenAI(model = settings.OPENAI_MODEL, | |
| api_key = settings.OPENAI_API_KEY, temperature=0) | |
| step_decompose_transform = StepDecomposeQueryTransform( | |
| llm=Settings.llm, verbose=True | |
| ) | |
| # Settings.llm = Groq(model=settings.MODEL_ID, api_key= settings.MODEL_API_KEY) | |
| # print(Settings.llm.max_tokens) | |
| # all_nodes = get_all_nodes(document_prepare(settings.RAW_DATA_DIR)) | |
| # define prompt viewing function | |
| def display_prompt_dict(prompts_dict): | |
| for k, p in prompts_dict.items(): | |
| text_md = f"**Prompt Key**: {k}<br>" f"**Text:** <br>" | |
| display(Markdown(text_md)) | |
| print(p.get_template()) | |
| display(Markdown("<br><br>")) | |
| def preprocessing_text(query: str) -> str: | |
| text = query | |
| abbreviations = { | |
| 'tnhh': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
| 'Tnhh': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
| 'TNHH': 'Trách nhiệm hữu hạn', # Công ty Trách nhiệm Hữu hạn | |
| 'cp': 'Cổ phần', # Công ty Cổ phần | |
| 'CP': 'Cổ phần', | |
| 'mtv': 'Một thành viên', # Công ty Trách nhiệm Hữu hạn Một Thành Viên | |
| 'MTV': 'Một thành viên', | |
| 'công ty hd': 'công ty Hợp danh', # Công ty Hợp danh | |
| 'công ty HD': 'công ty Hợp danh', | |
| 'dn': 'doanh nghiệp', # Doanh nghiệp | |
| 'DN': 'Doanh nghiệp', | |
| 'DNTN': 'Doanh nghiệp tư nhân', | |
| 'dntn': 'Doanh nghiệp tư nhân', | |
| 'Dntn': 'Doanh nghiệp tư nhân', | |
| 'vốn đl': 'Vốn điều lệ', # Vốn Điều lệ | |
| 'gpkd': 'Giấy phép kinh doanh', # Giấy Phép Kinh Doanh | |
| 'GPKD': 'Giấy phép kinh doanh', | |
| 'dkdn': 'Đăng ký doanh nghiệp', # Đăng Ký Doanh Nghiệp | |
| 'tldn': 'Thành lập doanh nghiệp', # Thành lập Doanh nghiệp | |
| 'hdqt': 'Hội đồng quản trị', # Hội Đồng Quản Trị | |
| 'vốn góp': 'Vốn góp', # Vốn Góp | |
| 'tct': 'Tổng công ty', # Tổng Công ty | |
| 'kv': 'Khu vực', # Khu Vực | |
| 'htx': 'Hợp tác xã', # Hợp Tác Xã | |
| 'lds': 'Liên doanh', # Liên Doanh | |
| 'sở hđt': 'Sở hữu đầu tư', # Sở Hữu Đầu Tư | |
| 'nlđ': 'Người lao động', # Người Lao Động | |
| 'đt': 'Đầu tư', # Đầu Tư | |
| 'kt': 'Kinh tế', # Kinh Tế | |
| 'kte': 'Kinh tế', | |
| 'hđ': 'hợp đồng', | |
| 'hdong': 'hợp đồng', | |
| 'gd': 'Giám đốc', | |
| 'đtdnnn': 'Đầu tư doanh nghiệp nước ngoài' # Đầu Tư Doanh Nghiệp Nước Ngoài | |
| } | |
| for k,v in abbreviations.items(): | |
| text = text.replace(k,v) | |
| text = re.sub(r'(.)\1{2,}', r'\1', text) #Removes trailing | |
| text = re.sub(r"(\w)\s*([{}])\s*(\w)".format(re.escape(string.punctuation)), r"\1 \3", text) # Removes punctuation after word characters | |
| text = re.sub(r"(\w)([" + string.punctuation + "])", r"\1", text) # Removes punctuation after word characters | |
| text = re.sub(f"([{string.punctuation}])([{string.punctuation}])+", r"\1", text) # Remove repeated consecutive punctuation marks | |
| text = text.strip() # Remove leading and trailing whitespaces | |
| # While loops to remove leading and trailing punctuation and whitespace characters. | |
| while text.endswith(tuple(string.punctuation + string.whitespace)): | |
| text = text[:-1] | |
| while text.startswith(tuple(string.punctuation + string.whitespace)): | |
| text = text[1:] | |
| text = re.sub(r"\s+", " ", text) # Replace multiple consecutive whitespaces with a single space | |
| return text | |
| def response_faiss(query:str, history: str) -> str: | |
| message = preprocessing_text(query) | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| index = load_index_from_storage(storage_context=storage_context) | |
| vector_retriever = index.as_retriever(similarity_top_k=2) | |
| query_engine = index.as_query_engine() | |
| query_engine.update_prompts( | |
| {"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
| "response_synthesizer:refine_template": refine_prompt_tmpl,} | |
| ) | |
| # display_prompt_dict(query_engine.get_prompts()) | |
| response = str(query_engine.query(f"{message}")) | |
| retrieved_nodes = vector_retriever.retrieve(message) | |
| print(retrieved_nodes[0].metadata) | |
| print(response) | |
| return response | |
| def sub_chunk_query(query: str, history: str) -> str: | |
| query = preprocessing_text(query) | |
| all_nodes_dict = {n.node_id: n for n in all_nodes} | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| index = load_index_from_storage(storage_context=storage_context) | |
| vector_retriever_chunk = index.as_retriever(similarity_top_k=2) | |
| retriever_chunk = RecursiveRetriever( | |
| "vector", | |
| retriever_dict={"vector": vector_retriever_chunk}, | |
| node_dict=all_nodes_dict, | |
| verbose=True, | |
| ) | |
| nodes = retriever_chunk.retrieve(query) | |
| print(nodes[0].text[:500]) | |
| query_engine = MultiStepQueryEngine( | |
| retriever_chunk, | |
| storage_context = storage_context, | |
| similarity_top_k=5, | |
| query_transform=step_decompose_transform, | |
| node_postprocessors=[ | |
| CohereRerank(api_key=settings.COHERE_API_KEY, top_n=3) | |
| ], | |
| ) | |
| query_engine.update_prompts( | |
| {"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
| "response_synthesizer:refine_template": refine_prompt_tmpl,} | |
| ) | |
| response = str(query_engine.query(f"{query}")) | |
| print(query) | |
| print(response) | |
| return response | |
| def window_query(query: str, history: str): | |
| query = preprocessing_text(query) | |
| vector_store = FaissVectorStore.from_persist_dir("./storage") | |
| storage_context = StorageContext.from_defaults( | |
| vector_store=vector_store, persist_dir="./storage" | |
| ) | |
| sentence_index = load_index_from_storage(storage_context=storage_context) | |
| query_engine = sentence_index.as_query_engine( | |
| similarity_top_k=3, | |
| # the target key defaults to `window` to match the node_parser's default | |
| node_postprocessors=[ | |
| MetadataReplacementPostProcessor(target_metadata_key="window"), | |
| CohereRerank(api_key=settings.COHERE_API_KEY, top_n=2), | |
| ], | |
| verbose=True, | |
| ) | |
| query_engine.update_prompts( | |
| {"response_synthesizer:text_qa_template": qa_prompt_tmpl, | |
| "response_synthesizer:refine_template": refine_prompt_tmpl,} | |
| ) | |
| response = query_engine.query(f"{query}") | |
| window = response.source_nodes[0].node.metadata["window"][:500] | |
| sentence = response.source_nodes[0].node.metadata["original_text"][:500] | |
| print(f"Window: {window}") | |
| print("------------------") | |
| print(f"Original Sentence: {sentence}") | |
| return str(response) | |
| examples=[ | |
| 'Chào bán cổ phần cho cổ đông hiện hữu của công ty cổ phần không phải là công ty đại chúng được thực hiện ra sao ?', | |
| 'Quyền của doanh nghiệp là những quyền nào?', | |
| 'Các trường hợp nào được coi là tên gây nhầm lẫn ?', | |
| 'Các quy định về chào bán trái phiếu riêng lẻ', | |
| 'Doanh nghiệp có quyền và nghĩa vụ như thế nào?', | |
| 'Thành lập công ty TNHH thì quy trình như thế nào?' | |
| ] | |
| # query = examples[1] | |
| # print(query) | |
| # print(sub_chunk_query(query, "")) |