Digambar29 commited on
Commit
7770bcb
·
1 Parent(s): e0f03d9

Changed the model to mixtral-8x7b-32768 for groq

Browse files
Files changed (1) hide show
  1. backend/app.py +29 -16
backend/app.py CHANGED
@@ -276,21 +276,6 @@ def ask_question_endpoint():
276
  with open(metadata_path, "rb") as f: metadata = pickle.load(f)
277
  chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"]
278
 
279
- # --- 2. Retrieve Context ---
280
- query_vec = get_embedding_model().encode([query], show_progress_bar=False)
281
- distances, indices = index.search(np.array(query_vec), k=3)
282
- context = "\n".join([chunks[i] for i in indices[0]])
283
- sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
284
-
285
- # --- 3. Generate Prompt ---
286
- prompt_template = (
287
- "Using ONLY the information from the following context, answer the question. "
288
- "Do not mention the context in your answer. Be concise.\n\n"
289
- "Context:\n{context}\n\n"
290
- "Question: {query}\n\nAnswer:"
291
- )
292
- prompt = prompt_template.format(context=context, query=query)
293
-
294
  # --- 4. Generate Response (with API/Local Fallback) ---
295
  response = ""
296
  use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here"
@@ -299,11 +284,25 @@ def ask_question_endpoint():
299
  # --- Primary Method: Try Groq API ---
300
  global api_llm
301
  if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
302
  logging.info("Attempting to generate response with Groq API...")
303
  try:
304
  chat_completion = api_llm.chat.completions.create(
305
  messages=[{"role": "user", "content": prompt}],
306
- model="llama3-70b-8192", # Use a current, supported model
 
 
307
  temperature=0.5, max_tokens=250
308
  )
309
  response = chat_completion.choices[0].message.content
@@ -314,10 +313,24 @@ def ask_question_endpoint():
314
 
315
  if not use_api:
316
  # --- Fallback Method: Use Local Model ---
 
 
 
 
 
 
 
 
 
 
 
317
  logging.info("Generating response with local model.")
318
  global local_llm
319
  if local_llm is None: local_llm = get_local_llm()
320
  response = local_llm.generate(prompt, max_tokens=250, temp=0.5)
 
 
 
321
 
322
  return jsonify({"answer": response, "sources": sources})
323
 
 
276
  with open(metadata_path, "rb") as f: metadata = pickle.load(f)
277
  chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"]
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # --- 4. Generate Response (with API/Local Fallback) ---
280
  response = ""
281
  use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here"
 
284
  # --- Primary Method: Try Groq API ---
285
  global api_llm
286
  if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY)
287
+
288
+ # Retrieve top 3 chunks for the powerful API model
289
+ query_vec = get_embedding_model().encode([query], show_progress_bar=False)
290
+ distances, indices = index.search(np.array(query_vec), k=3)
291
+ context = "\n".join([chunks[i] for i in indices[0]])
292
+ sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
293
+
294
+ prompt_template = ("Using ONLY the information from the following context, answer the question. "
295
+ "Do not mention the context in your answer. Be concise.\n\n"
296
+ "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
297
+ prompt = prompt_template.format(context=context, query=query)
298
+
299
  logging.info("Attempting to generate response with Groq API...")
300
  try:
301
  chat_completion = api_llm.chat.completions.create(
302
  messages=[{"role": "user", "content": prompt}],
303
+ # Use a stable, powerful model like Mixtral.
304
+ # Other options include 'gemma-7b-it'.
305
+ model="mixtral-8x7b-32768",
306
  temperature=0.5, max_tokens=250
307
  )
308
  response = chat_completion.choices[0].message.content
 
313
 
314
  if not use_api:
315
  # --- Fallback Method: Use Local Model ---
316
+ # The local model has a small context window, so we only use the single most relevant chunk (k=1).
317
+ query_vec = get_embedding_model().encode([query], show_progress_bar=False)
318
+ distances, indices = index.search(np.array(query_vec), k=1)
319
+ context = "\n".join([chunks[i] for i in indices[0]])
320
+ sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
321
+
322
+ prompt_template = ("Using ONLY the information from the following context, answer the question. "
323
+ "Do not mention the context in your answer. Be concise.\n\n"
324
+ "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
325
+ prompt = prompt_template.format(context=context, query=query)
326
+
327
  logging.info("Generating response with local model.")
328
  global local_llm
329
  if local_llm is None: local_llm = get_local_llm()
330
  response = local_llm.generate(prompt, max_tokens=250, temp=0.5)
331
+ # Check if the local model returned an error message instead of an answer
332
+ if "LLaMA ERROR" in response:
333
+ raise RuntimeError("The local model failed to generate a response due to context size limitations.")
334
 
335
  return jsonify({"answer": response, "sources": sources})
336