Spaces:
Sleeping
Sleeping
File size: 5,418 Bytes
3107242 8af85fa 3107242 1709069 3107242 1709069 3107242 1709069 3107242 1709069 3107242 6437579 3107242 1709069 3107242 1709069 3107242 1709069 3107242 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import sys
import os
sys.path.append(os.getcwd())
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import tools_condition
from langchain_core.tools import tool
from src.rag.documents_rag_pipeline import RAGPipeline
from src.rag.youtube_rag_pipeline import YouTubeRAGPipeline
from src.models.llm_wrapper import GeminiWrapper
from src.utils.search_docs_utils import search_relevant_documents
from src.configs.config import LOG_DIR
import logging
import os
# LOGGING SETUP
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
filename=LOG_FILE,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# PIPELINE INIT
doc_rag_pipeline = RAGPipeline()
ytb_rag_pipeline = YouTubeRAGPipeline()
# TOOLS
@tool
def search_legal_documents(query: str) -> str:
"""Search Moroccan **legal codes, decrees, and PDFs**. Use this tool when user asks about Moroccan **laws, articles, codes, legal references**, or related content."""
logging.info(f"Tool call: search_legal_documents with input: {query}")
try:
relevant_docs = search_relevant_documents(query, top_k=1)
results = doc_rag_pipeline.get_top_docs_chunks_for_query(query, relevant_docs['results'])
output = "\n".join([f"{r['text']}\n(Source: {r['pdf_link']})" for r in results])
logging.info(f"Tool result (search_legal_documents): {output}")
return output
except Exception as e:
logging.error(f"Error in search_legal_documents: {e}")
return "Une erreur est survenue lors de la recherche des documents légaux."
@tool
def search_youtube_transcripts(query: str) -> str:
"""Search Moroccan **Parliament YouTube videos** (debates, speeches, questions, sessions). Use this tool when user asks about **video content** or **what someone said**."""
logging.info(f"Tool call: search_youtube_transcripts with input: {query}")
try:
results = ytb_rag_pipeline.search(query)
output = "\n".join([f"{r['texte']}\n(Source: {r['lien']})" for r in results])
logging.info(f"Tool result (search_youtube_transcripts): {output}")
return output
except Exception as e:
logging.error(f"Error in search_youtube_transcripts: {e}")
return "Une erreur est survenue lors de la recherche dans les transcriptions YouTube."
tools = [search_legal_documents, search_youtube_transcripts]
# LLM SETUP
llm = GeminiWrapper()
llm_with_tools = llm.bind_tools(tools)
# LLM NODE
def assistant(state: MessagesState):
# Utiliser l'état d'invocation global
#invocation_type = invocation_state.invocation_type
user_msg = state["messages"][-1].content
logging.info(f"User input: {user_msg}")
# Créer le SystemMessage selon le type d'invocation
if invocation_state.invocation_type == 'chatbot':
sys_msg = SystemMessage(content="""
You are a helpful assistant specialized in answering user questions related to Moroccan Parliament YouTube videos and legal documents.
Your response must be strictly in the same language as the user’s query.
Provide accurate answers and include relevant sources (YouTube video links or PDF document links) in your response.
""")
else :
sys_msg = SystemMessage(content="""
You are a helpful assistant for an API that answers user questions related to Moroccan Parliament YouTube videos and legal documents.
Ensure your responses are concise and formatted appropriately for API output. Your response should be maximum 100 words.
Your response shoud be in json format like that : {"text response" : "", "sources" : ""}
""")
try:
print(state['messages'])
result = llm_with_tools.invoke([sys_msg] + state["messages"])
logging.info(f"🤖 Model Output : {result}")
return {"messages": [result]}
except Exception as e:
logging.error(f"LLM invocation failed: {e}")
return {"messages": [SystemMessage(content="Une erreur est survenue avec le modèle.")]}
# Tool to set the invocation type based on the user's message
def set_invocation_type(state: MessagesState):
user_msg = state["messages"][-1].content.strip() # Normalize the input
logging.info(f"User input for type update: {user_msg}")
# Set the invocation type based on the presence of 'APICALL'
if 'apicall' in user_msg:
invocation_state.invocation_type = 'API'
else:
invocation_state.invocation_type = 'chatbot'
logging.info(f"Invocation type set to: {invocation_state.invocation_type}")
class InvocationState:
def __init__(self):
self.invocation_type = None
# Créer une instance de l'état d'invocation
invocation_state = InvocationState()
# GRAPH SETUP
builder = StateGraph(MessagesState)
# Add the node to set the invocation type
builder.add_node("set_invocation_type", set_invocation_type)
builder.add_node("llm_assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Define the edges of the graph
builder.add_edge(START, "set_invocation_type")
builder.add_edge("set_invocation_type", "llm_assistant")
builder.add_conditional_edges("llm_assistant", tools_condition)
builder.add_edge("tools", "llm_assistant")
rag_graph = builder.compile()
|