subashpoudel commited on
Commit
8e422ca
·
1 Parent(s): af92900

Included the alternative retrieval process

Browse files
my_agent/utils/business_interaction.py CHANGED
@@ -5,12 +5,15 @@ from langgraph.checkpoint.memory import MemorySaver
5
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
6
  from pydantic import BaseModel, ConfigDict, Field
7
  from typing import Optional, List
8
- from .models_loader import llm
9
- from .prompts import introduction_prompt , business_interaction_prompt
10
  from .tools import retrieve_tool
11
  from langgraph.prebuilt import create_react_agent
12
  from langmem.short_term import SummarizationNode
13
  from langchain_core.messages.utils import count_tokens_approximately
 
 
 
14
 
15
 
16
 
@@ -24,51 +27,81 @@ business_state = State()
24
 
25
  class BusinessInteractionChatbot:
26
  def __init__(self):
 
 
27
  self.react_agent=create_react_agent(
28
  model=llm.bind_tools([retrieve_tool]),
29
  tools=[retrieve_tool]
30
  )
31
- self.summarization_model = llm.bind(max_tokens=400)
32
 
33
  self.summarization_node = SummarizationNode(
34
  token_counter=count_tokens_approximately,
35
  model=self.summarization_model,
36
- max_tokens=256,
37
  max_tokens_before_summary=256,
38
  max_summary_tokens=128,
39
  )
40
-
41
  self.memory = MemorySaver()
42
  # self.llm = ChatGroq(model_name="Gemma2-9b-It")
43
  self.workflow = self._initialize_workflow()
44
  self.interact_agent = self.workflow.compile(checkpointer=self.memory)
45
- self.messages = []
46
 
47
 
48
  def _initialize_workflow(self):
49
  workflow = StateGraph(MessagesState)
50
  workflow.add_node("chatbot", self._call_model)
51
  workflow.add_node("summarize",self.summarization_node)
 
 
52
  workflow.add_edge(START, "summarize")
53
  workflow.add_edge("summarize", "chatbot")
 
54
  workflow.add_edge("chatbot", END)
55
  return workflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def _call_model(self, state):
58
  print('Entered into callmodel')
59
  template = business_interaction_prompt
60
  messages = [SystemMessage(content=template)] + state["messages"]
61
- tool_response = self.react_agent.invoke({'messages':messages})['messages'][-2]
62
- response = self.react_agent.invoke({'messages':messages})['messages'][-1]
63
- print('Tool response:',tool_response)
64
- return {"messages": [response]}
 
 
 
 
 
 
 
 
 
65
 
66
  def chat(self, user_input: str):
67
  print('Entered into chat')
 
68
  self.messages.append({"role": "user", "content": user_input})
69
  config = {"configurable": {"thread_id": "2"}}
70
  response = self.interact_agent.invoke({"messages":self.messages}, config)['messages'][-1].content
71
  print('The response:',response)
72
  self.messages.append({"role": "assistant", "content": response})
 
73
  business_state.interactions.append({'user': user_input, 'agent_response': response})
74
  return response
 
5
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
6
  from pydantic import BaseModel, ConfigDict, Field
7
  from typing import Optional, List
8
+ from .models_loader import llm,ST
9
+ from .prompts import introduction_prompt , business_interaction_prompt, business_retrieval_prompt
10
  from .tools import retrieve_tool
11
  from langgraph.prebuilt import create_react_agent
12
  from langmem.short_term import SummarizationNode
13
  from langchain_core.messages.utils import count_tokens_approximately
14
+ from langchain_core.messages import RemoveMessage
15
+ from .data_loader import load_influencer_data
16
+
17
 
18
 
19
 
 
27
 
28
  class BusinessInteractionChatbot:
29
  def __init__(self):
30
+ self.messages = []
31
+
32
  self.react_agent=create_react_agent(
33
  model=llm.bind_tools([retrieve_tool]),
34
  tools=[retrieve_tool]
35
  )
36
+ self.summarization_model = llm.bind(max_tokens=128)
37
 
38
  self.summarization_node = SummarizationNode(
39
  token_counter=count_tokens_approximately,
40
  model=self.summarization_model,
41
+ max_tokens=300,
42
  max_tokens_before_summary=256,
43
  max_summary_tokens=128,
44
  )
 
45
  self.memory = MemorySaver()
46
  # self.llm = ChatGroq(model_name="Gemma2-9b-It")
47
  self.workflow = self._initialize_workflow()
48
  self.interact_agent = self.workflow.compile(checkpointer=self.memory)
 
49
 
50
 
51
  def _initialize_workflow(self):
52
  workflow = StateGraph(MessagesState)
53
  workflow.add_node("chatbot", self._call_model)
54
  workflow.add_node("summarize",self.summarization_node)
55
+ workflow.add_node("remove_message",self.delete_messages)
56
+
57
  workflow.add_edge(START, "summarize")
58
  workflow.add_edge("summarize", "chatbot")
59
+ workflow.add_edge("chatbot","remove_message")
60
  workflow.add_edge("chatbot", END)
61
  return workflow
62
+
63
+ def delete_messages(self,state):
64
+ print('Entered message deletion....')
65
+ if len(self.messages) > 5:
66
+ print('satisfied...')
67
+ self.messages = self.messages[2:]
68
+
69
+ def manual_retrieval(self):
70
+ embedded_query = ST.encode(str([msg['content'] for msg in self.messages if msg['role'] == 'user'])) # Embed each topic
71
+ data = load_influencer_data()
72
+ scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)
73
+
74
+ # Construct a list of dictionaries for this topic
75
+ result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
76
+ return result
77
+
78
 
79
  def _call_model(self, state):
80
  print('Entered into callmodel')
81
  template = business_interaction_prompt
82
  messages = [SystemMessage(content=template)] + state["messages"]
83
+ # response = self.react_agent.invoke({'messages':messages})['messages'][-2]
84
+ response = self.react_agent.invoke({'messages':messages})['messages']
85
+ if response [-2].name == None:
86
+ print('Entered into manual retrieval')
87
+ retrievals = self.manual_retrieval()
88
+ template = business_retrieval_prompt(retrievals)
89
+ messages = [SystemMessage(content=template)] + state["messages"]
90
+ backup_response = self.react_agent.invoke({'messages':messages})['messages'][-1]
91
+ print('Backup response:',backup_response.content)
92
+ return {"messages": [backup_response.content]}
93
+
94
+ else:
95
+ return {"messages": [response[-1]]}
96
 
97
  def chat(self, user_input: str):
98
  print('Entered into chat')
99
+
100
  self.messages.append({"role": "user", "content": user_input})
101
  config = {"configurable": {"thread_id": "2"}}
102
  response = self.interact_agent.invoke({"messages":self.messages}, config)['messages'][-1].content
103
  print('The response:',response)
104
  self.messages.append({"role": "assistant", "content": response})
105
+ print('The message_history:',self.messages)
106
  business_state.interactions.append({'user': user_input, 'agent_response': response})
107
  return response