| import gradio as gr |
| import string |
| import os |
| import torch |
|
|
| from langchain.retrievers import BM25Retriever, EnsembleRetriever |
| from langchain.vectorstores import FAISS |
|
|
| from langchain.text_splitter import CharacterTextSplitter |
| from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings |
| from langchain.document_loaders import PyPDFLoader |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import LoraConfig, get_peft_model |
|
|
| INFERENCE_PATH = "inference.txt" |
|
|
| loader = PyPDFLoader("rule2015.pdf") |
| pages = loader.load_and_split() |
|
|
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| texts = text_splitter.split_documents(pages) |
|
|
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
| hf_embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME) |
|
|
| rules = [] |
| with open('full_rules.txt', 'r', encoding = 'utf-8') as file: |
| for line in file: |
| rules.append(line.strip()) |
| |
| bm25_retriever = BM25Retriever.from_texts(rules) |
| bm25_retriever.k = 3 |
|
|
| faiss_vectorstore = FAISS.from_texts(rules, hf_embeddings) |
| faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
|
| ensemble_retriever = EnsembleRetriever( |
| retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] |
| ) |
|
|
| model_path = "Open-Orca/Mistral-7B-OpenOrca" |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16 |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config, device_map={"":0}) |
| model = torch.compile(model) |
|
|
| ''' |
| config = LoraConfig( |
| r=32, |
| lora_alpha=64, |
| target_modules=["q_proj", "k_proj", "v_proj"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM" |
| ) |
| |
| model = get_peft_model(model, config) |
| ''' |
|
|
| def get_relevant(question): |
| docs = ensemble_retriever.get_relevant_documents(question) |
| |
| doc_relevant = "" |
| for i in range(len(docs)): |
| doc_relevant += docs[i].page_content + "." + "\n" |
| |
| return doc_relevant |
|
|
| def write_prompt(doc_relevant, question): |
| PROMPT = "### Instruction:\n{instruction1}\n{doc_relevant}\n{instruction2}\n{question}\n{instruction3}\n### Response:" |
| input_prompt = PROMPT.format_map( |
| {"instruction1": "Dưới đây là một vài điều luật có liên quan đến câu hỏi về bộ luật tố tụng dân sự tại Việt Nam: ", "doc_relevant": doc_relevant, "instruction2": "Hãy trích xuất trong những điều luật đó về nội dung có liên quan đến câu hỏi sau: ", "question": question, "instruction3": "Hãy đưa ra câu trả lời đầy đủ, ngắn gọn và chính xác nhất"} |
| ) |
| return input_prompt |
|
|
| def generate(input_prompt): |
| device = 'cpu' |
| input_ids = tokenizer(input_prompt, return_tensors="pt") |
| |
| |
| |
| outputs = model.generate( |
| |
| |
| inputs=input_ids["input_ids"].to("cuda:0"), |
| attention_mask=input_ids["attention_mask"].to("cuda:0"), |
| |
| |
| |
| |
| max_new_tokens=1024, |
| |
| pad_token_id=tokenizer.eos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| |
| ) |
|
|
| return outputs |
|
|
| def answer(question): |
| doc_relevant = get_relevant(question) |
| input_prompt = write_prompt(doc_relevant, question) |
| outputs = generate(input_prompt) |
| |
| response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
| response = response.split("### Response:")[1] |
| |
| return response.strip() |
| |
| examples = [] |
|
|
| with open(INFERENCE_PATH, 'r', encoding = 'utf-8') as file: |
| for line in file: |
| examples.append(line) |
|
|
| demo = gr.Interface( |
| fn=answer, |
| inputs="text", |
| outputs="text", |
| examples=examples, |
| ) |
|
|
| |
| demo.queue().launch(share=True) |
|
|
|
|