|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import tiktoken |
|
|
import yaml |
|
|
from jinja2 import Environment, BaseLoader |
|
|
|
|
|
|
|
|
class PromptType(Enum): |
|
|
BASE_SYSTEM = "base_system" |
|
|
ANSWER_REFINEMENT = "answer_refinement" |
|
|
MEMORY_OPTIMIZATION = "memory_optimization" |
|
|
TOOL = "tool" |
|
|
SUB_AGENT = "sub_agent" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PromptTemplate: |
|
|
"""Structured prompt template with metadata""" |
|
|
name: str |
|
|
content: str |
|
|
prompt_type: PromptType |
|
|
variables: List[str] = field(default_factory=list) |
|
|
token_estimate: int = 0 |
|
|
version: str = "1.0" |
|
|
description: str = "" |
|
|
|
|
|
|
|
|
class PromptManager: |
|
|
"""Centralized management for Agent's prompts""" |
|
|
|
|
|
def __init__(self, prompt_config_path: str, model_name: str = "gpt-4.1"): |
|
|
self.templates: Dict[str, PromptTemplate] = {} |
|
|
self.jinja_env = Environment(loader=BaseLoader()) |
|
|
self.token_counter = tiktoken.encoding_for_model(model_name) |
|
|
|
|
|
|
|
|
self.load_prompts_from_config(prompt_config_path) |
|
|
|
|
|
def load_prompts_from_config(self, config_path: str): |
|
|
"""Load prompts from YAML configuration file""" |
|
|
path = Path(config_path) |
|
|
|
|
|
if path.suffix.lower() == '.yaml' or path.suffix.lower() == '.yml': |
|
|
with open(path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
for name, prompt_data in config.get('prompts', {}).items(): |
|
|
template = PromptTemplate( |
|
|
name=name, |
|
|
content=prompt_data['content'], |
|
|
prompt_type=PromptType(prompt_data.get('type', 'base_system')), |
|
|
variables=prompt_data.get('variables', []), |
|
|
version=prompt_data.get('version', '1.0'), |
|
|
description=prompt_data.get('description', '') |
|
|
) |
|
|
template.token_estimate = self._estimate_tokens(template.content) |
|
|
self.templates[name] = template |
|
|
|
|
|
def _estimate_tokens(self, text: str) -> int: |
|
|
"""Estimate token count for text""" |
|
|
return len(self.token_counter.encode(text)) |
|
|
|
|
|
def render_template(self, name: str, state: Dict[str, Any]) -> str: |
|
|
"""Render template with current state variables""" |
|
|
|
|
|
|
|
|
template_vars = {} |
|
|
|
|
|
|
|
|
template_vars.update(state) |
|
|
|
|
|
|
|
|
template = self.templates[name] |
|
|
jinja_template = self.jinja_env.from_string(template.content) |
|
|
return jinja_template.render(**template_vars) |
|
|
|
|
|
|
|
|
|
|
|
prompt_mgmt = PromptManager("config"+os.sep+"prompts.yaml") |
|
|
|