File size: 2,697 Bytes
778116a 8073bab 778116a 8073bab 778116a 8073bab 778116a 2e38934 8073bab e511adb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List
import tiktoken
import yaml
from jinja2 import Environment, BaseLoader
class PromptType(Enum):
BASE_SYSTEM = "base_system"
ANSWER_REFINEMENT = "answer_refinement"
MEMORY_OPTIMIZATION = "memory_optimization"
TOOL = "tool"
SUB_AGENT = "sub_agent"
@dataclass
class PromptTemplate:
"""Structured prompt template with metadata"""
name: str
content: str
prompt_type: PromptType
variables: List[str] = field(default_factory=list)
token_estimate: int = 0
version: str = "1.0"
description: str = ""
class PromptManager:
"""Centralized management for Agent's prompts"""
def __init__(self, prompt_config_path: str, model_name: str = "gpt-4.1"):
self.templates: Dict[str, PromptTemplate] = {}
self.jinja_env = Environment(loader=BaseLoader())
self.token_counter = tiktoken.encoding_for_model(model_name)
# Load prompts from config
self.load_prompts_from_config(prompt_config_path)
def load_prompts_from_config(self, config_path: str):
"""Load prompts from YAML configuration file"""
path = Path(config_path)
if path.suffix.lower() == '.yaml' or path.suffix.lower() == '.yml':
with open(path, 'r') as f:
config = yaml.safe_load(f)
for name, prompt_data in config.get('prompts', {}).items():
template = PromptTemplate(
name=name,
content=prompt_data['content'],
prompt_type=PromptType(prompt_data.get('type', 'base_system')),
variables=prompt_data.get('variables', []),
version=prompt_data.get('version', '1.0'),
description=prompt_data.get('description', '')
)
template.token_estimate = self._estimate_tokens(template.content)
self.templates[name] = template
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text"""
return len(self.token_counter.encode(text))
def render_template(self, name: str, state: Dict[str, Any]) -> str:
"""Render template with current state variables"""
# Prepare template variables
template_vars = {}
# Add all state variables
template_vars.update(state)
# Create and render template
template = self.templates[name]
jinja_template = self.jinja_env.from_string(template.content)
return jinja_template.render(**template_vars)
# Global instance
prompt_mgmt = PromptManager("config"+os.sep+"prompts.yaml")
|