rdz-falcon commited on
Commit
cdf0bc6
·
verified ·
1 Parent(s): a889531

Update src/rag.py

Browse files
Files changed (1) hide show
  1. src/rag.py +31 -13
src/rag.py CHANGED
@@ -466,21 +466,39 @@ class AACAssistant:
466
  response = self.chain.invoke(
467
  {"question": user_query, "emotion_analysis": emotion_analysis}
468
  )
469
- raw_full_answer = response.get("answer", "")
470
- assistant_marker = "</s> <|assistant|>"
471
-
472
- # Find the last occurrence of the marker
473
- marker_position = raw_full_answer.rfind(assistant_marker)
 
 
 
 
474
 
 
475
  if marker_position != -1:
476
- # Extract the text *after* the marker
477
- actual_response = raw_full_answer[marker_position + len(assistant_marker):].strip()
 
 
 
 
 
 
 
 
 
478
  else:
479
- # Fallback if the marker is not found in the response.
480
- # This might happen if the LLM's output is unexpected or if the prompt structure changed.
481
- print(f"WARNING: Assistant marker '{assistant_marker}' not found in raw answer. Returning raw answer as fallback.")
482
- actual_response = raw_full_answer.strip() # Or handle as an error
483
-
484
- print(f"DEBUG: process_query - Extracted assistant response: '{actual_response}'")
 
 
 
 
485
  return actual_response
486
  # return response["answer"]
 
466
  response = self.chain.invoke(
467
  {"question": user_query, "emotion_analysis": emotion_analysis}
468
  )
469
+ raw_chain_output_answer = response.get("answer", "")
470
+ prompt_end_marker = "Please generate your response as the AAC user, following the instructions above.</s>\n<|assistant|>"
471
+
472
+ # For debugging, let's print what we're searching for and a snippet of where we're searching
473
+ print(f"DEBUG: process_query - Attempting to find marker: [{prompt_end_marker}]")
474
+ # print(f"DEBUG: process_query - Last 200 chars of raw_chain_output_answer: [...{raw_chain_output_answer[-200:]}]")
475
+
476
+
477
+ marker_position = raw_chain_output_answer.rfind(prompt_end_marker)
478
 
479
+ actual_response = ""
480
  if marker_position != -1:
481
+ # If the marker is found, take everything AFTER it
482
+ actual_response = raw_chain_output_answer[marker_position + len(prompt_end_marker):].strip()
483
+ print(f"DEBUG: process_query - Marker found. Extracted response before cleaning EOS: '{actual_response}'")
484
+
485
+ # Llama 3 models often output an <|eot_id|> at the end of their turn.
486
+ # Let's remove this if present.
487
+ eot_marker = "<|eot_id|>"
488
+ if actual_response.endswith(eot_marker):
489
+ actual_response = actual_response[:-len(eot_marker)].strip()
490
+ print(f"DEBUG: process_query - Cleaned <|eot_id|>, final response: '{actual_response}'")
491
+
492
  else:
493
+ # This block will be hit if the precise prompt_end_marker isn't found.
494
+ # This indicates a mismatch between your defined marker and the actual raw output.
495
+ print(f"ERROR: Precise marker [{prompt_end_marker}] NOT FOUND in raw answer.")
496
+ print(f"DEBUG: process_query - Raw full answer from chain (length {len(raw_chain_output_answer)}):")
497
+ print(f"'''{raw_chain_output_answer}'''") # Print the whole thing for analysis
498
+ actual_response = "Error: Could not parse the assistant's response correctly." # Or return raw_chain_output_answer for debugging in UI
499
+
500
+ # --- END OF CORRECTED PARSING LOGIC ---
501
+
502
+ print(f"DEBUG: process_query - Final extracted assistant response: '{actual_response}'")
503
  return actual_response
504
  # return response["answer"]