amitbhatt6075 commited on
Commit
8231bd2
·
1 Parent(s): 269ad2b

refactor(ai): Upgrade langchain to v0.2.x syntax

Browse files
Files changed (2) hide show
  1. core/support_agent.py +94 -52
  2. requirements.txt +0 -0
core/support_agent.py CHANGED
@@ -1,96 +1,138 @@
1
  import traceback
2
  from typing import Dict, Any, List
3
- from llama_cpp import Llama
4
 
5
- from langchain_core.language_models.llms import LLM
6
- from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
7
- from langchain.memory import ConversationBufferMemory
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import Chroma
10
- from langchain_core.prompts import PromptTemplate
 
11
 
 
12
  class LlamaLangChain(LLM):
13
  llama_instance: Llama
 
14
  @property
15
- def _llm_type(self) -> str: return "custom"
 
 
16
  def _call(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
17
- response = self.llama_instance(prompt, max_tokens=512, stop=stop, stream=False, echo=False)
18
- return response["choices"][0]["text"]
 
 
 
 
 
 
 
 
 
 
 
 
19
  async def _acall(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
 
20
  return self._call(prompt, stop, **kwargs)
21
 
22
- def format_docs(docs):
 
23
  return "\n\n".join(doc.page_content for doc in docs)
24
 
25
  class SupportAgent:
 
 
 
 
26
  def __init__(self, llm_instance: Llama, embedding_path: str, db_path: str):
27
- print("--- Initializing Support Agent (Final Version) ---")
28
- if llm_instance is None: raise ValueError("SupportAgent received an invalid LLM instance.")
 
 
29
  self.langchain_llm_wrapper = LlamaLangChain(llama_instance=llm_instance)
30
- self.embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
31
- self.vector_store = Chroma(persist_directory=db_path, embedding_function=self.embeddings)
32
- self.conversations: Dict[str, ConversationBufferMemory] = {}
33
- print("✅ Agent and core components initialized successfully.")
 
 
 
 
 
 
 
34
 
35
- def _get_or_create_memory(self, conversation_id: str) -> ConversationBufferMemory:
36
- if conversation_id not in self.conversations:
37
- self.conversations[conversation_id] = ConversationBufferMemory(memory_key="chat_history", return_messages=True, input_key="question", output_key='answer')
38
- return self.conversations[conversation_id]
39
 
40
  def answer(self, payload: dict, conversation_id: str) -> dict:
41
  question = payload.get("question", "")
42
  live_data_context = payload.get("live_data", "")
43
- user_role = payload.get("role", "user")
44
 
45
- memory = self._get_or_create_memory(conversation_id)
 
46
 
47
  try:
48
- # FINAL, POLISHED PROMPT
49
- human_friendly_template = """You are Sparky, a helpful AI assistant for Reachify.
50
  Your job is to provide a direct and concise answer to the user's question.
51
- Use the Live Data and Context to find the answer. Do not talk about yourself.
52
 
53
- **Live Data (Facts from the user's account):**
54
  {live_data}
55
 
56
- **Context (General Knowledge):**
57
  {context}
58
 
59
- **Chat History:**
60
  {chat_history}
61
 
62
- **User's Question:** {question}
63
 
64
- **Direct Answer:**
65
  """
66
-
67
- final_prompt = PromptTemplate.from_template(human_friendly_template)
68
-
69
- # SYNTAX FIX: Removed the space between 'as_' and 'retriever()'
70
- retriever = self.vector_store.as_retriever()
71
-
72
- qa_chain = ConversationalRetrievalChain.from_llm(
73
- llm=self.langchain_llm_wrapper,
74
- retriever=retriever,
75
- memory=memory,
76
- combine_docs_chain_kwargs={"prompt": final_prompt}
 
 
 
 
 
77
  )
78
-
79
- result = qa_chain.invoke({
80
- "question": question,
81
- "live_data": live_data_context
82
- })
83
 
84
- raw_answer = result.get("answer", "I'm sorry, I could not find an answer.").strip()
85
-
86
- final_answer = raw_answer.split("Answer:")[0].split("Direct Answer:")[0].strip()
87
 
88
- return {"response": final_answer, "context": format_docs(result.get('source_documents', []))}
 
 
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
  traceback.print_exc()
92
- return {"response": "A critical server error occurred in the AI agent.", "context": str(e)}
93
-
 
 
94
 
95
  def generate_caption_variant(self, caption: str, action: str) -> str:
96
  # Note: You were calling self.llm here but it's defined as self.langchain_llm_wrapper
 
1
  import traceback
2
  from typing import Dict, Any, List
 
3
 
4
+ from llama_cpp import Llama
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.prompts import ChatPromptTemplate
8
  from langchain_community.vectorstores import Chroma
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
+ from langchain_core.language_models.llms import LLM
11
 
12
+ # A custom wrapper to make llama_cpp compatible with LangChain's LLM interface
13
  class LlamaLangChain(LLM):
14
  llama_instance: Llama
15
+
16
  @property
17
+ def _llm_type(self) -> str:
18
+ return "custom-llama-cpp"
19
+
20
  def _call(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
21
+ # Some LLMs may return conversational artifacts; we strip them here.
22
+ # This is a robust way to ensure a clean response.
23
+ unwanted_starters = ["Answer:", "Direct Answer:", "Assistant:"]
24
+ try:
25
+ response = self.llama_instance(prompt, max_tokens=512, stop=stop, stream=False, echo=False)
26
+ text = response["choices"][0]["text"].strip()
27
+ for starter in unwanted_starters:
28
+ if text.lower().startswith(starter.lower()):
29
+ text = text[len(starter):].strip()
30
+ return text
31
+ except Exception as e:
32
+ print(f"ERROR during LLM call: {e}")
33
+ return "Error generating response from the model."
34
+
35
  async def _acall(self, prompt: str, stop: List[str] | None = None, **kwargs) -> str:
36
+ # Simple async wrapper around the synchronous call
37
  return self._call(prompt, stop, **kwargs)
38
 
39
+ # Helper function to format retrieved documents
40
+ def _format_docs_for_context(docs: List[Any]) -> str:
41
  return "\n\n".join(doc.page_content for doc in docs)
42
 
43
  class SupportAgent:
44
+ """
45
+ Modern (LangChain v0.2.x) AI agent using a RAG pipeline with LCEL.
46
+ This version replaces the deprecated ConversationalRetrievalChain.
47
+ """
48
  def __init__(self, llm_instance: Llama, embedding_path: str, db_path: str):
49
+ print("--- Initializing Support Agent (LangChain v0.2.x Modern Version) ---")
50
+ if llm_instance is None:
51
+ raise ValueError("SupportAgent received an invalid LLM instance.")
52
+
53
  self.langchain_llm_wrapper = LlamaLangChain(llama_instance=llm_instance)
54
+
55
+ try:
56
+ print(f" - Loading embeddings from: {embedding_path}")
57
+ self.embeddings = HuggingFaceEmbeddings(model_name=embedding_path, model_kwargs={'device': 'cpu'})
58
+
59
+ print(f" - Connecting to Vector DB at: {db_path}")
60
+ self.vector_store = Chroma(persist_directory=db_path, embedding_function=self.embeddings)
61
+ self.retriever = self.vector_store.as_retriever(search_kwargs={"k": 3})
62
+
63
+ # The memory is no longer part of the chain itself in modern LCEL
64
+ self.conversations: Dict[str, List[tuple]] = {}
65
 
66
+ print("✅ Agent and core components initialized successfully.")
67
+ except Exception as e:
68
+ print(f" CRITICAL ERROR during Support Agent initialization: {e}")
69
+ traceback.print_exc()
70
 
71
  def answer(self, payload: dict, conversation_id: str) -> dict:
72
  question = payload.get("question", "")
73
  live_data_context = payload.get("live_data", "")
 
74
 
75
+ # Get or create chat history for this conversation
76
+ chat_history = self.conversations.get(conversation_id, [])
77
 
78
  try:
79
+ # This is the modern LangChain Expression Language (LCEL) chain
80
+ template = """You are Sparky, a helpful AI assistant for Reachify.
81
  Your job is to provide a direct and concise answer to the user's question.
82
+ Use the Live Data and Context provided to find the answer. Do not talk about yourself. If the information isn't in the context, say you don't know.
83
 
84
+ Live Data (Facts from user's account):
85
  {live_data}
86
 
87
+ Context (General Knowledge from documents):
88
  {context}
89
 
90
+ Previous Conversation:
91
  {chat_history}
92
 
93
+ User's Question: {question}
94
 
95
+ Direct Answer:
96
  """
97
+ prompt = ChatPromptTemplate.from_template(template)
98
+
99
+ # Manually format the chat history into a readable string
100
+ formatted_history = "\n".join([f"Human: {q}\nAssistant: {a}" for q, a in chat_history])
101
+
102
+ # The LCEL "pipe"
103
+ rag_chain = (
104
+ {
105
+ "context": self.retriever | _format_docs_for_context,
106
+ "question": RunnablePassthrough(),
107
+ "live_data": lambda x: live_data_context, # Pass live data through
108
+ "chat_history": lambda x: formatted_history, # Pass history through
109
+ }
110
+ | prompt
111
+ | self.langchain_llm_wrapper
112
+ | StrOutputParser()
113
  )
 
 
 
 
 
114
 
115
+ print(f" - Invoking RAG chain for question: '{question}'")
116
+ # Invoke the chain by passing just the question string
117
+ final_answer = rag_chain.invoke(question)
118
 
119
+ # Update the conversation memory after getting a successful answer
120
+ self.conversations[conversation_id] = chat_history + [(question, final_answer)]
121
+
122
+ # Get the documents that were used, for transparency
123
+ source_docs = self.retriever.get_relevant_documents(question)
124
+
125
+ return {
126
+ "response": final_answer,
127
+ "context": _format_docs_for_context(source_docs)
128
+ }
129
 
130
  except Exception as e:
131
  traceback.print_exc()
132
+ return {
133
+ "response": "A critical server error occurred in the AI agent.",
134
+ "context": str(e)
135
+ }
136
 
137
  def generate_caption_variant(self, caption: str, action: str) -> str:
138
  # Note: You were calling self.llm here but it's defined as self.langchain_llm_wrapper
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ