TRaw's picture
Upload 297 files
3d3d712
from __future__ import annotations
from typing import List
from taskweaver.memory.conversation import Conversation
from taskweaver.memory.round import Round
from taskweaver.memory.type_vars import RoleName
class Memory:
"""
Memory is used to store all the conversations in the system,
which should be initialized when creating a session.
"""
def __init__(self, session_id: str) -> None:
self.session_id = session_id
self.conversation = Conversation.init()
def create_round(self, user_query: str) -> Round:
"""Create a round with the given query."""
round = Round.create(user_query=user_query)
self.conversation.add_round(round)
return round
def get_role_rounds(self, role: RoleName, include_failure_rounds: bool = False) -> List[Round]:
"""Get all the rounds of the given role in the memory.
TODO: better do cache here to avoid recreating the round list (new object) every time.
Args:
role: the role of the memory.
include_failure_rounds: whether to include the failure rounds.
"""
rounds_from_role: List[Round] = []
for round in self.conversation.rounds:
new_round = Round.create(user_query=round.user_query, id=round.id, state=round.state)
for post in round.post_list:
if round.state == "failed" and not include_failure_rounds:
continue
if post.send_from == role or post.send_to == role:
new_round.add_post(post)
rounds_from_role.append(new_round)
return rounds_from_role