akseljoonas's picture
akseljoonas HF Staff
deploy
a2f44f6
"""
Context management for conversation history
"""
import os
import zoneinfo
from datetime import datetime
from pathlib import Path
from typing import Any
import yaml
from huggingface_hub import HfApi
from jinja2 import Template
from litellm import Message, acompletion
class ContextManager:
"""Manages conversation context and message history for the agent"""
def __init__(
self,
max_context: int = 180_000,
compact_size: float = 0.1,
untouched_messages: int = 5,
tool_specs: list[dict[str, Any]] | None = None,
prompt_file_suffix: str = "system_prompt_v2.yaml",
hf_token: str | None = None,
):
self.system_prompt = self._load_system_prompt(
tool_specs or [],
prompt_file_suffix="system_prompt_v2.yaml",
hf_token=hf_token,
)
self.max_context = max_context or 180_000
self.compact_size = int(self.max_context * compact_size)
self.context_length = len(self.system_prompt) // 4
self.untouched_messages = untouched_messages
self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
def _load_system_prompt(
self,
tool_specs: list[dict[str, Any]],
prompt_file_suffix: str = "system_prompt.yaml",
hf_token: str | None = None,
):
"""Load and render the system prompt from YAML file with Jinja2"""
prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}"
with open(prompt_file, "r") as f:
prompt_data = yaml.safe_load(f)
template_str = prompt_data.get("system_prompt", "")
# Get current date and time
tz = zoneinfo.ZoneInfo("Europe/Paris")
now = datetime.now(tz)
current_date = now.strftime("%d-%m-%Y")
current_time = now.strftime("%H:%M:%S.%f")[:-3]
current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})"
# Get HF user info - use provided token or fall back to env var
token = hf_token or os.environ.get("HF_TOKEN")
hf_user_info = "user" # Default; actual user info comes from OAuth
if token.strip():
try:
hf_user_info = HfApi(token=token.strip()).whoami().get("name", "user")
except Exception:
pass # Use default if whoami fails
template = Template(template_str)
return template.render(
tools=tool_specs,
num_tools=len(tool_specs),
current_date=current_date,
current_time=current_time,
current_timezone=current_timezone,
hf_user_info=hf_user_info,
)
def add_message(self, message: Message, token_count: int = None) -> None:
"""Add a message to the history"""
if token_count:
self.context_length = token_count
self.items.append(message)
def get_messages(self) -> list[Message]:
"""Get all messages for sending to LLM"""
return self.items
async def compact(self, model_name: str) -> None:
"""Remove old messages to keep history under target size"""
if (self.context_length <= self.max_context) or not self.items:
return
system_msg = (
self.items[0] if self.items and self.items[0].role == "system" else None
)
# Don't summarize a certain number of just-preceding messages
# Walk back to find a user message to make sure we keep an assistant -> user ->
# assistant general conversation structure
idx = len(self.items) - self.untouched_messages
while idx > 1 and self.items[idx].role != "user":
idx -= 1
recent_messages = self.items[idx:]
messages_to_summarize = self.items[1:idx]
# improbable, messages would have to very long
if not messages_to_summarize:
return
messages_to_summarize.append(
Message(
role="user",
content="Please provide a concise summary of the conversation above, focusing on key decisions, code changes, problems solved, and important context needed for future turns.",
)
)
response = await acompletion(
model=model_name,
messages=messages_to_summarize,
max_completion_tokens=self.compact_size,
)
summarized_message = Message(
role="assistant", content=response.choices[0].message.content
)
# Reconstruct: system + summary + recent messages (includes tools)
if system_msg:
self.items = [system_msg, summarized_message] + recent_messages
else:
self.items = [summarized_message] + recent_messages
self.context_length = (
len(self.system_prompt) // 4 + response.usage.completion_tokens
)