sxid003 commited on
Commit
1709069
·
verified ·
1 Parent(s): b817284

Update src/agents/rag_agent.py

Browse files
Files changed (1) hide show
  1. src/agents/rag_agent.py +26 -14
src/agents/rag_agent.py CHANGED
@@ -62,21 +62,16 @@ tools = [search_legal_documents, search_youtube_transcripts]
62
  llm = GeminiWrapper()
63
  llm_with_tools = llm.bind_tools(tools)
64
 
65
- # SYSTEM MESSAGE
66
- #sys_msg = SystemMessage(content="""
67
- #You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
68
- #Your response must be strictly in the same language as the user’s query.
69
- #Provide accurate answers and include relevant sources (YouTube video links or PDF document links) in your response.
70
- #""")
71
-
72
  # LLM NODE
73
  def assistant(state: MessagesState):
74
  # Utiliser l'état d'invocation global
75
- invocation_type = invocation_state.invocation_type
 
 
 
76
 
77
- print(invocation_type)
78
  # Créer le SystemMessage selon le type d'invocation
79
- if invocation_type == "chatbot":
80
  sys_msg = SystemMessage(content="""
81
  You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
82
  Your response must be strictly in the same language as the user’s query.
@@ -88,15 +83,28 @@ def assistant(state: MessagesState):
88
  Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
89
  Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
90
  """)
91
- user_msg = state["messages"][-1].content
92
- logging.info(f"User input: {user_msg}")
93
  try:
 
94
  result = llm_with_tools.invoke([sys_msg] + state["messages"])
95
- logging.info(f"🤖 Model Output: {result}")
96
  return {"messages": [result]}
97
  except Exception as e:
98
  logging.error(f"LLM invocation failed: {e}")
99
  return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  class InvocationState:
102
  def __init__(self):
@@ -108,10 +116,14 @@ invocation_state = InvocationState()
108
  # GRAPH SETUP
109
  builder = StateGraph(MessagesState)
110
 
 
 
111
  builder.add_node("llm_assistant", assistant)
112
  builder.add_node("tools", ToolNode(tools))
113
 
114
- builder.add_edge(START, "llm_assistant")
 
 
115
  builder.add_conditional_edges("llm_assistant", tools_condition)
116
  builder.add_edge("tools", "llm_assistant")
117
 
 
62
  llm = GeminiWrapper()
63
  llm_with_tools = llm.bind_tools(tools)
64
 
 
 
 
 
 
 
 
65
  # LLM NODE
66
  def assistant(state: MessagesState):
67
  # Utiliser l'état d'invocation global
68
+ #invocation_type = invocation_state.invocation_type
69
+
70
+ user_msg = state["messages"][-1].content
71
+ logging.info(f"User input: {user_msg}")
72
 
 
73
  # Créer le SystemMessage selon le type d'invocation
74
+ if invocation_state.invocation_type == 'chatbot':
75
  sys_msg = SystemMessage(content="""
76
  You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
77
  Your response must be strictly in the same language as the user’s query.
 
83
  Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
84
  Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
85
  """)
86
+
 
87
  try:
88
+ print(state['messages'])
89
  result = llm_with_tools.invoke([sys_msg] + state["messages"])
90
+ logging.info(f"🤖 Model Output kkkk: {result}")
91
  return {"messages": [result]}
92
  except Exception as e:
93
  logging.error(f"LLM invocation failed: {e}")
94
  return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
95
+
96
+ # Tool to set the invocation type based on the user's message
97
+ def set_invocation_type(state: MessagesState):
98
+ user_msg = state["messages"][-1].content.strip() # Normalize the input
99
+ logging.info(f"User input for type update: {user_msg}")
100
+
101
+ # Set the invocation type based on the presence of 'APICALL'
102
+ if 'apicall' in user_msg:
103
+ invocation_state.invocation_type = 'API'
104
+ else:
105
+ invocation_state.invocation_type = 'chatbot'
106
+
107
+ logging.info(f"Invocation type set to: {invocation_state.invocation_type}")
108
 
109
  class InvocationState:
110
  def __init__(self):
 
116
  # GRAPH SETUP
117
  builder = StateGraph(MessagesState)
118
 
119
+ # Add the node to set the invocation type
120
+ builder.add_node("set_invocation_type", set_invocation_type)
121
  builder.add_node("llm_assistant", assistant)
122
  builder.add_node("tools", ToolNode(tools))
123
 
124
+ # Define the edges of the graph
125
+ builder.add_edge(START, "set_invocation_type")
126
+ builder.add_edge("set_invocation_type", "llm_assistant")
127
  builder.add_conditional_edges("llm_assistant", tools_condition)
128
  builder.add_edge("tools", "llm_assistant")
129