Spaces:
Build error
Build error
| import re | |
| import yaml | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Set, Optional | |
| class PromptTemplate: | |
| """ | |
| A template class for managing and validating LLM prompts. | |
| This class handles: | |
| - Storing system and user prompts | |
| - Validating required template variables | |
| - Formatting prompts with provided variables | |
| Attributes: | |
| system_prompt (str): The system-level instructions for the LLM | |
| user_template (str): Template string with variables in {variable} format | |
| """ | |
| system_prompt: str | |
| user_template: str | |
| def __post_init__(self): | |
| """Initialize the set of required variables from the template.""" | |
| self.required_variables: Set[str] = self._get_required_variables() | |
| def _get_required_variables(self) -> set: | |
| """ | |
| Extract required variables from the template using regex. | |
| Returns: | |
| set: Set of variable names found in the template | |
| Example: | |
| Template "Write about {topic} in {style}" returns {'topic', 'style'} | |
| """ | |
| return set(re.findall(r'\{(\w+)\}', self.user_template)) | |
| def _validate_variables(self, provided_vars: Dict): | |
| """ | |
| Ensure all required template variables are provided. | |
| Args: | |
| provided_vars: Dictionary of variable names and values | |
| Raises: | |
| ValueError: If any required variables are missing | |
| """ | |
| provided_keys = set(provided_vars.keys()) | |
| missing_vars = self.required_variables - provided_keys | |
| if missing_vars: | |
| error_msg = ( | |
| f"\nPrompt Template Error:\n" | |
| f"Missing required variables: {', '.join(missing_vars)}\n" | |
| f"Template requires: {', '.join(self.required_variables)}\n" | |
| f"You provided: {', '.join(provided_keys)}\n" | |
| f"Template string: '{self.user_template}'" | |
| ) | |
| raise ValueError(error_msg) | |
| def format(self, **kwargs) -> List[Dict[str, str]]: | |
| """ | |
| Format the prompt template with provided variables. | |
| Args: | |
| **kwargs: Key-value pairs for template variables | |
| Returns: | |
| List[Dict[str, str]]: Formatted messages ready for LLM API | |
| Example: | |
| template.format(topic="AI", style="academic") | |
| """ | |
| self._validate_variables(kwargs) | |
| try: | |
| formatted_user_message = self.user_template.format(**kwargs) | |
| except Exception as e: | |
| raise ValueError(f"Error formatting template: {str(e)}") | |
| return [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": formatted_user_message} | |
| ] | |
| def load_prompt(yaml_path: str, version: str = None) -> tuple[PromptTemplate, dict]: | |
| """ | |
| Load prompt configuration from YAML file. | |
| Args: | |
| yaml_path: Path to YAML configuration file | |
| version: Specific version to load (defaults to 'current_version') | |
| Returns: | |
| tuple: (PromptTemplate instance, generation parameters dictionary) | |
| Example: | |
| prompt, params = load_prompt('prompts.yaml', version='v2') | |
| """ | |
| with open(yaml_path, 'r') as f: | |
| data = yaml.safe_load(f) | |
| # Use specified version or fall back to current_version | |
| version_to_use = version or data.get('current_version') | |
| if version_to_use not in data: | |
| raise KeyError(f"Version '{version_to_use}' not found in {yaml_path}") | |
| version_data = data[version_to_use] | |
| prompt = PromptTemplate( | |
| system_prompt=version_data['system_prompt'], | |
| user_template=version_data['user_template'] | |
| ) | |
| generation_params = version_data.get('generation_params', {}) | |
| return prompt, generation_params |