Spaces:
Paused
Paused
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List, Optional, Tuple, TypedDict | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import StandardCallbackDynamicParams | |
| class PromptManagementClient(TypedDict): | |
| prompt_id: str | |
| prompt_template: List[AllMessageValues] | |
| prompt_template_model: Optional[str] | |
| prompt_template_optional_params: Optional[Dict[str, Any]] | |
| completed_messages: Optional[List[AllMessageValues]] | |
| class PromptManagementBase(ABC): | |
| def integration_name(self) -> str: | |
| pass | |
| def should_run_prompt_management( | |
| self, | |
| prompt_id: str, | |
| dynamic_callback_params: StandardCallbackDynamicParams, | |
| ) -> bool: | |
| pass | |
| def _compile_prompt_helper( | |
| self, | |
| prompt_id: str, | |
| prompt_variables: Optional[dict], | |
| dynamic_callback_params: StandardCallbackDynamicParams, | |
| ) -> PromptManagementClient: | |
| pass | |
| def merge_messages( | |
| self, | |
| prompt_template: List[AllMessageValues], | |
| client_messages: List[AllMessageValues], | |
| ) -> List[AllMessageValues]: | |
| return prompt_template + client_messages | |
| def compile_prompt( | |
| self, | |
| prompt_id: str, | |
| prompt_variables: Optional[dict], | |
| client_messages: List[AllMessageValues], | |
| dynamic_callback_params: StandardCallbackDynamicParams, | |
| ) -> PromptManagementClient: | |
| compiled_prompt_client = self._compile_prompt_helper( | |
| prompt_id=prompt_id, | |
| prompt_variables=prompt_variables, | |
| dynamic_callback_params=dynamic_callback_params, | |
| ) | |
| try: | |
| messages = compiled_prompt_client["prompt_template"] + client_messages | |
| except Exception as e: | |
| raise ValueError( | |
| f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}" | |
| ) | |
| compiled_prompt_client["completed_messages"] = messages | |
| return compiled_prompt_client | |
| def _get_model_from_prompt( | |
| self, prompt_management_client: PromptManagementClient, model: str | |
| ) -> str: | |
| if prompt_management_client["prompt_template_model"] is not None: | |
| return prompt_management_client["prompt_template_model"] | |
| else: | |
| return model.replace("{}/".format(self.integration_name), "") | |
| def get_chat_completion_prompt( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| non_default_params: dict, | |
| prompt_id: Optional[str], | |
| prompt_variables: Optional[dict], | |
| dynamic_callback_params: StandardCallbackDynamicParams, | |
| ) -> Tuple[str, List[AllMessageValues], dict]: | |
| if prompt_id is None: | |
| raise ValueError("prompt_id is required for Prompt Management Base class") | |
| if not self.should_run_prompt_management( | |
| prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params | |
| ): | |
| return model, messages, non_default_params | |
| prompt_template = self.compile_prompt( | |
| prompt_id=prompt_id, | |
| prompt_variables=prompt_variables, | |
| client_messages=messages, | |
| dynamic_callback_params=dynamic_callback_params, | |
| ) | |
| completed_messages = prompt_template["completed_messages"] or messages | |
| prompt_template_optional_params = ( | |
| prompt_template["prompt_template_optional_params"] or {} | |
| ) | |
| updated_non_default_params = { | |
| **non_default_params, | |
| **prompt_template_optional_params, | |
| } | |
| model = self._get_model_from_prompt( | |
| prompt_management_client=prompt_template, model=model | |
| ) | |
| return model, completed_messages, updated_non_default_params | |