File size: 5,292 Bytes
be54038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
prompts.py — Versioned prompt registry for the UK Motor Insurance IDP pipeline.

Loads prompt text from prompts.yaml so prompts can be updated, versioned, and
reviewed without touching Python source code.

Usage
-----
    registry = PromptRegistry()                         # uses active_version from YAML
    registry = PromptRegistry(version="v2")             # pin to a specific version
    registry = PromptRegistry(config_path="custom.yaml")

    system_prompt = registry.get(DocumentType.SCHEDULE)
    print(registry.active_version)   # → "v1"
    print(registry.available_versions)  # → ["v1"]
"""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional

import yaml

from schema import DocumentType

logger = logging.getLogger(__name__)

# Default path: <project_root>/config/prompts.yaml
# Resolved relative to this file's location (src/ → .. → config/)
_DEFAULT_CONFIG = Path(__file__).parent.parent / "config" / "prompts.yaml"

# Maps DocumentType enum values → YAML keys
_DOC_TYPE_TO_KEY: dict[DocumentType, str] = {
    DocumentType.SCHEDULE:          "Schedule",
    DocumentType.CERTIFICATE:       "Certificate",
    DocumentType.STATEMENT_OF_FACT: "StatementOfFact",
    DocumentType.POLICY_BOOKLET:    "PolicyBooklet",
    DocumentType.UNKNOWN:           "_generic",
}

_GENERIC_KEY = "_generic"


class PromptRegistry:
    """
    Loads versioned prompts from a YAML file and resolves them by DocumentType.

    Parameters
    ----------
    config_path : str | Path | None
        Path to the YAML file.  Defaults to ``src/prompts.yaml`` (sibling of
        this module).
    version : str | None
        Prompt version to activate (e.g. ``"v1"``, ``"v2"``).
        Defaults to the ``active_version`` key in the YAML file.
    """

    def __init__(
        self,
        config_path: Optional[str | Path] = None,
        version: Optional[str] = None,
    ) -> None:
        self._config_path = Path(config_path) if config_path else _DEFAULT_CONFIG
        self._raw = self._load_yaml()
        self._active_version = version or self._raw.get("active_version", "v1")
        self._prompts = self._resolve_version(self._active_version)

        logger.info(
            "PromptRegistry loaded: version=%s, path=%s",
            self._active_version,
            self._config_path,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    @property
    def active_version(self) -> str:
        """The currently active prompt version string."""
        return self._active_version

    @property
    def available_versions(self) -> list[str]:
        """All version keys defined in the YAML file."""
        return list(self._raw.get("prompts", {}).keys())

    def get(self, doc_type: DocumentType) -> str:
        """
        Return the system prompt for a given DocumentType.

        Falls back to the ``_generic`` prompt if the specific key is missing.
        Raises ``KeyError`` if ``_generic`` is also absent (misconfigured YAML).
        """
        key = _DOC_TYPE_TO_KEY.get(doc_type, _GENERIC_KEY)
        prompt = self._prompts.get(key) or self._prompts.get(_GENERIC_KEY)
        if not prompt:
            raise KeyError(
                f"No prompt found for DocumentType '{doc_type.value}' in version "
                f"'{self._active_version}' of {self._config_path}. "
                f"Ensure '{key}' or '{_GENERIC_KEY}' is defined."
            )
        return prompt.strip()

    def reload(self) -> None:
        """
        Hot-reload prompts from disk without restarting the process.

        Useful in long-running services when prompts.yaml is updated in place.
        """
        self._raw = self._load_yaml()
        self._prompts = self._resolve_version(self._active_version)
        logger.info("PromptRegistry reloaded from %s", self._config_path)

    def switch_version(self, version: str) -> None:
        """
        Switch the active prompt version at runtime.

        Parameters
        ----------
        version : str
            Must be a key present under ``prompts:`` in the YAML file.
        """
        self._prompts = self._resolve_version(version)
        self._active_version = version
        logger.info("PromptRegistry switched to version '%s'", version)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _load_yaml(self) -> dict:
        if not self._config_path.exists():
            raise FileNotFoundError(
                f"Prompt configuration not found: {self._config_path}"
            )
        with self._config_path.open(encoding="utf-8") as fh:
            return yaml.safe_load(fh) or {}

    def _resolve_version(self, version: str) -> dict[str, str]:
        versions = self._raw.get("prompts", {})
        if version not in versions:
            available = list(versions.keys())
            raise ValueError(
                f"Prompt version '{version}' not found in {self._config_path}. "
                f"Available versions: {available}"
            )
        return versions[version]