File size: 3,662 Bytes
c1d186b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from transformers.agents import Agent, agent_types
from pydantic import Field
from gradio.data_classes import GradioModel, FileData, GradioRootModel
from typing import Literal, List, Generator, Optional, Union
from threading import Thread
import time


class ThoughtMetadata(GradioModel):
    tool_name: Optional[str] = None
    error: bool = False


class Message(GradioModel):
    role: Literal["user", "assistant"]
    thought_metadata: ThoughtMetadata = Field(default_factory=ThoughtMetadata)


class ChatMessage(Message):
    content: str


class ChatFileMessage(Message):
    file: FileData
    alt_text: Optional[str] = None


class ChatbotData(GradioRootModel):
    root: List[Union[ChatMessage, ChatFileMessage]]


def pull_messages(new_messages: List[dict]):
    for message in new_messages:
        if not len(message):
            continue
        if message.get("rationale"):
            yield ChatMessage(
                role="assistant", content=message["rationale"], thought=True
            )
        if message.get("tool_call"):
            used_code = message["tool_call"]["tool_name"] == "code interpreter"
            content = message["tool_call"]["tool_arguments"]
            if used_code:
                content = f"```py\n{content}\n```"
            yield ChatMessage(
                role="assistant",
                thought_metadata=ThoughtMetadata(
                    tool_name=message["tool_call"]["tool_name"]
                ),
                content=content,
                thought=True,
            )
        if message.get("observation"):
            yield ChatMessage(
                role="assistant", content=message["observation"], thought=True
            )
        if message.get("error"):
            yield ChatMessage(
                role="assistant",
                content=str(message["error"]),
                thought=True,
                thought_metadata=ThoughtMetadata(error=True),
            )


def stream_from_transformers_agent(
    agent: Agent, prompt: str
) -> Generator[ChatMessage, None, None]:
    """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""

    class Output:
        output: agent_types.AgentType | str = None

    def run_agent():
        output = agent.run(prompt)
        Output.output = output

    thread = Thread(target=run_agent)
    num_messages = 0

    # Start thread and pull logs while it runs
    thread.start()
    while thread.is_alive():
        if len(agent.logs) > num_messages:
            new_messages = agent.logs[num_messages:]
            for msg in pull_messages(new_messages):
                yield msg
                num_messages += 1
        time.sleep(0.1)

    thread.join(0.1)

    if len(agent.logs) > num_messages:
        new_messages = agent.logs[num_messages:]
        yield from pull_messages(new_messages)

    if isinstance(Output.output, agent_types.AgentText):
        yield ChatMessage(
            role="assistant", content=Output.output.to_string(), thought=True
        )
    elif isinstance(Output.output, agent_types.AgentImage):
        yield ChatFileMessage(
            role="assistant",
            file=FileData(path=Output.output.to_string(), mime_type="image/png"),
            content="",
            thought=True,
        )
    elif isinstance(Output.output, agent_types.AgentAudio):
        yield ChatFileMessage(
            role="assistant",
            file=FileData(path=Output.output.to_string(), mime_type="audio/wav"),
            content="",
            thought=True,
        )
    else:
        return ChatMessage(role="assistant", content=Output.output, thought=True)