File size: 3,447 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from typing import Callable, List, Set, Tuple
from injector import inject
from taskweaver.config.module_config import ModuleConfig
from taskweaver.llm import LLMApi
from taskweaver.llm.util import format_chat_message
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Round
class RoundCompressorConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("round_compressor")
self.rounds_to_compress = self._get_int("rounds_to_compress", 2)
self.rounds_to_retain = self._get_int("rounds_to_retain", 3)
assert self.rounds_to_compress > 0, "rounds_to_compress must be greater than 0"
assert self.rounds_to_retain > 0, "rounds_to_retain must be greater than 0"
class RoundCompressor:
@inject
def __init__(
self,
llm_api: LLMApi,
config: RoundCompressorConfig,
logger: TelemetryLogger,
):
self.config = config
self.processed_rounds: Set[str] = set()
self.rounds_to_compress = self.config.rounds_to_compress
self.rounds_to_retain = self.config.rounds_to_retain
self.previous_summary: str = "None"
self.llm_api = llm_api
self.logger = logger
def compress_rounds(
self,
rounds: List[Round],
rounds_formatter: Callable,
use_back_up_engine: bool = False,
prompt_template: str = "{PREVIOUS_SUMMARY}, please compress the following rounds",
) -> Tuple[str, List[Round]]:
remaining_rounds = len(rounds)
for _round in rounds:
if _round.id in self.processed_rounds:
remaining_rounds -= 1
continue
break
# not enough rounds to compress
if remaining_rounds < (self.rounds_to_compress + self.rounds_to_retain):
return self.previous_summary, rounds[-remaining_rounds:]
chat_summary = self._summarize(
rounds[-remaining_rounds : -self.rounds_to_retain],
rounds_formatter,
use_back_up_engine=use_back_up_engine,
prompt_template=prompt_template,
)
if len(chat_summary) > 0: # if the compression is successful
self.previous_summary = chat_summary
return chat_summary, rounds[-self.rounds_to_retain :]
else:
return self.previous_summary, rounds[-remaining_rounds:]
def _summarize(
self,
rounds: List[Round],
rounds_formatter: Callable,
use_back_up_engine: bool = False,
prompt_template: str = "{PREVIOUS_SUMMARY}, please compress the following rounds",
) -> str:
assert "{PREVIOUS_SUMMARY}" in prompt_template, "Prompt template must contain {PREVIOUS_SUMMARY}"
try:
chat_history_str = rounds_formatter(rounds)
system_instruction = prompt_template.format(
PREVIOUS_SUMMARY=self.previous_summary,
)
prompt = [
format_chat_message("system", system_instruction),
format_chat_message("user", chat_history_str),
]
new_summary = self.llm_api.chat_completion(prompt, use_backup_engine=use_back_up_engine)["content"]
self.processed_rounds.update([_round.id for _round in rounds])
return new_summary
except Exception as e:
self.logger.warning(f"Failed to compress rounds: {e}")
return ""
|