File size: 4,928 Bytes
ca80a11
8308438
ca80a11
132540f
 
 
a3df18c
 
 
132540f
 
ca80a11
132540f
 
ca80a11
6f344a4
132540f
6f344a4
132540f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80a11
6f344a4
132540f
6f344a4
132540f
 
 
 
 
 
 
 
 
 
 
 
b053964
6f344a4
132540f
6f344a4
132540f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80a11
6f344a4
132540f
6f344a4
132540f
 
 
 
 
 
 
 
 
b053964
132540f
ca80a11
 
132540f
 
ca80a11
 
 
6f344a4
132540f
6f344a4
132540f
 
 
 
 
 
174562d
132540f
 
 
 
 
 
 
174562d
132540f
 
 
 
 
ca80a11
132540f
a646e09
132540f
 
 
 
 
a646e09
35fa8f3
6f344a4
132540f
6f344a4
09e2e95
132540f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import pandas as pd
import logging
from datasets import load_dataset
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# ------------------------------------------------------------------
# 1. Load and Prepare the Bank FAQ Dataset
# ------------------------------------------------------------------
# Load the dataset from Hugging Face (Bank FAQs)
ds = load_dataset("maxpro291/bankfaqs_dataset")
train_ds = ds['train']
data = train_ds[:]  # load all examples

# Separate questions and answers from the 'text' field
questions = []
answers = []
for entry in data['text']:
    if entry.startswith("Q:"):
        questions.append(entry)
    elif entry.startswith("A:"):
        answers.append(entry)

# Create a DataFrame with questions and answers
Bank_Data = pd.DataFrame({'question': questions, 'answer': answers})

# Build context strings (combining question and answer) for the vector store
context_data = []
for i in range(len(Bank_Data)):
    context = f"Question: {Bank_Data.iloc[i]['question']} Answer: {Bank_Data.iloc[i]['answer']}"
    context_data.append(context)

# ------------------------------------------------------------------
# 2. Create the Vector Store for Retrieval
# ------------------------------------------------------------------
# Initialize the embedding model
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Create a Chroma vector store from the context data
vectorstore = Chroma.from_texts(
    texts=context_data,
    embedding=embed_model,
    persist_directory="./chroma_db_bank"
)

# Create a retriever from the vector store
retriever = vectorstore.as_retriever()

# ------------------------------------------------------------------
# 3. Initialize the LLM for Generation
# ------------------------------------------------------------------
# Note:
# The model "meta-llama/Llama-2-7b-chat-hf" is gated. If you have access,
# authenticate using huggingface-cli login. Otherwise, switch to a public model.
model_name = "gpt2"  # Replace with "meta-llama/Llama-2-7b-chat-hf" if you are authenticated.

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Create a text-generation pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=512,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.15
)

# Wrap the pipeline in LangChain's HuggingFacePipeline
huggingface_model = HuggingFacePipeline(pipeline=pipe)

# ------------------------------------------------------------------
# 4. Build the Retrieval-Augmented Generation (RAG) Chain
# ------------------------------------------------------------------
# Define a prompt template that instructs the assistant to use provided context
template = (
    "You are a helpful banking assistant. "
    "Use the provided context if it is relevant to answer the question. "
    "If not, answer using your general banking knowledge.\n"
    "Question: {question}\n"
    "Answer:"
)
rag_prompt = PromptTemplate.from_template(template)

# Build the RAG chain by piping the retriever, prompt, LLM, and an output parser
rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | rag_prompt
    | huggingface_model
    | StrOutputParser()
)

# ------------------------------------------------------------------
# 5. Set Up the Gradio Chat Interface
# ------------------------------------------------------------------
def rag_memory_stream(message, history):
    partial_text = ""
    # Stream the generated answer
    for new_text in rag_chain.stream(message):
        partial_text += new_text
        yield partial_text

# Example questions
examples = [
    "I want to open an account", 
    "What is a savings account?",
    "How do I use an ATM?",
    "How can I resolve a bank account issue?"
]

title = "Your Personal Banking Assistant 💬"
description = (
    "Welcome! I’m here to answer your questions about banking and related topics. "
    "Ask me anything, and I’ll do my best to assist you."
)

# Create a chat interface using Gradio
demo = gr.ChatInterface(
    fn=rag_memory_stream,
    title=title,
    description=description,
    examples=examples,
    theme="glass",
)

# ------------------------------------------------------------------
# 6. Launch the App
# ------------------------------------------------------------------
if __name__ == "__main__":
    demo.launch(share=True)