orca_pdf / app.py
qminh369's picture
Create app.py
4c348c9
import gradio as gr
import string
import os
import torch
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
INFERENCE_PATH = "inference.txt"
loader = PyPDFLoader("rule2015.pdf")
pages = loader.load_and_split()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(pages)
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
hf_embeddings = HuggingFaceEmbeddings(model_name=MODEL_NAME)
rules = []
with open('full_rules.txt', 'r', encoding = 'utf-8') as file:
for line in file:
rules.append(line.strip())
bm25_retriever = BM25Retriever.from_texts(rules)
bm25_retriever.k = 3 # 2
faiss_vectorstore = FAISS.from_texts(rules, hf_embeddings)
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 3}) # 2
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
model_path = "Open-Orca/Mistral-7B-OpenOrca"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config, device_map={"":0})
model = torch.compile(model)
'''
config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
'''
def get_relevant(question):
docs = ensemble_retriever.get_relevant_documents(question)
doc_relevant = ""
for i in range(len(docs)):
doc_relevant += docs[i].page_content + "." + "\n"
return doc_relevant
def write_prompt(doc_relevant, question):
PROMPT = "### Instruction:\n{instruction1}\n{doc_relevant}\n{instruction2}\n{question}\n{instruction3}\n### Response:"
input_prompt = PROMPT.format_map(
{"instruction1": "Dưới đây là một vài điều luật có liên quan đến câu hỏi về bộ luật tố tụng dân sự tại Việt Nam: ", "doc_relevant": doc_relevant, "instruction2": "Hãy trích xuất trong những điều luật đó về nội dung có liên quan đến câu hỏi sau: ", "question": question, "instruction3": "Hãy đưa ra câu trả lời đầy đủ, ngắn gọn và chính xác nhất"}
)
return input_prompt
def generate(input_prompt):
device = 'cpu'
input_ids = tokenizer(input_prompt, return_tensors="pt")
#input_ids = tokenizer(input_prompt, return_tensors="np")
#input_ids = tokenizer(input_prompt)
outputs = model.generate(
#inputs=input_ids["input_ids"].to(device),
#attention_mask=input_ids["attention_mask"].to(device),
inputs=input_ids["input_ids"].to("cuda:0"),
attention_mask=input_ids["attention_mask"].to("cuda:0"),
#do_sample=True, #
#temperature=0.5, #
#top_k=50, #
#top_p=0.9, #
max_new_tokens=1024, # 1024
#max_time=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
#pad_token_id=tokenizer.pad_token_id,
)
return outputs
def answer(question):
doc_relevant = get_relevant(question)
input_prompt = write_prompt(doc_relevant, question)
outputs = generate(input_prompt)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
response = response.split("### Response:")[1]
return response.strip()
examples = []
with open(INFERENCE_PATH, 'r', encoding = 'utf-8') as file:
for line in file:
examples.append(line)
demo = gr.Interface(
fn=answer,
inputs="text",
outputs="text",
examples=examples,
)
#demo.launch(share=True)
demo.queue().launch(share=True)