Spaces:
Runtime error
Runtime error
| from transformers.agents import Agent, agent_types | |
| from pydantic import Field | |
| from gradio.data_classes import GradioModel, FileData, GradioRootModel | |
| from typing import Literal, List, Generator, Optional, Union | |
| from threading import Thread | |
| import time | |
| class ThoughtMetadata(GradioModel): | |
| tool_name: Optional[str] = None | |
| error: bool = False | |
| class Message(GradioModel): | |
| role: Literal["user", "assistant"] | |
| thought_metadata: ThoughtMetadata = Field(default_factory=ThoughtMetadata) | |
| class ChatMessage(Message): | |
| content: str | |
| class ChatFileMessage(Message): | |
| file: FileData | |
| alt_text: Optional[str] = None | |
| class ChatbotData(GradioRootModel): | |
| root: List[Union[ChatMessage, ChatFileMessage]] | |
| def pull_messages(new_messages: List[dict]): | |
| for message in new_messages: | |
| if not len(message): | |
| continue | |
| if message.get("rationale"): | |
| yield ChatMessage( | |
| role="assistant", content=message["rationale"], thought=True | |
| ) | |
| if message.get("tool_call"): | |
| used_code = message["tool_call"]["tool_name"] == "code interpreter" | |
| content = message["tool_call"]["tool_arguments"] | |
| if used_code: | |
| content = f"```py\n{content}\n```" | |
| yield ChatMessage( | |
| role="assistant", | |
| thought_metadata=ThoughtMetadata( | |
| tool_name=message["tool_call"]["tool_name"] | |
| ), | |
| content=content, | |
| thought=True, | |
| ) | |
| if message.get("observation"): | |
| yield ChatMessage( | |
| role="assistant", content=message["observation"], thought=True | |
| ) | |
| if message.get("error"): | |
| yield ChatMessage( | |
| role="assistant", | |
| content=str(message["error"]), | |
| thought=True, | |
| thought_metadata=ThoughtMetadata(error=True), | |
| ) | |
| def stream_from_transformers_agent( | |
| agent: Agent, prompt: str | |
| ) -> Generator[ChatMessage, None, None]: | |
| """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages.""" | |
| class Output: | |
| output: agent_types.AgentType | str = None | |
| def run_agent(): | |
| output = agent.run(prompt) | |
| Output.output = output | |
| thread = Thread(target=run_agent) | |
| num_messages = 0 | |
| # Start thread and pull logs while it runs | |
| thread.start() | |
| while thread.is_alive(): | |
| if len(agent.logs) > num_messages: | |
| new_messages = agent.logs[num_messages:] | |
| for msg in pull_messages(new_messages): | |
| yield msg | |
| num_messages += 1 | |
| time.sleep(0.1) | |
| thread.join(0.1) | |
| if len(agent.logs) > num_messages: | |
| new_messages = agent.logs[num_messages:] | |
| yield from pull_messages(new_messages) | |
| if isinstance(Output.output, agent_types.AgentText): | |
| yield ChatMessage( | |
| role="assistant", content=Output.output.to_string(), thought=True | |
| ) | |
| elif isinstance(Output.output, agent_types.AgentImage): | |
| yield ChatFileMessage( | |
| role="assistant", | |
| file=FileData(path=Output.output.to_string(), mime_type="image/png"), | |
| content="", | |
| thought=True, | |
| ) | |
| elif isinstance(Output.output, agent_types.AgentAudio): | |
| yield ChatFileMessage( | |
| role="assistant", | |
| file=FileData(path=Output.output.to_string(), mime_type="audio/wav"), | |
| content="", | |
| thought=True, | |
| ) | |
| else: | |
| return ChatMessage(role="assistant", content=Output.output, thought=True) | |