| | from langchain.chains import RetrievalQA |
| | from langchain import HuggingFaceHub |
| | from langchain.prompts import PromptTemplate |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
| | from langchain.document_loaders import PyPDFLoader |
| | from langchain.vectorstores import FAISS |
| | from dotenv import load_dotenv |
| | from glob import glob |
| | from tqdm import tqdm |
| | import gradio as gr |
| | import yaml |
| |
|
| |
|
| | load_dotenv() |
| |
|
| |
|
| | def load_config(): |
| | with open('config.yaml', 'r') as file: |
| | config = yaml.safe_load(file) |
| | return config |
| |
|
| |
|
| | config = load_config() |
| |
|
| |
|
| | def load_embeddings(model_name=config["embeddings"]["name"], |
| | model_kwargs={'device': config["embeddings"]["device"]}): |
| | return HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) |
| |
|
| |
|
| | def load_documents(directory: str): |
| | """Loads all documents from a directory and returns a list of Document objects |
| | args: directory format = directory/ |
| | """ |
| | text_splitter = RecursiveCharacterTextSplitter(chunk_size=config["TextSplitter"]["chunk_size"], |
| | chunk_overlap=config["TextSplitter"]["chunk_overlap"]) |
| | documents = [] |
| | for item_path in tqdm(glob(directory + "*.pdf")): |
| | loader = PyPDFLoader(item_path) |
| | documents.extend(loader.load_and_split(text_splitter=text_splitter)) |
| |
|
| | return documents |
| |
|
| |
|
| | template = """Use the following pieces of context to answer the question at the end. |
| | If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| | Use three sentences maximum and keep the answer as concise as possible. |
| | Always say "thanks for asking!" at the end of the answer. |
| | {context} |
| | Question: {question} |
| | Helpful Answer:""" |
| | QA_CHAIN_PROMPT = PromptTemplate.from_template(template) |
| |
|
| | repo_id = "google/flan-t5-xxl" |
| |
|
| |
|
| | def get_llm(): |
| | llm = HuggingFaceHub( |
| | repo_id=repo_id, model_kwargs={"temperature": 0.5, "max_length": 200} |
| | ) |
| | return llm |
| |
|
| |
|
| | def answer_question(question: str): |
| | embedding_function = load_embeddings() |
| | documents = load_documents("data/") |
| |
|
| | db = FAISS.from_documents(documents, embedding_function) |
| |
|
| | retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4}) |
| |
|
| | qa_chain = RetrievalQA.from_chain_type( |
| | get_llm(), |
| | retriever=retriever, |
| | chain_type="stuff", |
| | chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, |
| | return_source_documents=True |
| | ) |
| |
|
| | output = qa_chain({"query": question}) |
| | return output["result"] |
| |
|
| |
|
| | |
| | with gr.Blocks() as demo: |
| | with gr.Tab("PdfChat"): |
| | with gr.Column(): |
| | ans = gr.Textbox(label="Answer", lines=10) |
| |
|
| | que = gr.Textbox(label="Ask a Question", lines=2) |
| |
|
| | bttn = gr.Button(value="Submit") |
| |
|
| | bttn.click(fn=answer_question, inputs=[que], outputs=[ans]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|