QuentinL52 commited on
Commit
6f2a831
·
verified ·
1 Parent(s): 895e94e

Update services/graph_service.py

Browse files
Files changed (1) hide show
  1. services/graph_service.py +34 -11
services/graph_service.py CHANGED
@@ -16,6 +16,7 @@ class AgentState(TypedDict):
16
  messages: Annotated[Sequence[BaseMessage], lambda x, y: x + y]
17
  user_id: str
18
  job_offer_id: str
 
19
 
20
  class GraphInterviewProcessor:
21
  """
@@ -77,6 +78,19 @@ class GraphInterviewProcessor:
77
 
78
  return prompt | llm_with_tools
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def _agent_node(self, state: AgentState):
81
  """Prépare le prompt et appelle le runnable de l'agent."""
82
 
@@ -109,37 +123,45 @@ class GraphInterviewProcessor:
109
  "messages": state["messages"],
110
  "job_description": job_description
111
  })
112
- return {"messages": [response]}
 
 
 
113
 
114
- def _router(self, state: AgentState):
115
- """Décide du chemin à suivre après la réponse de l'agent."""
 
 
 
 
 
116
  last_message = state["messages"][-1]
117
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls and len(last_message.tool_calls) > 0:
 
 
 
118
  return "call_tool"
119
- else:
120
- return "end_turn"
121
 
122
  def _build_graph(self) -> any:
123
  """Construit et compile le graphe d'états."""
124
  tool_node = ToolNode([trigger_interview_analysis])
125
-
126
  graph = StateGraph(AgentState)
127
  graph.add_node("agent", self._agent_node)
128
  graph.add_node("tools", tool_node)
129
-
130
  graph.set_entry_point("agent")
131
-
132
  graph.add_conditional_edges(
133
  "agent",
134
  self._router,
135
  {
136
  "call_tool": "tools",
 
137
  "end_turn": END
138
  }
139
  )
140
-
141
  graph.add_edge("tools", "agent")
142
-
143
  return graph.compile()
144
 
145
  def invoke(self, messages: List[Dict[str, Any]]):
@@ -154,6 +176,7 @@ class GraphInterviewProcessor:
154
  "user_id": self.user_id,
155
  "job_offer_id": self.job_offer_id,
156
  "messages": langchain_messages,
 
157
  }
158
 
159
  final_state = self.graph.invoke(initial_state)
 
16
  messages: Annotated[Sequence[BaseMessage], lambda x, y: x + y]
17
  user_id: str
18
  job_offer_id: str
19
+ job_description: str
20
 
21
  class GraphInterviewProcessor:
22
  """
 
78
 
79
  return prompt | llm_with_tools
80
 
81
+ def _should_continue(self, state: InterviewState) -> str:
82
+ """
83
+ Détermine si l'entretien doit continuer ou se terminer.
84
+ """
85
+ messages = state.get('messages', [])
86
+ last_message = messages[-1]
87
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
88
+ for tool_call in last_message.tool_calls:
89
+ if tool_call.get('name') == 'trigger_interview_analysis':
90
+ print("Condition de fin détectée : appel à trigger_interview_analysis.")
91
+ return "end"
92
+ return "continue"
93
+
94
  def _agent_node(self, state: AgentState):
95
  """Prépare le prompt et appelle le runnable de l'agent."""
96
 
 
123
  "messages": state["messages"],
124
  "job_description": job_description
125
  })
126
+ return {
127
+ "messages": [response],
128
+ "job_description": job_description_str
129
+ }
130
 
131
+ def _router(self, state: AgentState) -> str:
132
+ """
133
+ Route le flux du graphe en fonction de la dernière réponse de l'agent.
134
+ - Si un outil d'analyse final est appelé, termine le graphe.
135
+ - Si un autre outil est appelé, va au noeud d'outils.
136
+ - Sinon, termine le tour de conversation.
137
+ """
138
  last_message = state["messages"][-1]
139
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
140
+ if any(tool_call.get('name') == 'trigger_interview_analysis' for tool_call in last_message.tool_calls):
141
+ print(">>> Routeur : Appel à l'outil final détecté. Terminaison du graphe.")
142
+ return "call_final_tool"
143
  return "call_tool"
144
+ return "end_turn"
 
145
 
146
  def _build_graph(self) -> any:
147
  """Construit et compile le graphe d'états."""
148
  tool_node = ToolNode([trigger_interview_analysis])
 
149
  graph = StateGraph(AgentState)
150
  graph.add_node("agent", self._agent_node)
151
  graph.add_node("tools", tool_node)
152
+ graph.add_node("final_tool_node", tool_node)
153
  graph.set_entry_point("agent")
 
154
  graph.add_conditional_edges(
155
  "agent",
156
  self._router,
157
  {
158
  "call_tool": "tools",
159
+ "call_final_tool": "final_tool_node",
160
  "end_turn": END
161
  }
162
  )
 
163
  graph.add_edge("tools", "agent")
164
+ graph.add_edge("final_tool_node", END)
165
  return graph.compile()
166
 
167
  def invoke(self, messages: List[Dict[str, Any]]):
 
176
  "user_id": self.user_id,
177
  "job_offer_id": self.job_offer_id,
178
  "messages": langchain_messages,
179
+ "job_description": json.dumps(self.job_offer, ensure_ascii=False),
180
  }
181
 
182
  final_state = self.graph.invoke(initial_state)