freddyaboulton's picture
Upload folder using huggingface_hub
4a093d2 verified
from transformers.agents import Agent, agent_types
from pydantic import Field
from gradio.data_classes import GradioModel, FileData, GradioRootModel
from typing import Literal, List, Generator, Optional, Union
from threading import Thread
import time
class ThoughtMetadata(GradioModel):
tool_name: Optional[str] = None
error: bool = False
class Message(GradioModel):
role: Literal["user", "assistant"]
thought_metadata: ThoughtMetadata = Field(default_factory=ThoughtMetadata)
class ChatMessage(Message):
content: str
class ChatFileMessage(Message):
file: FileData
alt_text: Optional[str] = None
class ChatbotData(GradioRootModel):
root: List[Union[ChatMessage, ChatFileMessage]]
def pull_messages(new_messages: List[dict]):
for message in new_messages:
if not len(message):
continue
if message.get("rationale"):
yield ChatMessage(
role="assistant", content=message["rationale"], thought=True
)
if message.get("tool_call"):
used_code = message["tool_call"]["tool_name"] == "code interpreter"
content = message["tool_call"]["tool_arguments"]
if used_code:
content = f"```py\n{content}\n```"
yield ChatMessage(
role="assistant",
thought_metadata=ThoughtMetadata(
tool_name=message["tool_call"]["tool_name"]
),
content=content,
thought=True,
)
if message.get("observation"):
yield ChatMessage(
role="assistant", content=message["observation"], thought=True
)
if message.get("error"):
yield ChatMessage(
role="assistant",
content=str(message["error"]),
thought=True,
thought_metadata=ThoughtMetadata(error=True),
)
def stream_from_transformers_agent(
agent: Agent, prompt: str
) -> Generator[ChatMessage, None, None]:
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
class Output:
output: agent_types.AgentType | str = None
def run_agent():
output = agent.run(prompt)
Output.output = output
thread = Thread(target=run_agent)
num_messages = 0
# Start thread and pull logs while it runs
thread.start()
while thread.is_alive():
if len(agent.logs) > num_messages:
new_messages = agent.logs[num_messages:]
for msg in pull_messages(new_messages):
yield msg
num_messages += 1
time.sleep(0.1)
thread.join(0.1)
if len(agent.logs) > num_messages:
new_messages = agent.logs[num_messages:]
yield from pull_messages(new_messages)
if isinstance(Output.output, agent_types.AgentText):
yield ChatMessage(
role="assistant", content=Output.output.to_string(), thought=True
)
elif isinstance(Output.output, agent_types.AgentImage):
yield ChatFileMessage(
role="assistant",
file=FileData(path=Output.output.to_string(), mime_type="image/png"),
content="",
thought=True,
)
elif isinstance(Output.output, agent_types.AgentAudio):
yield ChatFileMessage(
role="assistant",
file=FileData(path=Output.output.to_string(), mime_type="audio/wav"),
content="",
thought=True,
)
else:
return ChatMessage(role="assistant", content=Output.output, thought=True)