File size: 1,447 Bytes
0646b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from cuga.backend.activity_tracker.tracker import ActivityTracker, Step
from cuga.backend.cuga_graph.nodes.shared.base_agent import create_partial
from cuga.backend.cuga_graph.nodes.shared.base_node import BaseNode
from cuga.backend.cuga_graph.state.agent_state import AgentState
from cuga.backend.cuga_graph.nodes.task_decomposition_planning.location_resolver_agent.location_resolver_agent import (
    LocationResolverAgent,
)
from langchain_core.messages import AIMessage

tracker = ActivityTracker()


class LocationSolver(BaseNode):
    def __init__(self, agent: LocationResolverAgent):
        super().__init__()
        self.name = "LocationResolver"
        self.agent = agent
        self.node = create_partial(
            LocationSolver.node_handler,
            agent=self.agent,
            name=self.name,
        )

    @staticmethod
    async def node_handler(state: AgentState, agent: LocationResolverAgent, name: str) -> AgentState:
        res: AIMessage = await agent.run(state.input)
        state.sender = name
        if res.content:
            tracker.collect_step(Step(name=name, data=res.content))
            state.task_analyzer_output.resolved_intent = res.content
        return state
        # if attrs.requires_location_search:
        #     return Command(update=state.model_dump(),goto="LocationResolverAgent")
        # else:
        #     return Command(update=state.model_dump(),goto="TaskDecompositionAgent")