iajitpanday commited on
Commit
5717062
·
verified ·
1 Parent(s): fee555b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -112
app.py CHANGED
@@ -1,71 +1,71 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
 
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import FAISS
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
  from langchain.chains import RetrievalQA
9
  from langchain.llms import HuggingFacePipeline
10
- import os
11
- import tempfile
12
- from typing import List, Tuple
13
- import requests
14
- from bs4 import BeautifulSoup
15
 
16
- # Initialize the model and tokenizer
17
  class CustomerSupportChatbot:
18
  def __init__(self):
19
- # Initialize embeddings
20
  self.embeddings = HuggingFaceEmbeddings(
21
- model_name="sentence-transformers/all-MiniLM-L6-v2"
 
22
  )
23
 
24
- # Initialize the base language model
25
- model_name = "microsoft/DialoGPT-medium" # You can change this to a different model
26
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
28
-
29
- # Create text generation pipeline
30
- self.pipe = pipeline(
31
  "text-generation",
32
- model=self.model,
33
- tokenizer=self.tokenizer,
34
- max_length=512,
 
35
  temperature=0.7,
36
- top_p=0.9,
37
- repetition_penalty=1.1
38
  )
39
 
40
- # Initialize HuggingFace pipeline for LangChain
41
- self.llm = HuggingFacePipeline(pipeline=self.pipe)
42
-
43
  # Initialize vector store
44
  self.vector_store = None
45
- self.qa_chain = None
46
 
47
  # Text splitter
48
  self.text_splitter = RecursiveCharacterTextSplitter(
49
- chunk_size=1000,
50
- chunk_overlap=200
 
51
  )
52
 
53
  def process_documents(self, pdf_files, website_urls) -> str:
54
  """Process PDF files and website URLs to create a vector store"""
55
- documents = []
56
 
57
  # Process PDF files
58
  if pdf_files:
59
  for pdf_file in pdf_files:
60
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
61
- tmp_file.write(pdf_file.file.read())
62
- tmp_file.flush()
 
 
 
 
 
63
 
64
- loader = PyPDFLoader(tmp_file.name)
65
- pdf_documents = loader.load()
66
- documents.extend(pdf_documents)
67
-
68
- os.unlink(tmp_file.name)
69
 
70
  # Process websites
71
  if website_urls:
@@ -73,86 +73,111 @@ class CustomerSupportChatbot:
73
  url = url.strip()
74
  if url:
75
  try:
76
- loader = WebBaseLoader(url)
77
- web_documents = loader.load()
78
- documents.extend(web_documents)
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
  print(f"Error loading {url}: {str(e)}")
81
 
82
- if not documents:
83
  return "No documents processed. Please upload PDFs or provide website URLs."
84
 
85
- # Split documents into chunks
86
- texts = self.text_splitter.split_documents(documents)
87
-
88
- # Create vector store
89
- self.vector_store = FAISS.from_documents(texts, self.embeddings)
90
-
91
- # Create QA chain
92
- self.qa_chain = RetrievalQA.from_chain_type(
93
- llm=self.llm,
94
- chain_type="stuff",
95
- retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
96
- return_source_documents=True
97
- )
 
 
98
 
99
- return f"Successfully processed {len(documents)} documents and created knowledge base."
 
 
 
 
 
100
 
101
  def chat(self, message: str, history: List[Tuple[str, str]]) -> str:
102
  """Chat function that uses RAG if available"""
103
 
104
- # If we have a knowledge base, use RAG
105
- if self.qa_chain:
106
- try:
107
- # Get relevant context from the knowledge base
108
- result = self.qa_chain({"query": message})
109
-
110
- # Format the response with context
111
- response = result["result"]
112
-
113
- # Add source information if available
114
- if "source_documents" in result and result["source_documents"]:
115
- sources = set()
116
- for doc in result["source_documents"]:
117
- if hasattr(doc, 'metadata') and 'source' in doc.metadata:
118
- sources.add(doc.metadata['source'])
119
-
120
- if sources:
121
- response += "\n\nSources: " + ", ".join(list(sources)[:3])
122
-
123
- return response
124
 
125
- except Exception as e:
126
- print(f"Error using RAG: {str(e)}")
127
- # Fall back to basic chat if RAG fails
128
-
129
- # Basic chat without RAG
130
- # Format conversation history for the model
131
- conversation = ""
132
- for user_msg, bot_msg in history[-5:]: # Use last 5 exchanges
133
- conversation += f"User: {user_msg}\nBot: {bot_msg}\n"
134
-
135
- conversation += f"User: {message}\nBot:"
136
-
137
- # Generate response
138
- response = self.pipe(conversation, max_length=len(conversation) + 100)[0]['generated_text']
139
-
140
- # Extract only the bot's response
141
- bot_response = response.split("Bot:")[-1].strip()
142
 
143
- return bot_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # Initialize the chatbot
146
  chatbot = CustomerSupportChatbot()
147
 
148
  # Create the Gradio interface
149
  def create_interface():
150
- with gr.Blocks(title="Customer Support Chatbot with RAG") as demo:
151
- gr.Markdown("# Customer Support Chatbot with RAG")
152
  gr.Markdown("Upload PDFs and/or provide website URLs to create a knowledge base for the chatbot.")
153
 
154
  with gr.Row():
155
  with gr.Column(scale=1):
 
156
  pdf_upload = gr.File(
157
  label="Upload PDF files",
158
  file_count="multiple",
@@ -166,20 +191,40 @@ def create_interface():
166
  )
167
 
168
  process_btn = gr.Button("Process Documents", variant="primary")
169
- status_text = gr.Textbox(label="Status", interactive=False)
170
 
171
  with gr.Column(scale=2):
172
- chatbot_interface = gr.Chatbot(label="Customer Support Chat")
 
 
 
 
 
173
  msg_input = gr.Textbox(
174
  label="Message",
175
  placeholder="Ask a question...",
176
- lines=2
 
177
  )
178
 
179
  with gr.Row():
180
  submit_btn = gr.Button("Send", variant="primary")
181
  clear_btn = gr.Button("Clear Chat")
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Event handlers
184
  def process_documents(pdf_files, website_urls):
185
  return chatbot.process_documents(pdf_files, website_urls)
@@ -211,20 +256,10 @@ def create_interface():
211
  )
212
 
213
  clear_btn.click(
214
- fn=lambda: None,
215
- outputs=chatbot_interface
216
- )
217
-
218
- # Add example questions
219
- gr.Examples(
220
- examples=[
221
- "What are your customer support hours?",
222
- "How can I track my order?",
223
- "What is the return policy?",
224
- "How do I contact customer service?",
225
- "What payment methods do you accept?"
226
- ],
227
- inputs=msg_input
228
  )
229
 
230
  return demo
 
1
  import gradio as gr
2
+ import os
3
+ import tempfile
4
+ from typing import List, Tuple
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+ import json
8
+
9
+ # Instead of using torch/transformers directly, use HuggingFace's Inference API
10
+ from transformers import pipeline
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.vectorstores import FAISS
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.document_loaders import PyPDFLoader, WebBaseLoader
15
  from langchain.chains import RetrievalQA
16
  from langchain.llms import HuggingFacePipeline
17
+ from langchain.schema import Document
 
 
 
 
18
 
19
+ # Initialize the chatbot class
20
  class CustomerSupportChatbot:
21
  def __init__(self):
22
+ # Use a lighter embedding model
23
  self.embeddings = HuggingFaceEmbeddings(
24
+ model_name="all-MiniLM-L6-v2",
25
+ model_kwargs={'device': 'cpu'}
26
  )
27
 
28
+ # Use a simpler model for chat
29
+ self.chat_pipeline = pipeline(
 
 
 
 
 
30
  "text-generation",
31
+ model="microsoft/DialoGPT-small", # Using smaller model
32
+ device_map="auto",
33
+ torch_dtype="auto",
34
+ max_new_tokens=100,
35
  temperature=0.7,
36
+ pad_token_id=50256
 
37
  )
38
 
 
 
 
39
  # Initialize vector store
40
  self.vector_store = None
41
+ self.documents = []
42
 
43
  # Text splitter
44
  self.text_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size=500,
46
+ chunk_overlap=50,
47
+ length_function=len,
48
  )
49
 
50
  def process_documents(self, pdf_files, website_urls) -> str:
51
  """Process PDF files and website URLs to create a vector store"""
52
+ self.documents = []
53
 
54
  # Process PDF files
55
  if pdf_files:
56
  for pdf_file in pdf_files:
57
+ try:
58
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
59
+ tmp_file.write(pdf_file.read())
60
+ tmp_file.flush()
61
+
62
+ loader = PyPDFLoader(tmp_file.name)
63
+ pdf_documents = loader.load()
64
+ self.documents.extend(pdf_documents)
65
 
66
+ os.unlink(tmp_file.name)
67
+ except Exception as e:
68
+ print(f"Error processing PDF: {str(e)}")
 
 
69
 
70
  # Process websites
71
  if website_urls:
 
73
  url = url.strip()
74
  if url:
75
  try:
76
+ # Simple web scraping
77
+ response = requests.get(url, timeout=10)
78
+ soup = BeautifulSoup(response.content, 'html.parser')
79
+
80
+ # Extract text content
81
+ text = soup.get_text(separator=' ', strip=True)
82
+
83
+ # Create a document
84
+ doc = Document(
85
+ page_content=text,
86
+ metadata={"source": url}
87
+ )
88
+ self.documents.append(doc)
89
  except Exception as e:
90
  print(f"Error loading {url}: {str(e)}")
91
 
92
+ if not self.documents:
93
  return "No documents processed. Please upload PDFs or provide website URLs."
94
 
95
+ try:
96
+ # Split documents into chunks
97
+ texts = self.text_splitter.split_documents(self.documents)
98
+
99
+ # Create vector store
100
+ self.vector_store = FAISS.from_documents(texts, self.embeddings)
101
+
102
+ return f"Successfully processed {len(self.documents)} documents into {len(texts)} chunks."
103
+ except Exception as e:
104
+ return f"Error creating vector store: {str(e)}"
105
+
106
+ def search_documents(self, query: str, k: int = 3) -> List[str]:
107
+ """Search for relevant documents"""
108
+ if not self.vector_store:
109
+ return []
110
 
111
+ try:
112
+ docs = self.vector_store.similarity_search(query, k=k)
113
+ return [doc.page_content for doc in docs]
114
+ except Exception as e:
115
+ print(f"Error searching documents: {str(e)}")
116
+ return []
117
 
118
  def chat(self, message: str, history: List[Tuple[str, str]]) -> str:
119
  """Chat function that uses RAG if available"""
120
 
121
+ # Search for relevant context
122
+ if self.vector_store:
123
+ relevant_docs = self.search_documents(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if relevant_docs:
126
+ # Create context from relevant documents
127
+ context = "\n\n".join(relevant_docs[:2]) # Use top 2 documents
128
+
129
+ # Create a prompt with context
130
+ prompt = f"""Based on the following context, please answer the customer's question:
131
+
132
+ Context:
133
+ {context}
134
+
135
+ Customer Question: {message}
136
+
137
+ Answer: """
138
+ else:
139
+ prompt = f"Customer Question: {message}\nAnswer: "
140
+ else:
141
+ prompt = f"Customer Question: {message}\nAnswer: "
142
 
143
+ try:
144
+ # Generate response
145
+ response = self.chat_pipeline(
146
+ prompt,
147
+ max_new_tokens=100,
148
+ do_sample=True,
149
+ temperature=0.7,
150
+ top_p=0.9,
151
+ num_return_sequences=1
152
+ )[0]['generated_text']
153
+
154
+ # Extract just the answer part
155
+ if "Answer: " in response:
156
+ answer = response.split("Answer: ")[-1].strip()
157
+ else:
158
+ answer = response.strip()
159
+
160
+ # Clean up the response
161
+ answer = answer.split("\n")[0].strip() # Take first line only
162
+
163
+ return answer if answer else "I'm here to help! Could you please rephrase your question?"
164
+
165
+ except Exception as e:
166
+ print(f"Error generating response: {str(e)}")
167
+ return "I'm sorry, I encountered an error. Could you please try again?"
168
 
169
  # Initialize the chatbot
170
  chatbot = CustomerSupportChatbot()
171
 
172
  # Create the Gradio interface
173
  def create_interface():
174
+ with gr.Blocks(title="Customer Support Chatbot with RAG", theme=gr.themes.Soft()) as demo:
175
+ gr.Markdown("# 🤖 Customer Support Chatbot with RAG")
176
  gr.Markdown("Upload PDFs and/or provide website URLs to create a knowledge base for the chatbot.")
177
 
178
  with gr.Row():
179
  with gr.Column(scale=1):
180
+ gr.Markdown("### 📁 Document Upload")
181
  pdf_upload = gr.File(
182
  label="Upload PDF files",
183
  file_count="multiple",
 
191
  )
192
 
193
  process_btn = gr.Button("Process Documents", variant="primary")
194
+ status_text = gr.Textbox(label="Status", interactive=False, show_label=True)
195
 
196
  with gr.Column(scale=2):
197
+ gr.Markdown("### 💬 Chat")
198
+ chatbot_interface = gr.Chatbot(
199
+ label="Customer Support Chat",
200
+ height=400,
201
+ show_label=True
202
+ )
203
  msg_input = gr.Textbox(
204
  label="Message",
205
  placeholder="Ask a question...",
206
+ lines=2,
207
+ show_label=True
208
  )
209
 
210
  with gr.Row():
211
  submit_btn = gr.Button("Send", variant="primary")
212
  clear_btn = gr.Button("Clear Chat")
213
 
214
+ # Example questions section
215
+ gr.Markdown("### 💡 Example Questions")
216
+ gr.Examples(
217
+ examples=[
218
+ "What are your customer support hours?",
219
+ "How can I track my order?",
220
+ "What is the return policy?",
221
+ "How do I contact customer service?",
222
+ "What payment methods do you accept?"
223
+ ],
224
+ inputs=msg_input,
225
+ label="Click on any example to try it:"
226
+ )
227
+
228
  # Event handlers
229
  def process_documents(pdf_files, website_urls):
230
  return chatbot.process_documents(pdf_files, website_urls)
 
256
  )
257
 
258
  clear_btn.click(
259
+ lambda: None,
260
+ None,
261
+ chatbot_interface,
262
+ queue=False
 
 
 
 
 
 
 
 
 
 
263
  )
264
 
265
  return demo