Wenye He commited on
Commit
76001f1
·
verified ·
1 Parent(s): 8f7d010

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -116
app.py CHANGED
@@ -3,22 +3,42 @@ import gradio as gr
3
  import torch
4
  import time
5
  import os
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
7
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
8
  from langchain_text_splitters import RecursiveCharacterTextSplitter
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import FAISS
11
 
 
 
 
 
12
  # Configuration
13
  MODEL_CONFIG = {
14
  "phi-3": {
15
  "model_name": "microsoft/phi-3-mini-4k-instruct",
16
- "template": "<|user|>\n{message}<|end|>\n<|assistant|>"
 
 
 
 
 
 
 
 
17
  },
18
  "llama3-8b": {
19
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
20
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
21
- {message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
 
 
 
 
 
 
 
22
  }
23
  }
24
 
@@ -34,131 +54,142 @@ class ChatModel:
34
  self.models = {}
35
  self.tokenizers = {}
36
  self.vectorstore = {}
 
 
 
37
 
38
  def load_model(self, model_name):
 
39
  if model_name not in self.models:
 
40
  config = MODEL_CONFIG[model_name]
41
- tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
42
- tokenizer.pad_token = tokenizer.eos_token
43
- model = AutoModelForCausalLM.from_pretrained(
44
- config["model_name"],
45
- quantization_config=bnb_config,
46
- device_map="auto",
47
- torch_dtype=torch.float16,
48
- )
49
- self.models[model_name] = model
50
- self.tokenizers[model_name] = tokenizer
51
-
52
-
 
 
 
53
 
54
  def load_vector_store(self, store_name):
55
- """Cache vector stores in memory"""
56
-
57
  if store_name not in self.vectorstore:
58
- embeddings = HuggingFaceEmbeddings(
59
- model_name="BAAI/bge-small-en-v1.5"
60
- )
61
- self.vectorstore[store_name] = FAISS.load_local(
62
- f"vector_stores_index/{store_name}",
63
- embeddings,
64
- allow_dangerous_deserialization=True
65
- )
 
 
 
 
 
66
  return self.vectorstore[store_name]
67
 
 
 
 
 
 
 
 
 
 
 
68
 
69
-
70
- def process_documents(self, files, progress=gr.Progress()):
71
- """Process uploaded documents into vector embeddings"""
72
  try:
73
- progress(0, desc="Starting document processing")
74
- documents = []
 
 
75
 
76
- # Load documents
77
- for file_path in progress.tqdm(files, desc="Loading files"):
78
- if file_path.endswith(".pdf"):
79
- loader = PyPDFLoader(file_path)
80
- elif file_path.endswith(".txt"):
81
- loader = TextLoader(file_path)
82
- else:
83
- continue
84
- documents.extend(loader.load())
85
 
86
- # Split documents
87
- progress(0.3, desc="Processing documents")
88
- text_splitter = RecursiveCharacterTextSplitter(
89
- chunk_size=512,
90
- chunk_overlap=50
91
  )
92
- texts = text_splitter.split_documents(documents)
93
 
94
- # Create embeddings
95
- progress(0.6, desc="Generating embeddings")
96
- embeddings = HuggingFaceEmbeddings(
97
- model_name="BAAI/bge-small-en-v1.5"
 
 
 
 
 
 
 
98
  )
99
 
100
- # Create vector store
101
- progress(0.8, desc="Building vector database")
102
- self.vectorstore = FAISS.from_documents(texts, embeddings)
103
 
104
- return "✅ Documents processed successfully! Ready for queries."
105
-
106
- except Exception as e:
107
- return f"❌ Error processing documents: {str(e)}"
 
 
 
108
 
109
- def generate(self, message, model_name, vector_store_name, history):
110
- start_time = time.time()
111
- self.load_model(model_name)
112
- vectorstore = self.load_vector_store(vector_store_name)
113
- config = MODEL_CONFIG[model_name]
114
-
115
- # Retrieve relevant context
116
- context = ""
117
- # if vectorstore:
118
- docs = vectorstore.similarity_search(message, k=3)
119
- context = "\n\n".join([d.page_content for d in docs])
120
-
121
- # Format prompt with context
122
- prompt = config["template"].format(
123
- message=f"Context:\n{context}\n\nQuestion: {message}"
124
- )
125
-
126
- # Generate response
127
- pipe = pipeline(
128
- "text-generation",
129
- model=self.models[model_name],
130
- tokenizer=self.tokenizers[model_name],
131
- max_new_tokens=384,
132
- temperature=0.7,
133
- top_p=0.9,
134
- repetition_penalty=1.1,
135
- do_sample=True,
136
- return_full_text=False
137
- )
138
-
139
- response = pipe(prompt)[0]['generated_text']
140
-
141
- # Calculate metrics
142
- elapsed_time = time.time() - start_time
143
- tokens = len(self.tokenizers[model_name].encode(response))
144
- tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
145
-
146
- return response, elapsed_time, tokens_per_sec
147
 
148
  # Initialize model handler
149
  model_handler = ChatModel()
150
 
151
  def chat(message, history, model_choice, vector_store_choice):
152
- print("vector_store_choice: ", vector_store_choice)
 
 
 
 
153
  try:
154
- response, response_time, token_speed = model_handler.generate(message, model_choice, vector_store_choice, history)
155
- formatted_response = f"{response}\n\n⏱️ Response Time: {response_time:.2f}s | 🚀 Speed: {token_speed:.2f} tokens/s"
 
 
 
 
 
 
 
 
 
 
 
 
156
  return [(message, formatted_response)]
 
157
  except Exception as e:
158
- return [(message, f"Error: {str(e)}")]
 
 
159
 
 
160
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
161
- gr.Markdown("# 🚀 LLM Chatbot with RAG & Performance Metrics")
162
 
163
  with gr.Row():
164
  model_choice = gr.Dropdown(
@@ -173,31 +204,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
173
  )
174
 
175
  with gr.Row():
176
- with gr.Column(scale=1):
177
- file_upload = gr.File(
178
- label="Upload Documents (PDF/TXT)",
179
- file_count="multiple",
180
- file_types=[".pdf", ".txt"],
181
- type="filepath"
182
- )
183
- status = gr.Textbox(label="Processing Status", interactive=False)
184
  with gr.Column(scale=3):
185
  chatbot = gr.Chatbot(height=500)
186
- msg = gr.Textbox(label="Message", placeholder="Type your question here...")
 
 
 
 
187
 
188
  with gr.Row():
189
  submit_btn = gr.Button("Send", variant="primary")
190
- clear_btn = gr.ClearButton([msg, chatbot, file_upload])
191
 
192
  # Event handlers
193
- file_upload.upload(
194
- fn=model_handler.process_documents,
195
- inputs=file_upload,
196
- outputs=status,
197
- show_progress="full"
198
- )
199
-
200
  msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
201
  submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
202
 
203
- demo.launch()
 
 
3
  import torch
4
  import time
5
  import os
6
+ import logging
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
8
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import FAISS
12
 
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
  # Configuration
18
  MODEL_CONFIG = {
19
  "phi-3": {
20
  "model_name": "microsoft/phi-3-mini-4k-instruct",
21
+ "template": """<|user|>
22
+ Using the following context, please answer the question. If the context doesn't contain relevant information, say so.
23
+
24
+ Context:
25
+ {context}
26
+
27
+ Question: {question}<|end|>
28
+ <|assistant|>
29
+ Let me help answer your question based on the provided context."""
30
  },
31
  "llama3-8b": {
32
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
33
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
34
+ Using the following context, please answer the question. If the context doesn't contain relevant information, say so.
35
+
36
+ Context:
37
+ {context}
38
+
39
+ Question: {question}<|eot_id|>
40
+ <|start_header_id|>assistant<|end_header_id|>
41
+ Let me help answer your question based on the provided context."""
42
  }
43
  }
44
 
 
54
  self.models = {}
55
  self.tokenizers = {}
56
  self.vectorstore = {}
57
+ self.embeddings = HuggingFaceEmbeddings(
58
+ model_name="BAAI/bge-small-en-v1.5"
59
+ )
60
 
61
  def load_model(self, model_name):
62
+ """Load and cache the model and tokenizer"""
63
  if model_name not in self.models:
64
+ logger.info(f"Loading model: {model_name}")
65
  config = MODEL_CONFIG[model_name]
66
+ try:
67
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ model = AutoModelForCausalLM.from_pretrained(
70
+ config["model_name"],
71
+ quantization_config=bnb_config,
72
+ device_map="auto",
73
+ torch_dtype=torch.float16,
74
+ )
75
+ self.models[model_name] = model
76
+ self.tokenizers[model_name] = tokenizer
77
+ logger.info(f"Successfully loaded model: {model_name}")
78
+ except Exception as e:
79
+ logger.error(f"Error loading model {model_name}: {str(e)}")
80
+ raise
81
 
82
  def load_vector_store(self, store_name):
83
+ """Load and cache vector stores"""
 
84
  if store_name not in self.vectorstore:
85
+ logger.info(f"Loading vector store: {store_name}")
86
+ try:
87
+ self.vectorstore[store_name] = FAISS.load_local(
88
+ f"vector_stores_index/{store_name}",
89
+ self.embeddings,
90
+ allow_dangerous_deserialization=True
91
+ )
92
+ # Verify vector store content
93
+ self.check_vectorstore(store_name)
94
+ logger.info(f"Successfully loaded vector store: {store_name}")
95
+ except Exception as e:
96
+ logger.error(f"Error loading vector store {store_name}: {str(e)}")
97
+ raise
98
  return self.vectorstore[store_name]
99
 
100
+ def check_vectorstore(self, store_name):
101
+ """Verify vector store content"""
102
+ try:
103
+ vectorstore = self.vectorstore[store_name]
104
+ sample_query = "test query"
105
+ docs = vectorstore.similarity_search(sample_query, k=1)
106
+ logger.info(f"Sample document from {store_name}: {docs[0].page_content[:200]}...")
107
+ except Exception as e:
108
+ logger.error(f"Error checking vector store {store_name}: {str(e)}")
109
+ raise
110
 
111
+ def generate(self, message, model_name, vector_store_name, history):
112
+ """Generate response using RAG"""
113
+ start_time = time.time()
114
  try:
115
+ # Load model and vector store
116
+ self.load_model(model_name)
117
+ vectorstore = self.load_vector_store(vector_store_name)
118
+ config = MODEL_CONFIG[model_name]
119
 
120
+ # Retrieve relevant context
121
+ logger.info(f"Retrieving context for query: {message}")
122
+ docs = vectorstore.similarity_search(message, k=3)
123
+ context = "\n\n".join([d.page_content for d in docs])
124
+ logger.info(f"Retrieved context: {context[:200]}...")
 
 
 
 
125
 
126
+ # Format prompt
127
+ prompt = config["template"].format(
128
+ context=context,
129
+ question=message
 
130
  )
 
131
 
132
+ # Generate response
133
+ pipe = pipeline(
134
+ "text-generation",
135
+ model=self.models[model_name],
136
+ tokenizer=self.tokenizers[model_name],
137
+ max_new_tokens=384,
138
+ temperature=0.3, # Lower temperature for more focused responses
139
+ top_p=0.9,
140
+ repetition_penalty=1.1,
141
+ do_sample=True,
142
+ return_full_text=False
143
  )
144
 
145
+ response = pipe(prompt)[0]['generated_text']
 
 
146
 
147
+ # Calculate metrics
148
+ elapsed_time = time.time() - start_time
149
+ tokens = len(self.tokenizers[model_name].encode(response))
150
+ tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
151
+
152
+ logger.info(f"Generated response in {elapsed_time:.2f}s")
153
+ return response, elapsed_time, tokens_per_sec
154
 
155
+ except Exception as e:
156
+ logger.error(f"Error in generate: {str(e)}")
157
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # Initialize model handler
160
  model_handler = ChatModel()
161
 
162
  def chat(message, history, model_choice, vector_store_choice):
163
+ """Chat interface function"""
164
+ logger.info(f"Received message: {message}")
165
+ logger.info(f"Using model: {model_choice}")
166
+ logger.info(f"Using vector store: {vector_store_choice}")
167
+
168
  try:
169
+ response, response_time, token_speed = model_handler.generate(
170
+ message,
171
+ model_choice,
172
+ vector_store_choice,
173
+ history
174
+ )
175
+
176
+ # Format response with metrics
177
+ formatted_response = (
178
+ f"{response}\n\n"
179
+ f"⏱️ Response Time: {response_time:.2f}s | "
180
+ f"🚀 Speed: {token_speed:.2f} tokens/s"
181
+ )
182
+
183
  return [(message, formatted_response)]
184
+
185
  except Exception as e:
186
+ logger.error(f"Error in chat: {str(e)}")
187
+ error_message = f"Error: {str(e)}\n\nPlease try again or contact support if the issue persists."
188
+ return [(message, error_message)]
189
 
190
+ # Gradio interface
191
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
192
+ gr.Markdown("# 🚀 Enhanced RAG Chatbot with Performance Metrics")
193
 
194
  with gr.Row():
195
  model_choice = gr.Dropdown(
 
204
  )
205
 
206
  with gr.Row():
 
 
 
 
 
 
 
 
207
  with gr.Column(scale=3):
208
  chatbot = gr.Chatbot(height=500)
209
+ msg = gr.Textbox(
210
+ label="Message",
211
+ placeholder="Type your question here...",
212
+ scale=4
213
+ )
214
 
215
  with gr.Row():
216
  submit_btn = gr.Button("Send", variant="primary")
217
+ clear_btn = gr.ClearButton([msg, chatbot])
218
 
219
  # Event handlers
 
 
 
 
 
 
 
220
  msg.submit(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
221
  submit_btn.click(chat, [msg, chatbot, model_choice, vector_store_choice], chatbot)
222
 
223
+ if __name__ == "__main__":
224
+ demo.launch()