Mohamed284 commited on
Commit
c0205c2
·
1 Parent(s): 61ef0a5
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -153,7 +153,7 @@ class EnhancedRetriever:
153
  return vector_store
154
 
155
  @lru_cache(maxsize=500)
156
- def retrieve(self, query: str) -> str:
157
  try:
158
  processed_query = self._preprocess_query(query)
159
  expanded_query = self._hyde_expansion(processed_query)
@@ -163,10 +163,12 @@ class EnhancedRetriever:
163
  expanded_results = self.bm25.invoke(expanded_query)
164
 
165
  fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
166
- return self._format_context(fused_results[:5])
 
 
167
  except Exception as e:
168
  logger.error(f"Retrieval Error: {str(e)}")
169
- return ""
170
 
171
  def _preprocess_query(self, query: str) -> str:
172
  return query.lower().strip()
@@ -228,7 +230,6 @@ SYSTEM_PROMPT = """
228
  Context: {context}
229
  """
230
 
231
-
232
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
233
  def get_ai_response(query: str, context: str, model: str) -> str:
234
  result = "" # Initialize the result variable
@@ -262,15 +263,14 @@ def get_ai_response(query: str, context: str, model: str) -> str:
262
  if result is None:
263
  result = "Failed to get response from llama3-70b-8192"
264
  # Append the model name to the response for clarity
265
- # get the key name model from model mapping
266
  for key, value in model_mapping.items():
267
  if value == model:
268
  model = key
269
- result += f"\n\n**Model:** {model}"
270
  return result
271
  except Exception as e:
272
  logger.error(f"Generation Error: {str(e)}")
273
- return "I'm unable to generate a response right now. Please try again later."
274
 
275
  def _postprocess_response(response: str) -> str:
276
  response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
@@ -306,20 +306,38 @@ def get_groq_llama3_response(query: str) -> str:
306
  except requests.exceptions.RequestException as e:
307
  logger.error(f"Groq API Error: {str(e)}")
308
  return "An error occurred while contacting Groq's Llama 3 model."
 
309
  # --- Pipeline ---
310
  documents = load_and_chunk_data(data_file_name)
311
  retriever = EnhancedRetriever(documents)
312
 
313
  def generate_response(question: str, model: str) -> str:
314
  try:
315
- context = retriever.retrieve(question)
316
- return get_ai_response(question, context, model) if context else "No relevant information found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  except Exception as e:
318
  logger.error(f"Pipeline Error: {str(e)}")
319
  return "An error occurred processing your request."
320
 
321
  # --- Gradio Interface ---
322
- # Define the mapping from display names to actual model identifiers
323
  model_mapping = {
324
  "Gemini-2.0-Flash": "gemini-2.0-flash",
325
  "Meta-llama-3-70b-instruct(GWDG)": "meta-llama-3-70b-instruct",
@@ -327,7 +345,7 @@ model_mapping = {
327
  }
328
 
329
  def chat_interface(question: str, history: List[Tuple[str, str]], display_model: str):
330
- model = model_mapping.get(display_model, "gemini-2.0-flash") # Default to Gemini if not found
331
  response = generate_response(question, model)
332
  return "", history + [(question, response)]
333
 
 
153
  return vector_store
154
 
155
  @lru_cache(maxsize=500)
156
+ def retrieve(self, query: str) -> Tuple[str, List[Document]]:
157
  try:
158
  processed_query = self._preprocess_query(query)
159
  expanded_query = self._hyde_expansion(processed_query)
 
163
  expanded_results = self.bm25.invoke(expanded_query)
164
 
165
  fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
166
+ top_docs = fused_results[:5]
167
+ formatted_context = self._format_context(top_docs)
168
+ return formatted_context, top_docs
169
  except Exception as e:
170
  logger.error(f"Retrieval Error: {str(e)}")
171
+ return "", []
172
 
173
  def _preprocess_query(self, query: str) -> str:
174
  return query.lower().strip()
 
230
  Context: {context}
231
  """
232
 
 
233
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
234
  def get_ai_response(query: str, context: str, model: str) -> str:
235
  result = "" # Initialize the result variable
 
263
  if result is None:
264
  result = "Failed to get response from llama3-70b-8192"
265
  # Append the model name to the response for clarity
 
266
  for key, value in model_mapping.items():
267
  if value == model:
268
  model = key
269
+ # result += f"\n\n**Model:** {model}"
270
  return result
271
  except Exception as e:
272
  logger.error(f"Generation Error: {str(e)}")
273
+ return "I'm unable to generate a response right now. Please try again later or try another model."
274
 
275
  def _postprocess_response(response: str) -> str:
276
  response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
 
306
  except requests.exceptions.RequestException as e:
307
  logger.error(f"Groq API Error: {str(e)}")
308
  return "An error occurred while contacting Groq's Llama 3 model."
309
+
310
  # --- Pipeline ---
311
  documents = load_and_chunk_data(data_file_name)
312
  retriever = EnhancedRetriever(documents)
313
 
314
  def generate_response(question: str, model: str) -> str:
315
  try:
316
+ formatted_context, retrieved_docs = retriever.retrieve(question)
317
+ if not formatted_context:
318
+ return "No relevant information found."
319
+ response = get_ai_response(question, formatted_context, model)
320
+ # Extract references from retrieved documents whose hyperlinks start with "https://asknature.org"
321
+ ref_links = []
322
+ for doc in retrieved_docs:
323
+ hyperlink = doc.metadata.get("hyperlink", "")
324
+ if hyperlink.startswith("https://asknature.org") and hyperlink not in ref_links:
325
+ ref_links.append(hyperlink)
326
+ if ref_links:
327
+ references_md = "\n\n**References:**\n"
328
+ for i, link in enumerate(ref_links, 1):
329
+ references_md += f"[{i}] {link}\n"
330
+ response += references_md
331
+ for key, value in model_mapping.items():
332
+ if value == model:
333
+ model = key
334
+ response += f"\n\n**Model:** {model}"
335
+ return response
336
  except Exception as e:
337
  logger.error(f"Pipeline Error: {str(e)}")
338
  return "An error occurred processing your request."
339
 
340
  # --- Gradio Interface ---
 
341
  model_mapping = {
342
  "Gemini-2.0-Flash": "gemini-2.0-flash",
343
  "Meta-llama-3-70b-instruct(GWDG)": "meta-llama-3-70b-instruct",
 
345
  }
346
 
347
  def chat_interface(question: str, history: List[Tuple[str, str]], display_model: str):
348
+ model = model_mapping.get(display_model, "gemini-2.0-flash")
349
  response = generate_response(question, model)
350
  return "", history + [(question, response)]
351