chess-agent / utils /helpers.py
czakop's picture
bump required version of chessboard component
402894f
from __future__ import annotations
from typing import TYPE_CHECKING, AsyncIterator
import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langgraph.prebuilt import create_react_agent
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph.graph import CompiledGraph
MESSAGE_TYPE = BaseMessage | gr.ChatMessage | dict[str, str]
def create_agent(
model_name: str, provider: str, api_key: str, tools: list
) -> CompiledGraph:
"""Create a React agent with the specified model."""
model = _create_model(model_name, provider, api_key)
return create_react_agent(
model,
tools=tools,
)
async def call_agent(
agent: CompiledGraph, messages: list[MESSAGE_TYPE], prompt: HumanMessage
) -> AsyncIterator[list[MESSAGE_TYPE]]:
async for chunk in agent.astream(
{
"messages": [_convert_to_langchain_message(msg) for msg in messages[:-1]]
+ [prompt]
}
):
if "tools" in chunk:
for step in chunk["tools"]["messages"]:
messages.append(
gr.ChatMessage(
role="assistant",
content=step.content,
metadata={"title": f"🛠️ Used tool {step.name}"},
)
)
yield messages
if "agent" in chunk:
messages.append(
gr.ChatMessage(
role="assistant",
content=_get_chunk_message_content(chunk),
)
)
yield messages
def _create_model(model_name: str, provider: str, api_key: str) -> BaseChatModel:
"""Get the chat model based on the provider and model name."""
if provider == "Anthropic":
return init_chat_model(
"anthropic:" + model_name,
anthropic_api_key=api_key,
)
elif provider == "Mistral":
return init_chat_model(
"mistralai:" + model_name,
mistral_api_key=api_key,
)
elif provider == "OpenAI":
return init_chat_model(
"openai:" + model_name,
openai_api_key=api_key,
)
else:
raise ValueError(f"Unsupported model provider: {provider}")
def _is_ai_message(message: MESSAGE_TYPE) -> bool:
if isinstance(message, AIMessage):
return True
if isinstance(message, gr.ChatMessage):
return message.role == "assistant"
if isinstance(message, dict):
return message.get("role") == "assistant"
return False
def _convert_to_langchain_message(message: MESSAGE_TYPE) -> BaseMessage:
if isinstance(message, BaseMessage):
return message
if isinstance(message, gr.ChatMessage):
return (
AIMessage(content=message.content)
if _is_ai_message(message)
else HumanMessage(content=message.content)
)
if isinstance(message, dict):
return (
AIMessage(content=message.get("content", ""))
if _is_ai_message(message)
else HumanMessage(content=message.get("content", ""))
)
raise ValueError(f"Unsupported message type: {type(message)}")
def _get_chunk_message_content(chunk: dict) -> str:
msg_object = chunk["agent"]["messages"][0]
message = msg_object.content
if isinstance(message, list):
message = message[0] if message else ""
if isinstance(message, dict):
message = message.get("text")
return message or "Calling tool(s)"