amitbhatt6075 commited on
Commit
d8f03cc
Β·
1 Parent(s): 6cb46f3

refactor(agent): Use human-friendly prompt for reliable chatbot responses

Browse files
Files changed (1) hide show
  1. core/support_agent.py +69 -23
core/support_agent.py CHANGED
@@ -1,21 +1,14 @@
1
-
2
  import traceback
3
  from typing import Dict, Any, List
4
  from llama_cpp import Llama
5
 
6
- # βœ… THE FIX IS HERE: The new, correct import paths for LangChain
7
  from langchain_core.language_models.llms import LLM
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import Chroma
12
  from langchain_core.prompts import PromptTemplate
13
- from langchain_core.output_parsers import StrOutputParser
14
- from dotenv import load_dotenv
15
-
16
- load_dotenv()
17
 
18
- # This class allows us to use our already-loaded llama_cpp model with LangChain
19
  class LlamaLangChain(LLM):
20
  llama_instance: Llama
21
 
@@ -23,16 +16,12 @@ class LlamaLangChain(LLM):
23
  def _llm_type(self) -> str:
24
  return "custom"
25
 
26
- # Changed stop to List[str] for better type hinting
27
  def _call(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
28
- response = self.llama_instance(
29
- prompt, max_tokens=256, stop=stop, stream=False
30
- )
31
  return response["choices"][0]["text"]
32
 
33
- # Required for async operations, even if not used, to match the base class
34
  async def _acall(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
35
- # For simplicity, we call the sync method. For production, you might want a true async implementation.
36
  return self._call(prompt, stop, **kwargs)
37
 
38
  def format_docs(docs):
@@ -41,23 +30,80 @@ def format_docs(docs):
41
  class SupportAgent:
42
  def __init__(self, llm_instance: Llama, embedding_path: str, db_path: str):
43
  print("--- Initializing Support Agent (Optimized for Low RAM) ---")
44
-
45
  if llm_instance is None:
46
  raise ValueError("SupportAgent received an invalid LLM instance.")
47
-
48
- # This wrapper is correct
49
  self.langchain_llm_wrapper = LlamaLangChain(llama_instance=llm_instance)
50
-
51
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
52
  self.vector_store = Chroma(persist_directory=db_path, embedding_function=self.embeddings)
53
- self.conversations = {}
54
-
55
- router_template = """Classify: 'live_data' or 'general_knowledge'. Question: {question} Classification:"""
56
- self.router_prompt = PromptTemplate.from_template(router_template)
57
- self.router_chain = self.router_prompt | self.langchain_llm_wrapper | StrOutputParser()
58
-
59
  print("βœ… Agent and core components initialized successfully.")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def _get_or_create_memory(self, conversation_id: str) -> ConversationBufferMemory:
63
  if conversation_id not in self.conversations:
 
 
1
  import traceback
2
  from typing import Dict, Any, List
3
  from llama_cpp import Llama
4
 
 
5
  from langchain_core.language_models.llms import LLM
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import Chroma
10
  from langchain_core.prompts import PromptTemplate
 
 
 
 
11
 
 
12
  class LlamaLangChain(LLM):
13
  llama_instance: Llama
14
 
 
16
  def _llm_type(self) -> str:
17
  return "custom"
18
 
 
19
  def _call(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
20
+ # Give a generous token limit for the answer
21
+ response = self.llama_instance(prompt, max_tokens=512, stop=stop, stream=False, echo=False)
 
22
  return response["choices"][0]["text"]
23
 
 
24
  async def _acall(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
 
25
  return self._call(prompt, stop, **kwargs)
26
 
27
  def format_docs(docs):
 
30
  class SupportAgent:
31
  def __init__(self, llm_instance: Llama, embedding_path: str, db_path: str):
32
  print("--- Initializing Support Agent (Optimized for Low RAM) ---")
 
33
  if llm_instance is None:
34
  raise ValueError("SupportAgent received an invalid LLM instance.")
 
 
35
  self.langchain_llm_wrapper = LlamaLangChain(llama_instance=llm_instance)
 
36
  self.embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
37
  self.vector_store = Chroma(persist_directory=db_path, embedding_function=self.embeddings)
38
+ self.conversations: Dict[str, ConversationBufferMemory] = {}
 
 
 
 
 
39
  print("βœ… Agent and core components initialized successfully.")
40
 
41
+ def _get_or_create_memory(self, conversation_id: str) -> ConversationBufferMemory:
42
+ if conversation_id not in self.conversations:
43
+ self.conversations[conversation_id] = ConversationBufferMemory(
44
+ memory_key="chat_history", return_messages=True, input_key="question", output_key='answer'
45
+ )
46
+ return self.conversations[conversation_id]
47
+
48
+ def answer(self, payload: dict, conversation_id: str) -> dict:
49
+ question = payload.get("question", "")
50
+ live_data_context = payload.get("live_data", "") # Get the live data from backend
51
+ user_role = payload.get("role", "user")
52
+
53
+ memory = self._get_or_create_memory(conversation_id)
54
+
55
+ try:
56
+ # === βœ… THE FINAL, BULLETPROOF FIX IS HERE βœ… ===
57
+ # We create a simple, human-like prompt that combines everything.
58
+ # No more complex [CONTEXT] blocks.
59
+ human_friendly_template = """You are a helpful and professional support assistant for the Reachify platform.
60
+ Answer the user's question based on their chat history and the context provided below.
61
+
62
+ Chat History:
63
+ {chat_history}
64
+
65
+ Additional Context (if available):
66
+ {context}
67
+
68
+ Live Data about the User (Role: {role}):
69
+ {live_data}
70
+
71
+ User's Question: {question}
72
+
73
+ Your Answer:
74
+ """
75
+ # Create a LangChain PromptTemplate from our new string
76
+ final_prompt = PromptTemplate.from_template(human_friendly_template)
77
+
78
+ retriever = self.vector_store.as_ retriever()
79
+
80
+ # Now, we pass this beautiful, simple prompt to the chain
81
+ qa_chain = ConversationalRetrievalChain.from_llm(
82
+ llm=self.langchain_llm_wrapper,
83
+ retriever=retriever,
84
+ memory=memory,
85
+ combine_docs_chain_kwargs={"prompt": final_prompt}
86
+ )
87
+
88
+ # We need to add all required variables for our new prompt
89
+ result = qa_chain.invoke({
90
+ "question": question,
91
+ "live_data": live_data_context,
92
+ "role": user_role
93
+ })
94
+
95
+ final_answer = result.get("answer", "I'm sorry, I could not find an answer.").strip()
96
+
97
+ # Final safety check
98
+ if "[NODE_NAME]" in final_answer or "Your Answer:" in final_answer:
99
+ return {"response": "I'm having trouble generating a clear response right now. Can you please rephrase the question?", "context": "AI returned a template."}
100
+
101
+ return {"response": final_answer, "context": format_docs(result.get('source_documents', []))}
102
+
103
+ except Exception as e:
104
+ traceback.print_exc()
105
+ return {"response": "A critical server error occurred in the AI agent.", "context": str(e)}
106
+
107
 
108
  def _get_or_create_memory(self, conversation_id: str) -> ConversationBufferMemory:
109
  if conversation_id not in self.conversations: