NSamson1 commited on
Commit
132540f
·
verified ·
1 Parent(s): cc4abf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -137
app.py CHANGED
@@ -1,173 +1,142 @@
1
  import os
2
  import pandas as pd
3
  import logging
4
- import threading
5
- from fastapi import FastAPI, Header, HTTPException
6
- import uvicorn
7
- import gradio as gr
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
- from langchain_community.vectorstores import Chroma
10
  from langchain_core.prompts import PromptTemplate
11
  from langchain_core.output_parsers import StrOutputParser
12
  from langchain_core.runnables import RunnablePassthrough
13
- from datasets import load_dataset
14
- from transformers import (
15
- AutoTokenizer,
16
- AutoModelForCausalLM,
17
- pipeline,
18
- BitsAndBytesConfig
19
- )
20
- import torch # Explicitly imported for CUDA management
21
-
22
- # ====================== CONFIGURATION ======================
23
- API_KEY = "Samson"
24
- MODEL_NAME = "microsoft/phi-2"
25
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
26
- # ===========================================================
27
-
28
- # Configure logging
29
- logging.basicConfig(
30
- level=logging.INFO,
31
- format='%(asctime)s - %(levelname)s - %(message)s'
32
- )
33
 
34
- # Clear CUDA cache if using GPU
35
- if torch.cuda.is_available():
36
- torch.cuda.empty_cache()
37
 
38
  # ------------------------------------------------------------------
39
- # 1. Load and Prepare Dataset
40
  # ------------------------------------------------------------------
41
- def load_data():
42
- try:
43
- ds = load_dataset("maxpro291/bankfaqs_dataset")
44
- data = ds['train'][:]
45
- questions = [entry for entry in data['text'] if entry.startswith("Q:")]
46
- answers = [entry for entry in data['text'] if entry.startswith("A:")]
47
- return pd.DataFrame({'question': questions, 'answer': answers})
48
- except Exception as e:
49
- logging.error(f"Error loading dataset: {str(e)}")
50
- raise
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # ------------------------------------------------------------------
53
- # 2. Initialize Embeddings and Vector Store
54
  # ------------------------------------------------------------------
55
- def init_vectordb(data):
56
- try:
57
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
58
- texts = [f"Q: {q}\nA: {a}" for q, a in zip(data['question'], data['answer'])]
59
- return Chroma.from_texts(
60
- texts=texts,
61
- embedding=embeddings,
62
- persist_directory="./chroma_db_bank"
63
- )
64
- except Exception as e:
65
- logging.error(f"Error initializing vector store: {str(e)}")
66
- raise
67
 
68
  # ------------------------------------------------------------------
69
- # 3. Initialize LLM with Quantization
70
  # ------------------------------------------------------------------
71
- def load_llm():
72
- try:
73
- quantization_config = BitsAndBytesConfig(
74
- load_in_4bit=True,
75
- bnb_4bit_use_double_quant=True,
76
- bnb_4bit_quant_type="nf4",
77
- bnb_4bit_compute_dtype="float16"
78
- )
79
-
80
- tokenizer = AutoTokenizer.from_pretrained(
81
- MODEL_NAME,
82
- trust_remote_code=True,
83
- padding_side="left" # Critical for phi-2
84
- )
85
-
86
- model = AutoModelForCausalLM.from_pretrained(
87
- MODEL_NAME,
88
- device_map="auto",
89
- trust_remote_code=True,
90
- quantization_config=quantization_config
91
- )
92
-
93
- return pipeline(
94
- "text-generation",
95
- model=model,
96
- tokenizer=tokenizer,
97
- max_new_tokens=512,
98
- temperature=0.7,
99
- top_p=0.95,
100
- do_sample=True
101
- )
102
- except Exception as e:
103
- logging.error(f"Error loading LLM: {str(e)}")
104
- raise
105
-
106
- # Initialize components
107
- bank_data = load_data()
108
- retriever = init_vectordb(bank_data).as_retriever()
109
- llm_pipeline = load_llm()
110
 
111
  # ------------------------------------------------------------------
112
- # 4. Build RAG Chain
113
  # ------------------------------------------------------------------
114
- template = """You are a banking assistant. Use context if relevant:
115
- Context: {context}
116
- Question: {question}
117
- Answer:"""
118
- prompt = PromptTemplate.from_template(template)
 
 
 
 
119
 
 
120
  rag_chain = (
121
  {"context": retriever, "question": RunnablePassthrough()}
122
- | prompt
123
- | llm_pipeline
124
  | StrOutputParser()
125
  )
126
 
127
  # ------------------------------------------------------------------
128
- # 5. FastAPI Setup
129
  # ------------------------------------------------------------------
130
- app = FastAPI()
131
-
132
- def validate_api_key(api_key: str = Header(None)):
133
- if api_key != API_KEY:
134
- raise HTTPException(status_code=401, detail="Invalid API Key")
135
- return True
136
 
137
- @app.post("/chat")
138
- async def chat_endpoint(question: str, authorization: str = Header(None)):
139
- validate_api_key(authorization)
140
- response = ""
141
- for chunk in rag_chain.stream(question):
142
- response += chunk
143
- return {"response": response}
144
 
145
- # ------------------------------------------------------------------
146
- # 6. Gradio Interface
147
- # ------------------------------------------------------------------
148
- def respond(message, history):
149
- return next(rag_chain.stream(message))
150
 
 
151
  demo = gr.ChatInterface(
152
- fn=respond,
153
- title="Banking Assistant 🔒",
154
- examples=[
155
- "How do I open an account?",
156
- "What's the interest rate?",
157
- "How do I apply for a loan?"
158
- ],
159
- theme="glass"
160
  )
161
 
162
  # ------------------------------------------------------------------
163
- # 7. Launch Servers
164
  # ------------------------------------------------------------------
165
  if __name__ == "__main__":
166
- # Start Gradio in separate thread
167
- threading.Thread(
168
- target=demo.launch,
169
- kwargs={"server_name": "0.0.0.0", "server_port": 7860, "share": False}
170
- ).start()
171
-
172
- # Start FastAPI
173
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  import pandas as pd
3
  import logging
4
+ from datasets import load_dataset
5
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
6
+ from langchain_chroma import Chroma
 
 
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.runnables import RunnablePassthrough
10
+ import gradio as gr
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
15
 
16
  # ------------------------------------------------------------------
17
+ # 1. Load and Prepare the Bank FAQ Dataset
18
  # ------------------------------------------------------------------
19
+ # Load the dataset from Hugging Face (Bank FAQs)
20
+ ds = load_dataset("maxpro291/bankfaqs_dataset")
21
+ train_ds = ds['train']
22
+ data = train_ds[:] # load all examples
23
+
24
+ # Separate questions and answers from the 'text' field
25
+ questions = []
26
+ answers = []
27
+ for entry in data['text']:
28
+ if entry.startswith("Q:"):
29
+ questions.append(entry)
30
+ elif entry.startswith("A:"):
31
+ answers.append(entry)
32
+
33
+ # Create a DataFrame with questions and answers
34
+ Bank_Data = pd.DataFrame({'question': questions, 'answer': answers})
35
+
36
+ # Build context strings (combining question and answer) for the vector store
37
+ context_data = []
38
+ for i in range(len(Bank_Data)):
39
+ context = f"Question: {Bank_Data.iloc[i]['question']} Answer: {Bank_Data.iloc[i]['answer']}"
40
+ context_data.append(context)
41
 
42
  # ------------------------------------------------------------------
43
+ # 2. Create the Vector Store for Retrieval
44
  # ------------------------------------------------------------------
45
+ # Initialize the embedding model
46
+ embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
47
+
48
+ # Create a Chroma vector store from the context data
49
+ vectorstore = Chroma.from_texts(
50
+ texts=context_data,
51
+ embedding=embed_model,
52
+ persist_directory="./chroma_db_bank"
53
+ )
54
+
55
+ # Create a retriever from the vector store
56
+ retriever = vectorstore.as_retriever()
57
 
58
  # ------------------------------------------------------------------
59
+ # 3. Initialize the LLM for Generation
60
  # ------------------------------------------------------------------
61
+ # Note:
62
+ # The model "meta-llama/Llama-2-7b-chat-hf" is gated. If you have access,
63
+ # authenticate using huggingface-cli login. Otherwise, switch to a public model.
64
+ model_name = "gpt2" # Replace with "meta-llama/Llama-2-7b-chat-hf" if you are authenticated.
65
+
66
+ # Load tokenizer and model
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+ model = AutoModelForCausalLM.from_pretrained(model_name)
69
+
70
+ # Create a text-generation pipeline
71
+ pipe = pipeline(
72
+ "text-generation",
73
+ model=model,
74
+ tokenizer=tokenizer,
75
+ max_length=512,
76
+ temperature=0.7,
77
+ top_p=0.95,
78
+ repetition_penalty=1.15
79
+ )
80
+
81
+ # Wrap the pipeline in LangChain's HuggingFacePipeline
82
+ huggingface_model = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # ------------------------------------------------------------------
85
+ # 4. Build the Retrieval-Augmented Generation (RAG) Chain
86
  # ------------------------------------------------------------------
87
+ # Define a prompt template that instructs the assistant to use provided context
88
+ template = (
89
+ "You are a helpful banking assistant. "
90
+ "Use the provided context if it is relevant to answer the question. "
91
+ "If not, answer using your general banking knowledge.\n"
92
+ "Question: {question}\n"
93
+ "Answer:"
94
+ )
95
+ rag_prompt = PromptTemplate.from_template(template)
96
 
97
+ # Build the RAG chain by piping the retriever, prompt, LLM, and an output parser
98
  rag_chain = (
99
  {"context": retriever, "question": RunnablePassthrough()}
100
+ | rag_prompt
101
+ | huggingface_model
102
  | StrOutputParser()
103
  )
104
 
105
  # ------------------------------------------------------------------
106
+ # 5. Set Up the Gradio Chat Interface
107
  # ------------------------------------------------------------------
108
+ def rag_memory_stream(message, history):
109
+ partial_text = ""
110
+ # Stream the generated answer
111
+ for new_text in rag_chain.stream(message):
112
+ partial_text += new_text
113
+ yield partial_text
114
 
115
+ # Example questions
116
+ examples = [
117
+ "I want to open an account",
118
+ "What is a savings account?",
119
+ "How do I use an ATM?",
120
+ "How can I resolve a bank account issue?"
121
+ ]
122
 
123
+ title = "Your Personal Banking Assistant 💬"
124
+ description = (
125
+ "Welcome! I’m here to answer your questions about banking and related topics. "
126
+ "Ask me anything, and I’ll do my best to assist you."
127
+ )
128
 
129
+ # Create a chat interface using Gradio
130
  demo = gr.ChatInterface(
131
+ fn=rag_memory_stream,
132
+ title=title,
133
+ description=description,
134
+ examples=examples,
135
+ theme="glass",
 
 
 
136
  )
137
 
138
  # ------------------------------------------------------------------
139
+ # 6. Launch the App
140
  # ------------------------------------------------------------------
141
  if __name__ == "__main__":
142
+ demo.launch(share=True)