| from __future__ import annotations | |
| from typing import List | |
| from taskweaver.memory.conversation import Conversation | |
| from taskweaver.memory.round import Round | |
| from taskweaver.memory.type_vars import RoleName | |
| class Memory: | |
| """ | |
| Memory is used to store all the conversations in the system, | |
| which should be initialized when creating a session. | |
| """ | |
| def __init__(self, session_id: str) -> None: | |
| self.session_id = session_id | |
| self.conversation = Conversation.init() | |
| def create_round(self, user_query: str) -> Round: | |
| """Create a round with the given query.""" | |
| round = Round.create(user_query=user_query) | |
| self.conversation.add_round(round) | |
| return round | |
| def get_role_rounds(self, role: RoleName, include_failure_rounds: bool = False) -> List[Round]: | |
| """Get all the rounds of the given role in the memory. | |
| TODO: better do cache here to avoid recreating the round list (new object) every time. | |
| Args: | |
| role: the role of the memory. | |
| include_failure_rounds: whether to include the failure rounds. | |
| """ | |
| rounds_from_role: List[Round] = [] | |
| for round in self.conversation.rounds: | |
| new_round = Round.create(user_query=round.user_query, id=round.id, state=round.state) | |
| for post in round.post_list: | |
| if round.state == "failed" and not include_failure_rounds: | |
| continue | |
| if post.send_from == role or post.send_to == role: | |
| new_round.add_post(post) | |
| rounds_from_role.append(new_round) | |
| return rounds_from_role | |