Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, AsyncIterator | |
| import gradio as gr | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | |
| from langgraph.prebuilt import create_react_agent | |
| if TYPE_CHECKING: | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langgraph.graph.graph import CompiledGraph | |
| MESSAGE_TYPE = BaseMessage | gr.ChatMessage | dict[str, str] | |
| def create_agent( | |
| model_name: str, provider: str, api_key: str, tools: list | |
| ) -> CompiledGraph: | |
| """Create a React agent with the specified model.""" | |
| model = _create_model(model_name, provider, api_key) | |
| return create_react_agent( | |
| model, | |
| tools=tools, | |
| ) | |
| async def call_agent( | |
| agent: CompiledGraph, messages: list[MESSAGE_TYPE], prompt: HumanMessage | |
| ) -> AsyncIterator[list[MESSAGE_TYPE]]: | |
| async for chunk in agent.astream( | |
| { | |
| "messages": [_convert_to_langchain_message(msg) for msg in messages[:-1]] | |
| + [prompt] | |
| } | |
| ): | |
| if "tools" in chunk: | |
| for step in chunk["tools"]["messages"]: | |
| messages.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=step.content, | |
| metadata={"title": f"🛠️ Used tool {step.name}"}, | |
| ) | |
| ) | |
| yield messages | |
| if "agent" in chunk: | |
| messages.append( | |
| gr.ChatMessage( | |
| role="assistant", | |
| content=_get_chunk_message_content(chunk), | |
| ) | |
| ) | |
| yield messages | |
| def _create_model(model_name: str, provider: str, api_key: str) -> BaseChatModel: | |
| """Get the chat model based on the provider and model name.""" | |
| if provider == "Anthropic": | |
| return init_chat_model( | |
| "anthropic:" + model_name, | |
| anthropic_api_key=api_key, | |
| ) | |
| elif provider == "Mistral": | |
| return init_chat_model( | |
| "mistralai:" + model_name, | |
| mistral_api_key=api_key, | |
| ) | |
| elif provider == "OpenAI": | |
| return init_chat_model( | |
| "openai:" + model_name, | |
| openai_api_key=api_key, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model provider: {provider}") | |
| def _is_ai_message(message: MESSAGE_TYPE) -> bool: | |
| if isinstance(message, AIMessage): | |
| return True | |
| if isinstance(message, gr.ChatMessage): | |
| return message.role == "assistant" | |
| if isinstance(message, dict): | |
| return message.get("role") == "assistant" | |
| return False | |
| def _convert_to_langchain_message(message: MESSAGE_TYPE) -> BaseMessage: | |
| if isinstance(message, BaseMessage): | |
| return message | |
| if isinstance(message, gr.ChatMessage): | |
| return ( | |
| AIMessage(content=message.content) | |
| if _is_ai_message(message) | |
| else HumanMessage(content=message.content) | |
| ) | |
| if isinstance(message, dict): | |
| return ( | |
| AIMessage(content=message.get("content", "")) | |
| if _is_ai_message(message) | |
| else HumanMessage(content=message.get("content", "")) | |
| ) | |
| raise ValueError(f"Unsupported message type: {type(message)}") | |
| def _get_chunk_message_content(chunk: dict) -> str: | |
| msg_object = chunk["agent"]["messages"][0] | |
| message = msg_object.content | |
| if isinstance(message, list): | |
| message = message[0] if message else "" | |
| if isinstance(message, dict): | |
| message = message.get("text") | |
| return message or "Calling tool(s)" | |