Wenye He commited on
Commit
4429fce
·
verified ·
1 Parent(s): 359f734

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -30
app.py CHANGED
@@ -19,26 +19,26 @@ MODEL_CONFIG = {
19
  "phi-3": {
20
  "model_name": "microsoft/phi-3-mini-4k-instruct",
21
  "template": """<|user|>
22
- Using the following context, please answer the question. If the context doesn't contain relevant information, say so.
23
 
24
  Context:
25
  {context}
26
 
27
  Question: {question}<|end|>
28
  <|assistant|>
29
- Let me help answer your question based on the provided context."""
30
  },
31
  "llama3-8b": {
32
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
33
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
34
- Using the following context, please answer the question. If the context doesn't contain relevant information, say so.
35
 
36
  Context:
37
  {context}
38
 
39
  Question: {question}<|eot_id|>
40
  <|start_header_id|>assistant<|end_header_id|>
41
- Let me help answer your question based on the provided context."""
42
  }
43
  }
44
 
@@ -53,7 +53,9 @@ class ChatModel:
53
  def __init__(self):
54
  self.models = {}
55
  self.tokenizers = {}
56
- self.vectorstore = {}
 
 
57
  self.embeddings = HuggingFaceEmbeddings(
58
  model_name="sentence-transformers/all-MiniLM-L6-v2"
59
  )
@@ -62,8 +64,8 @@ class ChatModel:
62
  """Load and cache the model and tokenizer"""
63
  if model_name not in self.models:
64
  logger.info(f"Loading model: {model_name}")
65
- config = MODEL_CONFIG[model_name]
66
  try:
 
67
  tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
68
  tokenizer.pad_token = tokenizer.eos_token
69
  model = AutoModelForCausalLM.from_pretrained(
@@ -80,32 +82,51 @@ class ChatModel:
80
  raise
81
 
82
  def load_vector_store(self, store_name):
83
- """Load and cache vector stores"""
84
- if store_name not in self.vectorstore:
85
- logger.info(f"Loading vector store: {store_name}")
86
- try:
87
- self.vectorstore[store_name] = FAISS.load_local(
88
- f"vector_stores_index/{store_name}",
 
 
 
 
 
 
 
89
  self.embeddings,
90
  allow_dangerous_deserialization=True
91
  )
92
- # Verify vector store content
93
- self.check_vectorstore(store_name)
 
 
94
  logger.info(f"Successfully loaded vector store: {store_name}")
95
- except Exception as e:
96
- logger.error(f"Error loading vector store {store_name}: {str(e)}")
97
- raise
98
- return self.vectorstore[store_name]
99
 
100
- def check_vectorstore(self, store_name):
101
- """Verify vector store content"""
 
 
 
 
 
 
 
102
  try:
103
- vectorstore = self.vectorstore[store_name]
104
- sample_query = "test query"
105
- docs = vectorstore.similarity_search(sample_query, k=1)
106
- logger.info(f"Sample document from {store_name}: {docs[0].page_content[:200]}...")
 
 
 
 
 
107
  except Exception as e:
108
- logger.error(f"Error checking vector store {store_name}: {str(e)}")
109
  raise
110
 
111
  def generate(self, message, model_name, vector_store_name, history):
@@ -120,8 +141,14 @@ class ChatModel:
120
  # Retrieve relevant context
121
  logger.info(f"Retrieving context for query: {message}")
122
  docs = vectorstore.similarity_search(message, k=3)
 
 
 
 
 
 
 
123
  context = "\n\n".join([d.page_content for d in docs])
124
- logger.info(f"Retrieved context: {context[:200]}...")
125
 
126
  # Format prompt
127
  prompt = config["template"].format(
@@ -129,6 +156,8 @@ class ChatModel:
129
  question=message
130
  )
131
 
 
 
132
  # Generate response
133
  pipe = pipeline(
134
  "text-generation",
@@ -173,7 +202,7 @@ def chat(message, history, model_choice, vector_store_choice):
173
  history
174
  )
175
 
176
- # Format response with metrics
177
  formatted_response = (
178
  f"{response}\n\n"
179
  f"⏱️ Response Time: {response_time:.2f}s | "
@@ -189,7 +218,10 @@ def chat(message, history, model_choice, vector_store_choice):
189
 
190
  # Gradio interface
191
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
192
- gr.Markdown("# 🚀 Enhanced RAG Chatbot with Performance Metrics")
 
 
 
193
 
194
  with gr.Row():
195
  model_choice = gr.Dropdown(
@@ -198,9 +230,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
198
  value="phi-3"
199
  )
200
  vector_store_choice = gr.Dropdown(
201
- ["llm", "scoliosis"],
202
  value="scoliosis",
203
- label="Knowledge Base"
 
204
  )
205
 
206
  with gr.Row():
 
19
  "phi-3": {
20
  "model_name": "microsoft/phi-3-mini-4k-instruct",
21
  "template": """<|user|>
22
+ Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly.
23
 
24
  Context:
25
  {context}
26
 
27
  Question: {question}<|end|>
28
  <|assistant|>
29
+ Based on the provided context, I'll answer your question:"""
30
  },
31
  "llama3-8b": {
32
  "model_name": "NousResearch/Meta-Llama-3-8B-Instruct",
33
  "template": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
34
+ Using only the following context, please provide a relevant answer to the question. If the context doesn't contain relevant information, please say so clearly.
35
 
36
  Context:
37
  {context}
38
 
39
  Question: {question}<|eot_id|>
40
  <|start_header_id|>assistant<|end_header_id|>
41
+ Based on the provided context, I'll answer your question:"""
42
  }
43
  }
44
 
 
53
  def __init__(self):
54
  self.models = {}
55
  self.tokenizers = {}
56
+ self.current_store = None
57
+ self.current_vectorstore = None
58
+ # Use the same embedding model as in vector store creation
59
  self.embeddings = HuggingFaceEmbeddings(
60
  model_name="sentence-transformers/all-MiniLM-L6-v2"
61
  )
 
64
  """Load and cache the model and tokenizer"""
65
  if model_name not in self.models:
66
  logger.info(f"Loading model: {model_name}")
 
67
  try:
68
+ config = MODEL_CONFIG[model_name]
69
  tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
70
  tokenizer.pad_token = tokenizer.eos_token
71
  model = AutoModelForCausalLM.from_pretrained(
 
82
  raise
83
 
84
  def load_vector_store(self, store_name):
85
+ """Load vector store with cache invalidation"""
86
+ try:
87
+ # Check if we need to load a new store
88
+ if self.current_store != store_name:
89
+ logger.info(f"Loading new vector store: {store_name}")
90
+ vector_store_path = f"vector-stores/{store_name}"
91
+
92
+ if not os.path.exists(vector_store_path):
93
+ raise ValueError(f"Vector store not found at: {vector_store_path}")
94
+
95
+ # Load new vector store
96
+ self.current_vectorstore = FAISS.load_local(
97
+ vector_store_path,
98
  self.embeddings,
99
  allow_dangerous_deserialization=True
100
  )
101
+ self.current_store = store_name
102
+
103
+ # Verify the new store
104
+ self.check_vectorstore()
105
  logger.info(f"Successfully loaded vector store: {store_name}")
106
+
107
+ return self.current_vectorstore
 
 
108
 
109
+ except Exception as e:
110
+ logger.error(f"Error loading vector store {store_name}: {str(e)}")
111
+ # Reset state on error
112
+ self.current_store = None
113
+ self.current_vectorstore = None
114
+ raise
115
+
116
+ def check_vectorstore(self):
117
+ """Verify current vector store content"""
118
  try:
119
+ if self.current_vectorstore is None:
120
+ raise ValueError("No vector store currently loaded")
121
+
122
+ # Use a generic query to test retrieval
123
+ sample_query = "what is this document about"
124
+ docs = self.current_vectorstore.similarity_search(sample_query, k=1)
125
+ logger.info(f"Vector store {self.current_store} content sample:")
126
+ logger.info(f"Document content: {docs[0].page_content[:200]}...")
127
+ logger.info(f"Document source: {docs[0].metadata.get('source', 'unknown')}")
128
  except Exception as e:
129
+ logger.error(f"Error checking vector store: {str(e)}")
130
  raise
131
 
132
  def generate(self, message, model_name, vector_store_name, history):
 
141
  # Retrieve relevant context
142
  logger.info(f"Retrieving context for query: {message}")
143
  docs = vectorstore.similarity_search(message, k=3)
144
+
145
+ # Log retrieved documents for debugging
146
+ for i, doc in enumerate(docs):
147
+ logger.info(f"Retrieved document {i + 1}:")
148
+ logger.info(f"Source: {doc.metadata.get('source', 'unknown')}")
149
+ logger.info(f"Content: {doc.page_content[:200]}...")
150
+
151
  context = "\n\n".join([d.page_content for d in docs])
 
152
 
153
  # Format prompt
154
  prompt = config["template"].format(
 
156
  question=message
157
  )
158
 
159
+ logger.info(f"Generated prompt: {prompt[:200]}...")
160
+
161
  # Generate response
162
  pipe = pipeline(
163
  "text-generation",
 
202
  history
203
  )
204
 
205
+ # Format response with metrics and source context
206
  formatted_response = (
207
  f"{response}\n\n"
208
  f"⏱️ Response Time: {response_time:.2f}s | "
 
218
 
219
  # Gradio interface
220
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown("""# 🚀 Enhanced RAG Chatbot with Performance Metrics
222
+
223
+ This chatbot uses Retrieval-Augmented Generation (RAG) to provide informed responses based on your documents.
224
+ """)
225
 
226
  with gr.Row():
227
  model_choice = gr.Dropdown(
 
230
  value="phi-3"
231
  )
232
  vector_store_choice = gr.Dropdown(
233
+ ["llm", "scoliosis"], # Update these choices based on your vector stores
234
  value="scoliosis",
235
+ label="Knowledge Base",
236
+ interactive=True
237
  )
238
 
239
  with gr.Row():