| from langchain.chains import RetrievalQA |
| from langchain.llms import LlamaCpp |
| from retriever import load_db |
| from huggingface_hub import hf_hub_download |
| from langchain.document_loaders import PyPDFLoader , DirectoryLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter , CharacterTextSplitter |
| from langchain.embeddings import HuggingFaceEmbeddings |
| from langchain.vectorstores import FAISS |
| from langchain_community.llms import LlamaCpp |
| from langchain.chains import LLMChain |
| from langchain.prompts import PromptTemplate |
| from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler |
| from langchain.prompts.chat import ( |
| ChatPromptTemplate, |
| SystemMessagePromptTemplate, |
| HumanMessagePromptTemplate |
| ) |
| from langchain.chains import RetrievalQA |
|
|
| |
| MODEL_PATH = hf_hub_download( |
| repo_id="ohaiyo123/SEG_Llama2Lora", |
| filename="llama 2 7b hf chat_Lora.gguf", |
| cache_dir="model_cache" |
| ) |
|
|
| |
| llm = LlamaCpp( |
| model_path= MODEL_PATH, |
| n_gpu_layers= -1, |
| n_batch=512, |
| n_ctx=2048, |
| f16_kv=True, |
| temperature=0.01, |
| callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), |
| verbose=False, |
| ) |
|
|
| def create_db_from_text(): |
| raw_text = """Ngân hàng là một tổ chức tài chính cung cấp các dịch vụ như gửi tiền, cho vay, chuyển khoản và thanh toán. Tại Việt Nam, các ngân hàng thương mại đóng vai trò quan trọng trong việc hỗ trợ doanh nghiệp và cá nhân tiếp cận nguồn vốn. |
| Một số ngân hàng lớn bao gồm Vietcombank, BIDV, VietinBank và Techcombank.""" |
|
|
| text_splitter = CharacterTextSplitter( |
| separator = "\n", |
| chunk_size=500, |
| chunk_overlap=50, |
| length_function=len |
| ) |
| chunks = text_splitter.split_text(raw_text) |
| |
| embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base") |
| db = FAISS.from_texts(chunks, embbeding_model) |
| db.save_local(vector_db_path) |
| return db |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def create_db_from_file(): |
| loader = DirectoryLoader(dpf_data_path, |
| glob="*.pdf", |
| loader_cls=PyPDFLoader) |
| documents = loader.load() |
|
|
| text_splitter = RecursiveCharacterTextSplitter( |
| separators=[ |
| "\n\n", |
| "\n", |
| " ", |
| ".", |
| ",", |
| "\u200b", |
| "\uff0c", |
| "\u3001", |
| "\uff0e", |
| "\u3002", |
| "", |
| ], |
| chunk_size=500, |
| chunk_overlap=50 |
| ) |
| chunks = text_splitter.split_documents(documents) |
| |
| embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base") |
| db = FAISS.from_documents(chunks,embbeding_model) |
| db.save_local(vector_db_path) |
| return db |
|
|
|
|
|
|
|
|
| |
| def load_llm(model_file): |
|
|
| llm = LlamaCpp( |
| model_path= model_file, |
| n_gpu_layers= -1, |
| n_batch=512, |
| n_ctx=2048, |
| f16_kv=True, |
| temperature=0.01, |
| callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), |
| verbose=True, |
|
|
| ) |
| return llm |
|
|
|
|
|
|
|
|
|
|
| |
| def create_prompt(): |
| system_tm = SystemMessagePromptTemplate.from_template( |
| "Sử dụng thông tin sau đây để trả lời câu hỏi.\n" |
| "Nếu bạn không biết câu trả lời thì hãy nói rằng bạn không biết, đừng cố tạo ra câu trả lời\n\n" |
| "{context}" |
| ) |
| human_tm = HumanMessagePromptTemplate.from_template("{question}") |
| return ChatPromptTemplate.from_messages([system_tm, human_tm]) |
|
|
|
|
|
|
|
|
|
|
|
|
| def create_qna_chain(llm,db): |
| prompt = create_prompt() |
| llm_chain = RetrievalQA.from_chain_type(llm=llm, |
| chain_type="stuff", |
| retriever=db.as_retriever(search_kwargs={"k":3}), |
| return_source_documents=True, |
| chain_type_kwargs={"prompt":prompt} |
| ) |
| return llm_chain |
|
|
|
|
|
|
|
|
|
|
| |
| def read_vectors_db(): |
| |
| embbeding_model = HuggingFaceEmbeddings(model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base") |
| db = FAISS.load_local(vector_db_path,embbeding_model, allow_dangerous_deserialization=True) |
| return db |
|
|
|
|
|
|
|
|
| |
| |
| db = read_vectors_db() |
| |
| llm = load_llm(model_file) |
|
|
|
|
| |
|
|
| llm_chain = create_qna_chain(llm,db) |