Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_groq import ChatGroq | |
| from langgraph.graph import StateGraph, MessagesState, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from typing import Optional, List | |
| from .models_loader import llm | |
| from .prompts import introduction_prompt , business_interaction_prompt | |
| from .tools import retrieve_tool | |
| from langgraph.prebuilt import create_react_agent | |
| from langmem.short_term import SummarizationNode | |
| from langchain_core.messages.utils import count_tokens_approximately | |
| # State model | |
| class State(BaseModel): | |
| interactions: Optional[list] = [] | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| # Global business state (shared) | |
| business_state = State() | |
| class BusinessInteractionChatbot: | |
| def __init__(self): | |
| self.react_agent=create_react_agent( | |
| model=llm.bind_tools([retrieve_tool]), | |
| tools=[retrieve_tool] | |
| ) | |
| self.summarization_model = llm.bind(max_tokens=400) | |
| self.summarization_node = SummarizationNode( | |
| token_counter=count_tokens_approximately, | |
| model=self.summarization_model, | |
| max_tokens=256, | |
| max_tokens_before_summary=256, | |
| max_summary_tokens=128, | |
| ) | |
| self.memory = MemorySaver() | |
| # self.llm = ChatGroq(model_name="Gemma2-9b-It") | |
| self.workflow = self._initialize_workflow() | |
| self.interact_agent = self.workflow.compile(checkpointer=self.memory) | |
| self.messages = [] | |
| def _initialize_workflow(self): | |
| workflow = StateGraph(MessagesState) | |
| workflow.add_node("chatbot", self._call_model) | |
| workflow.add_node("summarize",self.summarization_node) | |
| workflow.add_edge(START, "summarize") | |
| workflow.add_edge("summarize", "chatbot") | |
| workflow.add_edge("chatbot", END) | |
| return workflow | |
| def _call_model(self, state): | |
| print('Entered into callmodel') | |
| template = business_interaction_prompt | |
| messages = [SystemMessage(content=template)] + state["messages"] | |
| tool_response = self.react_agent.invoke({'messages':messages})['messages'][-2] | |
| response = self.react_agent.invoke({'messages':messages})['messages'][-1] | |
| print('Tool response:',tool_response) | |
| return {"messages": [response]} | |
| def chat(self, user_input: str): | |
| print('Entered into chat') | |
| self.messages.append({"role": "user", "content": user_input}) | |
| config = {"configurable": {"thread_id": "2"}} | |
| response = self.interact_agent.invoke({"messages":self.messages}, config)['messages'][-1].content | |
| print('The response:',response) | |
| self.messages.append({"role": "assistant", "content": response}) | |
| business_state.interactions.append({'user': user_input, 'agent_response': response}) | |
| return response | |