File size: 4,928 Bytes
ca80a11 8308438 ca80a11 132540f a3df18c 132540f ca80a11 132540f ca80a11 6f344a4 132540f 6f344a4 132540f ca80a11 6f344a4 132540f 6f344a4 132540f b053964 6f344a4 132540f 6f344a4 132540f ca80a11 6f344a4 132540f 6f344a4 132540f b053964 132540f ca80a11 132540f ca80a11 6f344a4 132540f 6f344a4 132540f 174562d 132540f 174562d 132540f ca80a11 132540f a646e09 132540f a646e09 35fa8f3 6f344a4 132540f 6f344a4 09e2e95 132540f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import pandas as pd
import logging
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ------------------------------------------------------------------
# 1. Load and Prepare the Bank FAQ Dataset
# ------------------------------------------------------------------
# Load the dataset from Hugging Face (Bank FAQs)
ds = load_dataset("maxpro291/bankfaqs_dataset")
train_ds = ds['train']
data = train_ds[:] # load all examples
# Separate questions and answers from the 'text' field
questions = []
answers = []
for entry in data['text']:
if entry.startswith("Q:"):
questions.append(entry)
elif entry.startswith("A:"):
answers.append(entry)
# Create a DataFrame with questions and answers
Bank_Data = pd.DataFrame({'question': questions, 'answer': answers})
# Build context strings (combining question and answer) for the vector store
context_data = []
for i in range(len(Bank_Data)):
context = f"Question: {Bank_Data.iloc[i]['question']} Answer: {Bank_Data.iloc[i]['answer']}"
context_data.append(context)
# ------------------------------------------------------------------
# 2. Create the Vector Store for Retrieval
# ------------------------------------------------------------------
# Initialize the embedding model
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create a Chroma vector store from the context data
vectorstore = Chroma.from_texts(
texts=context_data,
embedding=embed_model,
persist_directory="./chroma_db_bank"
)
# Create a retriever from the vector store
retriever = vectorstore.as_retriever()
# ------------------------------------------------------------------
# 3. Initialize the LLM for Generation
# ------------------------------------------------------------------
# Note:
# The model "meta-llama/Llama-2-7b-chat-hf" is gated. If you have access,
# authenticate using huggingface-cli login. Otherwise, switch to a public model.
model_name = "gpt2" # Replace with "meta-llama/Llama-2-7b-chat-hf" if you are authenticated.
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create a text-generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
# Wrap the pipeline in LangChain's HuggingFacePipeline
huggingface_model = HuggingFacePipeline(pipeline=pipe)
# ------------------------------------------------------------------
# 4. Build the Retrieval-Augmented Generation (RAG) Chain
# ------------------------------------------------------------------
# Define a prompt template that instructs the assistant to use provided context
template = (
"You are a helpful banking assistant. "
"Use the provided context if it is relevant to answer the question. "
"If not, answer using your general banking knowledge.\n"
"Question: {question}\n"
"Answer:"
)
rag_prompt = PromptTemplate.from_template(template)
# Build the RAG chain by piping the retriever, prompt, LLM, and an output parser
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| huggingface_model
| StrOutputParser()
)
# ------------------------------------------------------------------
# 5. Set Up the Gradio Chat Interface
# ------------------------------------------------------------------
def rag_memory_stream(message, history):
partial_text = ""
# Stream the generated answer
for new_text in rag_chain.stream(message):
partial_text += new_text
yield partial_text
# Example questions
examples = [
"I want to open an account",
"What is a savings account?",
"How do I use an ATM?",
"How can I resolve a bank account issue?"
]
title = "Your Personal Banking Assistant 💬"
description = (
"Welcome! I’m here to answer your questions about banking and related topics. "
"Ask me anything, and I’ll do my best to assist you."
)
# Create a chat interface using Gradio
demo = gr.ChatInterface(
fn=rag_memory_stream,
title=title,
description=description,
examples=examples,
theme="glass",
)
# ------------------------------------------------------------------
# 6. Launch the App
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(share=True) |