|
|
from copy import deepcopy |
|
|
import hydra |
|
|
|
|
|
import time |
|
|
|
|
|
from typing import Dict, Optional, Any |
|
|
|
|
|
from flows.base_flows import AtomicFlow |
|
|
from flows.datasets import GenericDemonstrationsDataset |
|
|
|
|
|
from flows.utils import logging |
|
|
from flows.messages.flow_message import UpdateMessage_ChatMessage |
|
|
|
|
|
from flows.prompt_template import JinjaPrompt |
|
|
|
|
|
from backends.llm_lite import LiteLLMBackend |
|
|
|
|
|
log = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class OpenAIChatAtomicFlow(AtomicFlow): |
|
|
REQUIRED_KEYS_CONFIG = ["backend"] |
|
|
|
|
|
SUPPORTS_CACHING: bool = True |
|
|
|
|
|
system_message_prompt_template: JinjaPrompt |
|
|
human_message_prompt_template: JinjaPrompt |
|
|
|
|
|
backend: LiteLLMBackend |
|
|
init_human_message_prompt_template: Optional[JinjaPrompt] = None |
|
|
demonstrations: GenericDemonstrationsDataset = None |
|
|
demonstrations_k: Optional[int] = None |
|
|
demonstrations_response_prompt_template: str = None |
|
|
|
|
|
def __init__(self, |
|
|
system_message_prompt_template, |
|
|
human_message_prompt_template, |
|
|
init_human_message_prompt_template, |
|
|
backend, |
|
|
demonstrations_response_prompt_template=None, |
|
|
demonstrations=None, |
|
|
**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.system_message_prompt_template = system_message_prompt_template |
|
|
self.human_message_prompt_template = human_message_prompt_template |
|
|
self.init_human_message_prompt_template = init_human_message_prompt_template |
|
|
self.demonstrations_response_prompt_template = demonstrations_response_prompt_template |
|
|
self.demonstrations = demonstrations |
|
|
self.demonstrations_k = self.flow_config.get("demonstrations_k", None) |
|
|
self.backend = backend |
|
|
assert self.flow_config["name"] not in [ |
|
|
"system", |
|
|
"user", |
|
|
"assistant", |
|
|
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
|
|
|
|
|
def set_up_flow_state(self): |
|
|
super().set_up_flow_state() |
|
|
self.flow_state["previous_messages"] = [] |
|
|
|
|
|
@classmethod |
|
|
def _set_up_prompts(cls, config): |
|
|
kwargs = {} |
|
|
|
|
|
kwargs["system_message_prompt_template"] = \ |
|
|
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
|
|
kwargs["init_human_message_prompt_template"] = \ |
|
|
hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial") |
|
|
kwargs["human_message_prompt_template"] = \ |
|
|
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
|
|
|
|
|
if "demonstrations_response_prompt_template" in config: |
|
|
kwargs["demonstrations_response_prompt_template"] = \ |
|
|
hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial") |
|
|
kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations']) |
|
|
|
|
|
return kwargs |
|
|
|
|
|
@classmethod |
|
|
def _set_up_backend(cls, config): |
|
|
kwargs = {} |
|
|
|
|
|
kwargs["backend"] = \ |
|
|
hydra.utils.instantiate(config['backend'], _convert_="partial") |
|
|
|
|
|
return kwargs |
|
|
|
|
|
@classmethod |
|
|
def instantiate_from_config(cls, config): |
|
|
flow_config = deepcopy(config) |
|
|
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
|
|
|
|
|
|
|
return cls(**kwargs) |
|
|
|
|
|
def _is_conversation_initialized(self): |
|
|
if len(self.flow_state["previous_messages"]) > 0: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_interface_description(self): |
|
|
if self._is_conversation_initialized(): |
|
|
|
|
|
return {"input": self.flow_config["input_interface_initialized"], |
|
|
"output": self.flow_config["output_interface"]} |
|
|
else: |
|
|
return {"input": self.flow_config["input_interface_non_initialized"], |
|
|
"output": self.flow_config["output_interface"]} |
|
|
|
|
|
@staticmethod |
|
|
def _get_message(prompt_template, input_data: Dict[str, Any]): |
|
|
template_kwargs = {} |
|
|
for input_variable in prompt_template.input_variables: |
|
|
template_kwargs[input_variable] = input_data[input_variable] |
|
|
msg_content = prompt_template.format(**template_kwargs) |
|
|
return msg_content |
|
|
|
|
|
def _get_demonstration_query_message_content(self, sample_data: Dict): |
|
|
input_variables = self.init_human_message_prompt_template.input_variables |
|
|
return self.init_human_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
|
|
def _get_demonstration_response_message_content(self, sample_data: Dict): |
|
|
input_variables = self.demonstrations_response_prompt_template.input_variables |
|
|
return self.demonstrations_response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
|
|
|
|
|
def _add_demonstrations(self): |
|
|
if self.demonstrations is not None: |
|
|
demonstrations = self.demonstrations |
|
|
|
|
|
c = 0 |
|
|
for example in demonstrations: |
|
|
if self.demonstrations_k is not None and c >= self.demonstrations_k: |
|
|
break |
|
|
c += 1 |
|
|
query = self._get_demonstration_query_message_content(example) |
|
|
response = self._get_demonstration_response_message_content(example) |
|
|
|
|
|
self._state_update_add_chat_message(content=query, |
|
|
role=self.flow_config["user_name"]) |
|
|
|
|
|
self._state_update_add_chat_message(content=response, |
|
|
role=self.flow_config["assistant_name"]) |
|
|
|
|
|
def _state_update_add_chat_message(self, |
|
|
role: str, |
|
|
content: str) -> None: |
|
|
|
|
|
|
|
|
acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]] |
|
|
if role in acceptable_roles: |
|
|
self.flow_state["previous_messages"].append({"role": role , "content": content}) |
|
|
|
|
|
else: |
|
|
raise Exception(f"Invalid role: `{role}`.\n" |
|
|
f"Role should be one of: " |
|
|
f"`{acceptable_roles}`, ") |
|
|
|
|
|
|
|
|
|
|
|
chat_message = UpdateMessage_ChatMessage( |
|
|
created_by=self.flow_config["name"], |
|
|
updated_flow=self.flow_config["name"], |
|
|
role=role, |
|
|
content=content, |
|
|
) |
|
|
self._log_message(chat_message) |
|
|
|
|
|
def _get_previous_messages(self): |
|
|
all_messages = self.flow_state["previous_messages"] |
|
|
first_k = self.flow_config["previous_messages"]["first_k"] |
|
|
last_k = self.flow_config["previous_messages"]["last_k"] |
|
|
|
|
|
if not first_k and not last_k: |
|
|
return all_messages |
|
|
elif first_k and last_k: |
|
|
return all_messages[:first_k] + all_messages[-last_k:] |
|
|
elif first_k: |
|
|
return all_messages[:first_k] |
|
|
return all_messages[-last_k:] |
|
|
|
|
|
def _call(self): |
|
|
|
|
|
messages = self._get_previous_messages() |
|
|
_success = False |
|
|
attempts = 1 |
|
|
error = None |
|
|
response = None |
|
|
while attempts <= self.flow_config['n_api_retries']: |
|
|
try: |
|
|
response = self.backend(messages=messages,mock_response=False) |
|
|
response = [ answer["content"] for answer in response] |
|
|
_success = True |
|
|
break |
|
|
except Exception as e: |
|
|
log.error( |
|
|
f"Error {attempts} in calling backend: {e}. " |
|
|
f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attempts += 1 |
|
|
time.sleep(self.flow_config['wait_time_between_retries']) |
|
|
error = e |
|
|
|
|
|
if not _success: |
|
|
raise error |
|
|
|
|
|
return response |
|
|
|
|
|
def _initialize_conversation(self, input_data: Dict[str, Any]): |
|
|
|
|
|
system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
|
|
|
|
|
self._state_update_add_chat_message(content=system_message_content, |
|
|
role=self.flow_config["system_name"]) |
|
|
|
|
|
|
|
|
self._add_demonstrations() |
|
|
|
|
|
def _process_input(self, input_data: Dict[str, Any]): |
|
|
if self._is_conversation_initialized(): |
|
|
|
|
|
user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
|
|
|
else: |
|
|
|
|
|
self._initialize_conversation(input_data) |
|
|
if getattr(self, "init_human_message_prompt_template", None) is not None: |
|
|
|
|
|
user_message_content = self._get_message(self.init_human_message_prompt_template, input_data) |
|
|
else: |
|
|
user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
|
|
|
self._state_update_add_chat_message(role=self.flow_config["user_name"], |
|
|
content=user_message_content) |
|
|
|
|
|
def run(self, |
|
|
input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
self._process_input(input_data) |
|
|
|
|
|
|
|
|
response = self._call() |
|
|
|
|
|
|
|
|
for answer in response: |
|
|
self._state_update_add_chat_message( |
|
|
role=self.flow_config["assistant_name"], |
|
|
content=answer |
|
|
) |
|
|
|
|
|
return {"api_output": response} |
|
|
|