File size: 5,418 Bytes
3107242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af85fa
3107242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709069
 
 
 
3107242
 
1709069
3107242
 
 
 
 
 
 
 
 
 
 
1709069
3107242
1709069
3107242
6437579
3107242
 
 
 
1709069
 
 
 
 
 
 
 
 
 
 
 
 
3107242
 
 
 
 
 
 
 
 
 
 
1709069
 
3107242
 
 
1709069
 
 
3107242
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
import os
sys.path.append(os.getcwd())

from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import tools_condition
from langchain_core.tools import tool
from src.rag.documents_rag_pipeline import RAGPipeline
from src.rag.youtube_rag_pipeline import YouTubeRAGPipeline
from src.models.llm_wrapper import GeminiWrapper
from src.utils.search_docs_utils import search_relevant_documents
from src.configs.config import LOG_DIR
import logging
import os

# LOGGING SETUP
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
    filename=LOG_FILE,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# PIPELINE INIT
doc_rag_pipeline = RAGPipeline()
ytb_rag_pipeline = YouTubeRAGPipeline()

# TOOLS
@tool
def search_legal_documents(query: str) -> str:
    """Search Moroccan **legal codes, decrees, and PDFs**. Use this tool when user asks about Moroccan **laws, articles, codes, legal references**, or related content."""
    logging.info(f"Tool call: search_legal_documents with input: {query}")
    try:
        relevant_docs = search_relevant_documents(query, top_k=1)
        results = doc_rag_pipeline.get_top_docs_chunks_for_query(query, relevant_docs['results'])
        output = "\n".join([f"{r['text']}\n(Source: {r['pdf_link']})" for r in results])
        logging.info(f"Tool result (search_legal_documents): {output}")
        return output
    except Exception as e:
        logging.error(f"Error in search_legal_documents: {e}")
        return "Une erreur est survenue lors de la recherche des documents légaux."

@tool
def search_youtube_transcripts(query: str) -> str:
    """Search Moroccan **Parliament YouTube videos** (debates, speeches, questions, sessions). Use this tool when user asks about **video content** or **what someone said**."""
    logging.info(f"Tool call: search_youtube_transcripts with input: {query}")
    try:
        results = ytb_rag_pipeline.search(query)
        output = "\n".join([f"{r['texte']}\n(Source: {r['lien']})" for r in results])
        logging.info(f"Tool result (search_youtube_transcripts): {output}")
        return output
    except Exception as e:
        logging.error(f"Error in search_youtube_transcripts: {e}")
        return "Une erreur est survenue lors de la recherche dans les transcriptions YouTube."

tools = [search_legal_documents, search_youtube_transcripts]

# LLM SETUP
llm = GeminiWrapper()
llm_with_tools = llm.bind_tools(tools)

# LLM NODE
def assistant(state: MessagesState):
     # Utiliser l'état d'invocation global
    #invocation_type = invocation_state.invocation_type

    user_msg = state["messages"][-1].content
    logging.info(f"User input: {user_msg}")

    # Créer le SystemMessage selon le type d'invocation
    if invocation_state.invocation_type == 'chatbot':
        sys_msg = SystemMessage(content="""
        You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
        Your response must be strictly in the same language as the user’s query.
        Provide accurate answers and include relevant sources (YouTube video links or PDF document links) in your response.
        """)
    else :
        sys_msg = SystemMessage(content="""
        You are a helpful assistant for an API that answers user questions related to Moroccan Parliament YouTube videos and legal documents.
        Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
        Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
        """)
    
    try:
        print(state['messages'])
        result = llm_with_tools.invoke([sys_msg] + state["messages"])
        logging.info(f"🤖 Model Output : {result}")
        return {"messages": [result]}
    except Exception as e:
        logging.error(f"LLM invocation failed: {e}")
        return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
    
# Tool to set the invocation type based on the user's message
def set_invocation_type(state: MessagesState):
    user_msg = state["messages"][-1].content.strip()  # Normalize the input
    logging.info(f"User input for type update: {user_msg}")

    # Set the invocation type based on the presence of 'APICALL'
    if 'apicall' in user_msg:
        invocation_state.invocation_type = 'API'
    else:
        invocation_state.invocation_type = 'chatbot'

    logging.info(f"Invocation type set to: {invocation_state.invocation_type}")

class InvocationState:
    def __init__(self):
        self.invocation_type = None

# Créer une instance de l'état d'invocation
invocation_state = InvocationState()

# GRAPH SETUP
builder = StateGraph(MessagesState)

# Add the node to set the invocation type
builder.add_node("set_invocation_type", set_invocation_type)
builder.add_node("llm_assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define the edges of the graph
builder.add_edge(START, "set_invocation_type")
builder.add_edge("set_invocation_type", "llm_assistant")
builder.add_conditional_edges("llm_assistant", tools_condition)
builder.add_edge("tools", "llm_assistant")

rag_graph = builder.compile()