Spaces:
Build error
Build error
Commit ·
c0205c2
1
Parent(s): 61ef0a5
app.py
CHANGED
|
@@ -153,7 +153,7 @@ class EnhancedRetriever:
|
|
| 153 |
return vector_store
|
| 154 |
|
| 155 |
@lru_cache(maxsize=500)
|
| 156 |
-
def retrieve(self, query: str) -> str:
|
| 157 |
try:
|
| 158 |
processed_query = self._preprocess_query(query)
|
| 159 |
expanded_query = self._hyde_expansion(processed_query)
|
|
@@ -163,10 +163,12 @@ class EnhancedRetriever:
|
|
| 163 |
expanded_results = self.bm25.invoke(expanded_query)
|
| 164 |
|
| 165 |
fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
except Exception as e:
|
| 168 |
logger.error(f"Retrieval Error: {str(e)}")
|
| 169 |
-
return ""
|
| 170 |
|
| 171 |
def _preprocess_query(self, query: str) -> str:
|
| 172 |
return query.lower().strip()
|
|
@@ -228,7 +230,6 @@ SYSTEM_PROMPT = """
|
|
| 228 |
Context: {context}
|
| 229 |
"""
|
| 230 |
|
| 231 |
-
|
| 232 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
|
| 233 |
def get_ai_response(query: str, context: str, model: str) -> str:
|
| 234 |
result = "" # Initialize the result variable
|
|
@@ -262,15 +263,14 @@ def get_ai_response(query: str, context: str, model: str) -> str:
|
|
| 262 |
if result is None:
|
| 263 |
result = "Failed to get response from llama3-70b-8192"
|
| 264 |
# Append the model name to the response for clarity
|
| 265 |
-
# get the key name model from model mapping
|
| 266 |
for key, value in model_mapping.items():
|
| 267 |
if value == model:
|
| 268 |
model = key
|
| 269 |
-
result += f"\n\n**Model:** {model}"
|
| 270 |
return result
|
| 271 |
except Exception as e:
|
| 272 |
logger.error(f"Generation Error: {str(e)}")
|
| 273 |
-
return "I'm unable to generate a response right now. Please try again later."
|
| 274 |
|
| 275 |
def _postprocess_response(response: str) -> str:
|
| 276 |
response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
|
|
@@ -306,20 +306,38 @@ def get_groq_llama3_response(query: str) -> str:
|
|
| 306 |
except requests.exceptions.RequestException as e:
|
| 307 |
logger.error(f"Groq API Error: {str(e)}")
|
| 308 |
return "An error occurred while contacting Groq's Llama 3 model."
|
|
|
|
| 309 |
# --- Pipeline ---
|
| 310 |
documents = load_and_chunk_data(data_file_name)
|
| 311 |
retriever = EnhancedRetriever(documents)
|
| 312 |
|
| 313 |
def generate_response(question: str, model: str) -> str:
|
| 314 |
try:
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
except Exception as e:
|
| 318 |
logger.error(f"Pipeline Error: {str(e)}")
|
| 319 |
return "An error occurred processing your request."
|
| 320 |
|
| 321 |
# --- Gradio Interface ---
|
| 322 |
-
# Define the mapping from display names to actual model identifiers
|
| 323 |
model_mapping = {
|
| 324 |
"Gemini-2.0-Flash": "gemini-2.0-flash",
|
| 325 |
"Meta-llama-3-70b-instruct(GWDG)": "meta-llama-3-70b-instruct",
|
|
@@ -327,7 +345,7 @@ model_mapping = {
|
|
| 327 |
}
|
| 328 |
|
| 329 |
def chat_interface(question: str, history: List[Tuple[str, str]], display_model: str):
|
| 330 |
-
model = model_mapping.get(display_model, "gemini-2.0-flash")
|
| 331 |
response = generate_response(question, model)
|
| 332 |
return "", history + [(question, response)]
|
| 333 |
|
|
|
|
| 153 |
return vector_store
|
| 154 |
|
| 155 |
@lru_cache(maxsize=500)
|
| 156 |
+
def retrieve(self, query: str) -> Tuple[str, List[Document]]:
|
| 157 |
try:
|
| 158 |
processed_query = self._preprocess_query(query)
|
| 159 |
expanded_query = self._hyde_expansion(processed_query)
|
|
|
|
| 163 |
expanded_results = self.bm25.invoke(expanded_query)
|
| 164 |
|
| 165 |
fused_results = self._fuse_results([bm25_results, vector_results, expanded_results])
|
| 166 |
+
top_docs = fused_results[:5]
|
| 167 |
+
formatted_context = self._format_context(top_docs)
|
| 168 |
+
return formatted_context, top_docs
|
| 169 |
except Exception as e:
|
| 170 |
logger.error(f"Retrieval Error: {str(e)}")
|
| 171 |
+
return "", []
|
| 172 |
|
| 173 |
def _preprocess_query(self, query: str) -> str:
|
| 174 |
return query.lower().strip()
|
|
|
|
| 230 |
Context: {context}
|
| 231 |
"""
|
| 232 |
|
|
|
|
| 233 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=20))
|
| 234 |
def get_ai_response(query: str, context: str, model: str) -> str:
|
| 235 |
result = "" # Initialize the result variable
|
|
|
|
| 263 |
if result is None:
|
| 264 |
result = "Failed to get response from llama3-70b-8192"
|
| 265 |
# Append the model name to the response for clarity
|
|
|
|
| 266 |
for key, value in model_mapping.items():
|
| 267 |
if value == model:
|
| 268 |
model = key
|
| 269 |
+
# result += f"\n\n**Model:** {model}"
|
| 270 |
return result
|
| 271 |
except Exception as e:
|
| 272 |
logger.error(f"Generation Error: {str(e)}")
|
| 273 |
+
return "I'm unable to generate a response right now. Please try again later or try another model."
|
| 274 |
|
| 275 |
def _postprocess_response(response: str) -> str:
|
| 276 |
response = re.sub(r"\[(.*?)\]", r"[\1](#)", response)
|
|
|
|
| 306 |
except requests.exceptions.RequestException as e:
|
| 307 |
logger.error(f"Groq API Error: {str(e)}")
|
| 308 |
return "An error occurred while contacting Groq's Llama 3 model."
|
| 309 |
+
|
| 310 |
# --- Pipeline ---
|
| 311 |
documents = load_and_chunk_data(data_file_name)
|
| 312 |
retriever = EnhancedRetriever(documents)
|
| 313 |
|
| 314 |
def generate_response(question: str, model: str) -> str:
|
| 315 |
try:
|
| 316 |
+
formatted_context, retrieved_docs = retriever.retrieve(question)
|
| 317 |
+
if not formatted_context:
|
| 318 |
+
return "No relevant information found."
|
| 319 |
+
response = get_ai_response(question, formatted_context, model)
|
| 320 |
+
# Extract references from retrieved documents whose hyperlinks start with "https://asknature.org"
|
| 321 |
+
ref_links = []
|
| 322 |
+
for doc in retrieved_docs:
|
| 323 |
+
hyperlink = doc.metadata.get("hyperlink", "")
|
| 324 |
+
if hyperlink.startswith("https://asknature.org") and hyperlink not in ref_links:
|
| 325 |
+
ref_links.append(hyperlink)
|
| 326 |
+
if ref_links:
|
| 327 |
+
references_md = "\n\n**References:**\n"
|
| 328 |
+
for i, link in enumerate(ref_links, 1):
|
| 329 |
+
references_md += f"[{i}] {link}\n"
|
| 330 |
+
response += references_md
|
| 331 |
+
for key, value in model_mapping.items():
|
| 332 |
+
if value == model:
|
| 333 |
+
model = key
|
| 334 |
+
response += f"\n\n**Model:** {model}"
|
| 335 |
+
return response
|
| 336 |
except Exception as e:
|
| 337 |
logger.error(f"Pipeline Error: {str(e)}")
|
| 338 |
return "An error occurred processing your request."
|
| 339 |
|
| 340 |
# --- Gradio Interface ---
|
|
|
|
| 341 |
model_mapping = {
|
| 342 |
"Gemini-2.0-Flash": "gemini-2.0-flash",
|
| 343 |
"Meta-llama-3-70b-instruct(GWDG)": "meta-llama-3-70b-instruct",
|
|
|
|
| 345 |
}
|
| 346 |
|
| 347 |
def chat_interface(question: str, history: List[Tuple[str, str]], display_model: str):
|
| 348 |
+
model = model_mapping.get(display_model, "gemini-2.0-flash")
|
| 349 |
response = generate_response(question, model)
|
| 350 |
return "", history + [(question, response)]
|
| 351 |
|