File size: 2,697 Bytes
778116a
8073bab
778116a
 
 
 
8073bab
 
778116a
8073bab
 
 
 
 
 
778116a
2e38934
8073bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e511adb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List

import tiktoken
import yaml
from jinja2 import Environment, BaseLoader


class PromptType(Enum):
    BASE_SYSTEM = "base_system"
    ANSWER_REFINEMENT = "answer_refinement"
    MEMORY_OPTIMIZATION = "memory_optimization"
    TOOL = "tool"
    SUB_AGENT = "sub_agent"


@dataclass
class PromptTemplate:
    """Structured prompt template with metadata"""
    name: str
    content: str
    prompt_type: PromptType
    variables: List[str] = field(default_factory=list)
    token_estimate: int = 0
    version: str = "1.0"
    description: str = ""


class PromptManager:
    """Centralized management for Agent's prompts"""

    def __init__(self, prompt_config_path: str, model_name: str = "gpt-4.1"):
        self.templates: Dict[str, PromptTemplate] = {}
        self.jinja_env = Environment(loader=BaseLoader())
        self.token_counter = tiktoken.encoding_for_model(model_name)

        # Load prompts from config
        self.load_prompts_from_config(prompt_config_path)

    def load_prompts_from_config(self, config_path: str):
        """Load prompts from YAML  configuration file"""
        path = Path(config_path)

        if path.suffix.lower() == '.yaml' or path.suffix.lower() == '.yml':
            with open(path, 'r') as f:
                config = yaml.safe_load(f)

        for name, prompt_data in config.get('prompts', {}).items():
            template = PromptTemplate(
                name=name,
                content=prompt_data['content'],
                prompt_type=PromptType(prompt_data.get('type', 'base_system')),
                variables=prompt_data.get('variables', []),
                version=prompt_data.get('version', '1.0'),
                description=prompt_data.get('description', '')
            )
            template.token_estimate = self._estimate_tokens(template.content)
            self.templates[name] = template

    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count for text"""
        return len(self.token_counter.encode(text))

    def render_template(self, name: str, state: Dict[str, Any]) -> str:
        """Render template with current state variables"""

        # Prepare template variables
        template_vars = {}

        # Add all state variables
        template_vars.update(state)

        # Create and render template
        template = self.templates[name]
        jinja_template = self.jinja_env.from_string(template.content)
        return jinja_template.render(**template_vars)


# Global instance
prompt_mgmt = PromptManager("config"+os.sep+"prompts.yaml")