File size: 2,572 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

import secrets
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Union

from taskweaver.memory.type_vars import RoundState
from taskweaver.utils import create_id

from .post import Post


@dataclass
class Round:
    """A round is the basic unit of conversation in the project, which is a collection of posts.

    Args:
        id: the unique id of the round.
        post_list: a list of posts in the round.
    """

    id: Optional[Union[str, None]]
    user_query: str
    state: RoundState
    post_list: List[Post]

    @staticmethod
    def create(
        user_query: str,
        id: Optional[Union[str, None]] = None,
        state: RoundState = "created",
        post_list: Optional[List[Post]] = None,
    ) -> Round:
        """Create a round with the given user query, id, and state."""
        return Round(
            id="round-" + create_id() if id is None else id,
            user_query=user_query,
            state=state,
            post_list=post_list if post_list is not None else [],
        )

    def __repr__(self):
        post_list_str = "\n".join([" " * 2 + str(item) for item in self.post_list])
        return "\n".join(
            [
                "Round:",
                f"- Query: {self.user_query}",
                f"- State: {self.state}",
                f"- Post Num:{len(self.post_list)}",
                f"- Post List: \n{post_list_str}\n\n",
            ],
        )

    def __str__(self):
        return self.__repr__()

    def to_dict(self) -> Dict[str, Any]:
        """Convert the round to a dict."""
        return {
            "id": self.id,
            "user_query": self.user_query,
            "state": self.state,
            "post_list": [post.to_dict() for post in self.post_list],
        }

    @staticmethod
    def from_dict(content: Dict[str, Any]) -> Round:
        """Convert the dict to a round. Will assign a new id to the round."""
        return Round(
            id="round-" + secrets.token_hex(6),
            user_query=content["user_query"],
            state=content["state"],
            post_list=[Post.from_dict(post) for post in content["post_list"]]
            if content["post_list"] is not None
            else [],
        )

    def add_post(self, post: Post):
        """Add a post to the post list."""
        self.post_list.append(post)

    def change_round_state(self, new_state: Literal["finished", "failed", "created"]):
        """Change the state of the round."""
        self.state = new_state