File size: 1,646 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import annotations

from typing import List

from taskweaver.memory.conversation import Conversation
from taskweaver.memory.round import Round
from taskweaver.memory.type_vars import RoleName


class Memory:
    """
    Memory is used to store all the conversations in the system,
    which should be initialized when creating a session.
    """

    def __init__(self, session_id: str) -> None:
        self.session_id = session_id
        self.conversation = Conversation.init()

    def create_round(self, user_query: str) -> Round:
        """Create a round with the given query."""
        round = Round.create(user_query=user_query)
        self.conversation.add_round(round)
        return round

    def get_role_rounds(self, role: RoleName, include_failure_rounds: bool = False) -> List[Round]:
        """Get all the rounds of the given role in the memory.
        TODO: better do cache here to avoid recreating the round list (new object) every time.

        Args:
            role: the role of the memory.
            include_failure_rounds: whether to include the failure rounds.
        """
        rounds_from_role: List[Round] = []
        for round in self.conversation.rounds:
            new_round = Round.create(user_query=round.user_query, id=round.id, state=round.state)
            for post in round.post_list:
                if round.state == "failed" and not include_failure_rounds:
                    continue
                if post.send_from == role or post.send_to == role:
                    new_round.add_post(post)
            rounds_from_role.append(new_round)
        return rounds_from_role