Commit ·
7770bcb
1
Parent(s): e0f03d9
Changed the model to mixtral-8x7b-32768 for groq
Browse files- backend/app.py +29 -16
backend/app.py
CHANGED
|
@@ -276,21 +276,6 @@ def ask_question_endpoint():
|
|
| 276 |
with open(metadata_path, "rb") as f: metadata = pickle.load(f)
|
| 277 |
chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"]
|
| 278 |
|
| 279 |
-
# --- 2. Retrieve Context ---
|
| 280 |
-
query_vec = get_embedding_model().encode([query], show_progress_bar=False)
|
| 281 |
-
distances, indices = index.search(np.array(query_vec), k=3)
|
| 282 |
-
context = "\n".join([chunks[i] for i in indices[0]])
|
| 283 |
-
sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
|
| 284 |
-
|
| 285 |
-
# --- 3. Generate Prompt ---
|
| 286 |
-
prompt_template = (
|
| 287 |
-
"Using ONLY the information from the following context, answer the question. "
|
| 288 |
-
"Do not mention the context in your answer. Be concise.\n\n"
|
| 289 |
-
"Context:\n{context}\n\n"
|
| 290 |
-
"Question: {query}\n\nAnswer:"
|
| 291 |
-
)
|
| 292 |
-
prompt = prompt_template.format(context=context, query=query)
|
| 293 |
-
|
| 294 |
# --- 4. Generate Response (with API/Local Fallback) ---
|
| 295 |
response = ""
|
| 296 |
use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here"
|
|
@@ -299,11 +284,25 @@ def ask_question_endpoint():
|
|
| 299 |
# --- Primary Method: Try Groq API ---
|
| 300 |
global api_llm
|
| 301 |
if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
logging.info("Attempting to generate response with Groq API...")
|
| 303 |
try:
|
| 304 |
chat_completion = api_llm.chat.completions.create(
|
| 305 |
messages=[{"role": "user", "content": prompt}],
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
temperature=0.5, max_tokens=250
|
| 308 |
)
|
| 309 |
response = chat_completion.choices[0].message.content
|
|
@@ -314,10 +313,24 @@ def ask_question_endpoint():
|
|
| 314 |
|
| 315 |
if not use_api:
|
| 316 |
# --- Fallback Method: Use Local Model ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
logging.info("Generating response with local model.")
|
| 318 |
global local_llm
|
| 319 |
if local_llm is None: local_llm = get_local_llm()
|
| 320 |
response = local_llm.generate(prompt, max_tokens=250, temp=0.5)
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
return jsonify({"answer": response, "sources": sources})
|
| 323 |
|
|
|
|
| 276 |
with open(metadata_path, "rb") as f: metadata = pickle.load(f)
|
| 277 |
chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"]
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
# --- 4. Generate Response (with API/Local Fallback) ---
|
| 280 |
response = ""
|
| 281 |
use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here"
|
|
|
|
| 284 |
# --- Primary Method: Try Groq API ---
|
| 285 |
global api_llm
|
| 286 |
if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY)
|
| 287 |
+
|
| 288 |
+
# Retrieve top 3 chunks for the powerful API model
|
| 289 |
+
query_vec = get_embedding_model().encode([query], show_progress_bar=False)
|
| 290 |
+
distances, indices = index.search(np.array(query_vec), k=3)
|
| 291 |
+
context = "\n".join([chunks[i] for i in indices[0]])
|
| 292 |
+
sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
|
| 293 |
+
|
| 294 |
+
prompt_template = ("Using ONLY the information from the following context, answer the question. "
|
| 295 |
+
"Do not mention the context in your answer. Be concise.\n\n"
|
| 296 |
+
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
| 297 |
+
prompt = prompt_template.format(context=context, query=query)
|
| 298 |
+
|
| 299 |
logging.info("Attempting to generate response with Groq API...")
|
| 300 |
try:
|
| 301 |
chat_completion = api_llm.chat.completions.create(
|
| 302 |
messages=[{"role": "user", "content": prompt}],
|
| 303 |
+
# Use a stable, powerful model like Mixtral.
|
| 304 |
+
# Other options include 'gemma-7b-it'.
|
| 305 |
+
model="mixtral-8x7b-32768",
|
| 306 |
temperature=0.5, max_tokens=250
|
| 307 |
)
|
| 308 |
response = chat_completion.choices[0].message.content
|
|
|
|
| 313 |
|
| 314 |
if not use_api:
|
| 315 |
# --- Fallback Method: Use Local Model ---
|
| 316 |
+
# The local model has a small context window, so we only use the single most relevant chunk (k=1).
|
| 317 |
+
query_vec = get_embedding_model().encode([query], show_progress_bar=False)
|
| 318 |
+
distances, indices = index.search(np.array(query_vec), k=1)
|
| 319 |
+
context = "\n".join([chunks[i] for i in indices[0]])
|
| 320 |
+
sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
|
| 321 |
+
|
| 322 |
+
prompt_template = ("Using ONLY the information from the following context, answer the question. "
|
| 323 |
+
"Do not mention the context in your answer. Be concise.\n\n"
|
| 324 |
+
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
| 325 |
+
prompt = prompt_template.format(context=context, query=query)
|
| 326 |
+
|
| 327 |
logging.info("Generating response with local model.")
|
| 328 |
global local_llm
|
| 329 |
if local_llm is None: local_llm = get_local_llm()
|
| 330 |
response = local_llm.generate(prompt, max_tokens=250, temp=0.5)
|
| 331 |
+
# Check if the local model returned an error message instead of an answer
|
| 332 |
+
if "LLaMA ERROR" in response:
|
| 333 |
+
raise RuntimeError("The local model failed to generate a response due to context size limitations.")
|
| 334 |
|
| 335 |
return jsonify({"answer": response, "sources": sources})
|
| 336 |
|