HFAgentsCourse / agent.py
nicolacaione's picture
Used to pass the exam
3874cd4
from langchain_core.language_models import BaseChatModel
from langchain.chat_models import init_chat_model
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
from typing import Annotated, Literal, TypedDict, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.runnables import Runnable
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import BaseMessage
from langsmith import traceable
import os
import streamlit as st
# Configura LangSmith con validazione delle chiavi
def setup_langsmith():
"""Setup LangSmith with proper validation and error handling."""
api_key = os.getenv("LANGCHAIN_API_KEY", "")
if not api_key or api_key == "your_langsmith_api_key_here":
print("⚠️ WARNING: LANGCHAIN_API_KEY not set or using placeholder value")
print(" LangSmith tracing will be disabled. Set your API key in .env file")
print(" Get your API key from: https://smith.langchain.com")
os.environ["LANGCHAIN_TRACING_V2"] = "false"
return False
# Configure LangSmith with valid API key
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = api_key
os.environ["LANGCHAIN_PROJECT"] = "gaia-agent-project"
print("βœ… LangSmith tracing configured successfully")
return True
# Initialize LangSmith
langsmith_enabled = setup_langsmith()
# Import constants and tools from the proper module
from constants import BASE_PROMPT_GAIA_BENCHMARK, AGENT_TOOLS
# Get debug setting from environment with fallback
DEBUG_AGENT_IN_CHAIN = os.getenv("DEBUG_AGENT_IN_CHAIN", "true").lower() == "true"
class State(TypedDict):
"""Stato del graph agent"""
messages: Annotated[list, add_messages]
class AgentGraph:
"""Agent GAIA con tracing LangSmith integrato"""
def __init__(self, model_string: str = "gpt-4o", tools: list = AGENT_TOOLS, **kwargs):
self._tools = tools
self._model = init_chat_model(model_string)
if self._tools:
self._model = self._model.bind_tools(self._tools)
# Definisci il graph
workflow = StateGraph(State)
workflow.add_node("agent", self.call_model)
if self._tools:
tool_node = ToolNode(self._tools)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", tools_condition)
workflow.add_edge("tools", 'agent')
else:
workflow.add_edge(START, "agent")
workflow.add_edge("agent", END)
self.app = workflow.compile()
@property
def tools(self):
return self._tools
@property
def model(self) -> Union[BaseChatModel, Runnable[LanguageModelInput, BaseMessage]]:
return self._model
@traceable(name="call_model")
def call_model(self, state: State):
"""Chiamata al modello - tracciata automaticamente da LangSmith"""
messages = state['messages']
response = self._model.invoke(messages)
return {"messages": [response]}
@traceable(name="agent_execution")
def __call__(self, query: str, max_iterations: int = 15):
"""Enhanced agent execution with better control and error handling"""
state = {
"messages": [
{"role": "system", "content": BASE_PROMPT_GAIA_BENCHMARK},
{"role": "user", "content": query},
]
}
try:
# Execute agent with iteration limit
agent_responses = []
iteration_count = 0
for response in self.execution(state):
agent_responses.append(response)
iteration_count += 1
# Prevent infinite loops
if iteration_count >= max_iterations:
print(f"⚠️ Agent reached maximum iterations ({max_iterations}), stopping execution")
break
# Check if we have a final answer
if agent_responses:
last_message = agent_responses[-1].get('messages', [])
if last_message and hasattr(last_message[-1], 'content'):
content = str(last_message[-1].content)
if "FINAL ANSWER:" in content.upper():
# Found final answer, we can stop early
break
# Extract the final response
if agent_responses and agent_responses[-1].get('messages'):
final_content = agent_responses[-1]['messages'][-1].content
print(f"βœ… Agent completed in {iteration_count} iterations")
return final_content
else:
print("❌ No valid response generated")
return "No response generated - please try rephrasing your question"
except Exception as e:
error_msg = f"Agent execution failed: {str(e)}"
print(f"❌ {error_msg}")
return f"Error: {error_msg}"
@traceable(name="agent_stream_execution")
def execution(self, state: State | None):
"""Esecuzione streaming dell'agent"""
for event in self.app.stream(
state,
stream_mode="values",
debug=DEBUG_AGENT_IN_CHAIN
):
yield event
def add_state_message(prompt, role: Literal["user", "assistant"] = "user"):
"""Aggiunge messaggio allo stato della sessione"""
if 'state' not in st.session_state:
st.session_state.state = {"messages": []}
st.session_state.state['messages'].append({"role": role, "content": prompt})
@traceable(name="streamlit_chat_interface")
def main():
"""Interfaccia Streamlit con tracing LangSmith"""
st.set_page_config(
page_title="GAIA Agent",
page_icon="πŸ€–",
layout="wide"
)
st.title("πŸ€– GAIA Agent con LangSmith Tracing")
st.markdown("*Un assistente AI intelligente con monitoraggio completo*")
# Sidebar per configurazione
with st.sidebar:
st.header("βš™οΈ Configurazione")
model_choice = st.selectbox(
"Modello:",
["gpt-4o", "gpt-4", "gpt-3.5-turbo", "claude-3-sonnet"]
)
st.header("πŸ“Š LangSmith")
st.write("Tracing automatico attivo βœ…")
if st.button("πŸ“± Apri LangSmith Dashboard"):
st.write("Vai su: https://smith.langchain.com")
# Inizializza l'agent
if 'agent' not in st.session_state:
st.session_state.agent = AgentGraph(model_string=model_choice, tools=AGENT_TOOLS)
# Inizializza stato conversazione
if 'state' not in st.session_state:
st.session_state.state = {
"messages": [{"role": "system", "content": BASE_PROMPT_GAIA_BENCHMARK}]
}
# Area di input
st.header("πŸ’¬ Conversazione")
# Container per i messaggi
chat_container = st.container()
# Input dell'utente
with st.form(key="chat_form", clear_on_submit=True):
user_input = st.text_area(
"Inserisci la tua domanda:",
placeholder="Es: Spiegami come funziona l'intelligenza artificiale...",
height=100
)
col1, col2 = st.columns([1, 5])
with col1:
submit = st.form_submit_button("πŸš€ Invia")
with col2:
if st.form_submit_button("πŸ—‘οΈ Pulisci Chat"):
st.session_state.state = {
"messages": [{"role": "system", "content": BASE_PROMPT_GAIA_BENCHMARK}]
}
st.rerun()
# Gestisci invio messaggio
if submit and user_input.strip():
with chat_container:
# Mostra messaggio utente
with st.chat_message("user"):
st.write(user_input)
# Aggiungi messaggio utente allo stato
add_state_message(user_input, "user")
# Genera risposta agent
with st.chat_message("assistant"):
with st.spinner("πŸ€” GAIA sta pensando..."):
try:
# Esegui l'agent (tracciato automaticamente)
response = st.session_state.agent(user_input)
st.write(response)
# Aggiungi risposta allo stato
add_state_message(response, "assistant")
except Exception as e:
st.error(f"❌ Errore: {str(e)}")
st.info("πŸ’‘ Controlla la configurazione API nel codice")
# Mostra cronologia chat
if len(st.session_state.state["messages"]) > 1:
with chat_container:
st.subheader("πŸ“œ Cronologia Conversazione")
for i, msg in enumerate(st.session_state.state["messages"]):
if isinstance(msg, dict) and msg.get("role") != "system":
with st.chat_message(msg["role"]):
st.write(msg["content"])
# Footer con info
st.markdown("---")
st.markdown("""
**πŸ” Monitoraggio LangSmith:**
- βœ… Trace automatici per ogni conversazione
- πŸ“Š Metriche di performance e costi
- πŸ› οΈ Debug dettagliato del reasoning
- πŸ“ˆ Analytics e feedback degli utenti
""")