|
|
from typing import Callable, List, Set, Tuple |
|
|
|
|
|
from injector import inject |
|
|
|
|
|
from taskweaver.config.module_config import ModuleConfig |
|
|
from taskweaver.llm import LLMApi |
|
|
from taskweaver.llm.util import format_chat_message |
|
|
from taskweaver.logging import TelemetryLogger |
|
|
from taskweaver.memory import Round |
|
|
|
|
|
|
|
|
class RoundCompressorConfig(ModuleConfig): |
|
|
def _configure(self) -> None: |
|
|
self._set_name("round_compressor") |
|
|
self.rounds_to_compress = self._get_int("rounds_to_compress", 2) |
|
|
self.rounds_to_retain = self._get_int("rounds_to_retain", 3) |
|
|
|
|
|
assert self.rounds_to_compress > 0, "rounds_to_compress must be greater than 0" |
|
|
assert self.rounds_to_retain > 0, "rounds_to_retain must be greater than 0" |
|
|
|
|
|
|
|
|
class RoundCompressor: |
|
|
@inject |
|
|
def __init__( |
|
|
self, |
|
|
llm_api: LLMApi, |
|
|
config: RoundCompressorConfig, |
|
|
logger: TelemetryLogger, |
|
|
): |
|
|
self.config = config |
|
|
self.processed_rounds: Set[str] = set() |
|
|
self.rounds_to_compress = self.config.rounds_to_compress |
|
|
self.rounds_to_retain = self.config.rounds_to_retain |
|
|
self.previous_summary: str = "None" |
|
|
self.llm_api = llm_api |
|
|
self.logger = logger |
|
|
|
|
|
def compress_rounds( |
|
|
self, |
|
|
rounds: List[Round], |
|
|
rounds_formatter: Callable, |
|
|
use_back_up_engine: bool = False, |
|
|
prompt_template: str = "{PREVIOUS_SUMMARY}, please compress the following rounds", |
|
|
) -> Tuple[str, List[Round]]: |
|
|
remaining_rounds = len(rounds) |
|
|
for _round in rounds: |
|
|
if _round.id in self.processed_rounds: |
|
|
remaining_rounds -= 1 |
|
|
continue |
|
|
break |
|
|
|
|
|
|
|
|
if remaining_rounds < (self.rounds_to_compress + self.rounds_to_retain): |
|
|
return self.previous_summary, rounds[-remaining_rounds:] |
|
|
|
|
|
chat_summary = self._summarize( |
|
|
rounds[-remaining_rounds : -self.rounds_to_retain], |
|
|
rounds_formatter, |
|
|
use_back_up_engine=use_back_up_engine, |
|
|
prompt_template=prompt_template, |
|
|
) |
|
|
|
|
|
if len(chat_summary) > 0: |
|
|
self.previous_summary = chat_summary |
|
|
return chat_summary, rounds[-self.rounds_to_retain :] |
|
|
else: |
|
|
return self.previous_summary, rounds[-remaining_rounds:] |
|
|
|
|
|
def _summarize( |
|
|
self, |
|
|
rounds: List[Round], |
|
|
rounds_formatter: Callable, |
|
|
use_back_up_engine: bool = False, |
|
|
prompt_template: str = "{PREVIOUS_SUMMARY}, please compress the following rounds", |
|
|
) -> str: |
|
|
assert "{PREVIOUS_SUMMARY}" in prompt_template, "Prompt template must contain {PREVIOUS_SUMMARY}" |
|
|
try: |
|
|
chat_history_str = rounds_formatter(rounds) |
|
|
system_instruction = prompt_template.format( |
|
|
PREVIOUS_SUMMARY=self.previous_summary, |
|
|
) |
|
|
prompt = [ |
|
|
format_chat_message("system", system_instruction), |
|
|
format_chat_message("user", chat_history_str), |
|
|
] |
|
|
new_summary = self.llm_api.chat_completion(prompt, use_backup_engine=use_back_up_engine)["content"] |
|
|
self.processed_rounds.update([_round.id for _round in rounds]) |
|
|
return new_summary |
|
|
except Exception as e: |
|
|
self.logger.warning(f"Failed to compress rounds: {e}") |
|
|
return "" |
|
|
|