Spaces:
Sleeping
Sleeping
Update src/agents/rag_agent.py
Browse files- src/agents/rag_agent.py +26 -14
src/agents/rag_agent.py
CHANGED
|
@@ -62,21 +62,16 @@ tools = [search_legal_documents, search_youtube_transcripts]
|
|
| 62 |
llm = GeminiWrapper()
|
| 63 |
llm_with_tools = llm.bind_tools(tools)
|
| 64 |
|
| 65 |
-
# SYSTEM MESSAGE
|
| 66 |
-
#sys_msg = SystemMessage(content="""
|
| 67 |
-
#You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
|
| 68 |
-
#Your response must be strictly in the same language as the user’s query.
|
| 69 |
-
#Provide accurate answers and include relevant sources (YouTube video links or PDF document links) in your response.
|
| 70 |
-
#""")
|
| 71 |
-
|
| 72 |
# LLM NODE
|
| 73 |
def assistant(state: MessagesState):
|
| 74 |
# Utiliser l'état d'invocation global
|
| 75 |
-
invocation_type = invocation_state.invocation_type
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
print(invocation_type)
|
| 78 |
# Créer le SystemMessage selon le type d'invocation
|
| 79 |
-
if invocation_type ==
|
| 80 |
sys_msg = SystemMessage(content="""
|
| 81 |
You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
|
| 82 |
Your response must be strictly in the same language as the user’s query.
|
|
@@ -88,15 +83,28 @@ def assistant(state: MessagesState):
|
|
| 88 |
Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
|
| 89 |
Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
|
| 90 |
""")
|
| 91 |
-
|
| 92 |
-
logging.info(f"User input: {user_msg}")
|
| 93 |
try:
|
|
|
|
| 94 |
result = llm_with_tools.invoke([sys_msg] + state["messages"])
|
| 95 |
-
logging.info(f"🤖 Model Output: {result}")
|
| 96 |
return {"messages": [result]}
|
| 97 |
except Exception as e:
|
| 98 |
logging.error(f"LLM invocation failed: {e}")
|
| 99 |
return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
class InvocationState:
|
| 102 |
def __init__(self):
|
|
@@ -108,10 +116,14 @@ invocation_state = InvocationState()
|
|
| 108 |
# GRAPH SETUP
|
| 109 |
builder = StateGraph(MessagesState)
|
| 110 |
|
|
|
|
|
|
|
| 111 |
builder.add_node("llm_assistant", assistant)
|
| 112 |
builder.add_node("tools", ToolNode(tools))
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
builder.add_conditional_edges("llm_assistant", tools_condition)
|
| 116 |
builder.add_edge("tools", "llm_assistant")
|
| 117 |
|
|
|
|
| 62 |
llm = GeminiWrapper()
|
| 63 |
llm_with_tools = llm.bind_tools(tools)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# LLM NODE
|
| 66 |
def assistant(state: MessagesState):
|
| 67 |
# Utiliser l'état d'invocation global
|
| 68 |
+
#invocation_type = invocation_state.invocation_type
|
| 69 |
+
|
| 70 |
+
user_msg = state["messages"][-1].content
|
| 71 |
+
logging.info(f"User input: {user_msg}")
|
| 72 |
|
|
|
|
| 73 |
# Créer le SystemMessage selon le type d'invocation
|
| 74 |
+
if invocation_state.invocation_type == 'chatbot':
|
| 75 |
sys_msg = SystemMessage(content="""
|
| 76 |
You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
|
| 77 |
Your response must be strictly in the same language as the user’s query.
|
|
|
|
| 83 |
Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
|
| 84 |
Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
|
| 85 |
""")
|
| 86 |
+
|
|
|
|
| 87 |
try:
|
| 88 |
+
print(state['messages'])
|
| 89 |
result = llm_with_tools.invoke([sys_msg] + state["messages"])
|
| 90 |
+
logging.info(f"🤖 Model Output kkkk: {result}")
|
| 91 |
return {"messages": [result]}
|
| 92 |
except Exception as e:
|
| 93 |
logging.error(f"LLM invocation failed: {e}")
|
| 94 |
return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
|
| 95 |
+
|
| 96 |
+
# Tool to set the invocation type based on the user's message
|
| 97 |
+
def set_invocation_type(state: MessagesState):
|
| 98 |
+
user_msg = state["messages"][-1].content.strip() # Normalize the input
|
| 99 |
+
logging.info(f"User input for type update: {user_msg}")
|
| 100 |
+
|
| 101 |
+
# Set the invocation type based on the presence of 'APICALL'
|
| 102 |
+
if 'apicall' in user_msg:
|
| 103 |
+
invocation_state.invocation_type = 'API'
|
| 104 |
+
else:
|
| 105 |
+
invocation_state.invocation_type = 'chatbot'
|
| 106 |
+
|
| 107 |
+
logging.info(f"Invocation type set to: {invocation_state.invocation_type}")
|
| 108 |
|
| 109 |
class InvocationState:
|
| 110 |
def __init__(self):
|
|
|
|
| 116 |
# GRAPH SETUP
|
| 117 |
builder = StateGraph(MessagesState)
|
| 118 |
|
| 119 |
+
# Add the node to set the invocation type
|
| 120 |
+
builder.add_node("set_invocation_type", set_invocation_type)
|
| 121 |
builder.add_node("llm_assistant", assistant)
|
| 122 |
builder.add_node("tools", ToolNode(tools))
|
| 123 |
|
| 124 |
+
# Define the edges of the graph
|
| 125 |
+
builder.add_edge(START, "set_invocation_type")
|
| 126 |
+
builder.add_edge("set_invocation_type", "llm_assistant")
|
| 127 |
builder.add_conditional_edges("llm_assistant", tools_condition)
|
| 128 |
builder.add_edge("tools", "llm_assistant")
|
| 129 |
|