ChatFlowModule / OpenAIChatAtomicFlow.py
nbaldwin's picture
new backend
bdc9b47
raw
history blame
10.5 kB
from copy import deepcopy
import hydra
import time
from typing import Dict, Optional, Any
from flows.base_flows import AtomicFlow
from flows.datasets import GenericDemonstrationsDataset
from flows.utils import logging
from flows.messages.flow_message import UpdateMessage_ChatMessage
from flows.prompt_template import JinjaPrompt
from backends.llm_lite import LiteLLMBackend
log = logging.get_logger(__name__)
class OpenAIChatAtomicFlow(AtomicFlow):
REQUIRED_KEYS_CONFIG = ["backend"]
SUPPORTS_CACHING: bool = True
system_message_prompt_template: JinjaPrompt
human_message_prompt_template: JinjaPrompt
backend: LiteLLMBackend
init_human_message_prompt_template: Optional[JinjaPrompt] = None
demonstrations: GenericDemonstrationsDataset = None
demonstrations_k: Optional[int] = None
demonstrations_response_prompt_template: str = None
def __init__(self,
system_message_prompt_template,
human_message_prompt_template,
init_human_message_prompt_template,
backend,
demonstrations_response_prompt_template=None,
demonstrations=None,
**kwargs):
super().__init__(**kwargs)
self.system_message_prompt_template = system_message_prompt_template
self.human_message_prompt_template = human_message_prompt_template
self.init_human_message_prompt_template = init_human_message_prompt_template
self.demonstrations_response_prompt_template = demonstrations_response_prompt_template
self.demonstrations = demonstrations
self.demonstrations_k = self.flow_config.get("demonstrations_k", None)
self.backend = backend
assert self.flow_config["name"] not in [
"system",
"user",
"assistant",
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'"
def set_up_flow_state(self):
super().set_up_flow_state()
self.flow_state["previous_messages"] = []
@classmethod
def _set_up_prompts(cls, config):
kwargs = {}
kwargs["system_message_prompt_template"] = \
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
kwargs["init_human_message_prompt_template"] = \
hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial")
kwargs["human_message_prompt_template"] = \
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
if "demonstrations_response_prompt_template" in config:
kwargs["demonstrations_response_prompt_template"] = \
hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial")
kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations'])
return kwargs
@classmethod
def _set_up_backend(cls, config):
kwargs = {}
kwargs["backend"] = \
hydra.utils.instantiate(config['backend'], _convert_="partial")
return kwargs
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _is_conversation_initialized(self):
if len(self.flow_state["previous_messages"]) > 0:
return True
return False
def get_interface_description(self):
if self._is_conversation_initialized():
return {"input": self.flow_config["input_interface_initialized"],
"output": self.flow_config["output_interface"]}
else:
return {"input": self.flow_config["input_interface_non_initialized"],
"output": self.flow_config["output_interface"]}
@staticmethod
def _get_message(prompt_template, input_data: Dict[str, Any]):
template_kwargs = {}
for input_variable in prompt_template.input_variables:
template_kwargs[input_variable] = input_data[input_variable]
msg_content = prompt_template.format(**template_kwargs)
return msg_content
def _get_demonstration_query_message_content(self, sample_data: Dict):
input_variables = self.init_human_message_prompt_template.input_variables
return self.init_human_message_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _get_demonstration_response_message_content(self, sample_data: Dict):
input_variables = self.demonstrations_response_prompt_template.input_variables
return self.demonstrations_response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
def _add_demonstrations(self):
if self.demonstrations is not None:
demonstrations = self.demonstrations
c = 0
for example in demonstrations:
if self.demonstrations_k is not None and c >= self.demonstrations_k:
break
c += 1
query = self._get_demonstration_query_message_content(example)
response = self._get_demonstration_response_message_content(example)
self._state_update_add_chat_message(content=query,
role=self.flow_config["user_name"])
self._state_update_add_chat_message(content=response,
role=self.flow_config["assistant_name"])
def _state_update_add_chat_message(self,
role: str,
content: str) -> None:
acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]]
if role in acceptable_roles:
self.flow_state["previous_messages"].append({"role": role , "content": content})
else:
raise Exception(f"Invalid role: `{role}`.\n"
f"Role should be one of: "
f"`{acceptable_roles}`, ")
# Log the update to the flow messages list
chat_message = UpdateMessage_ChatMessage(
created_by=self.flow_config["name"],
updated_flow=self.flow_config["name"],
role=role,
content=content,
)
self._log_message(chat_message)
def _get_previous_messages(self):
all_messages = self.flow_state["previous_messages"]
first_k = self.flow_config["previous_messages"]["first_k"]
last_k = self.flow_config["previous_messages"]["last_k"]
if not first_k and not last_k:
return all_messages
elif first_k and last_k:
return all_messages[:first_k] + all_messages[-last_k:]
elif first_k:
return all_messages[:first_k]
return all_messages[-last_k:]
def _call(self):
messages = self._get_previous_messages()
_success = False
attempts = 1
error = None
response = None
while attempts <= self.flow_config['n_api_retries']:
try:
response = self.backend(messages=messages,mock_response=False) #set mock_response to True when debugging (fake API request)
response = [ answer["content"] for answer in response] # because n in the generation parameters can be > 1
_success = True
break
except Exception as e:
log.error(
f"Error {attempts} in calling backend: {e}. "
f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..."
)
# log.error(
# f"The API call raised an exception with the following arguments: "
# f"\n{self.flow_state['history'].to_string()}"
# ) # ToDo: Make this message more user-friendly
attempts += 1
time.sleep(self.flow_config['wait_time_between_retries'])
error = e
if not _success:
raise error
return response
def _initialize_conversation(self, input_data: Dict[str, Any]):
# ~~~ Add the system message ~~~
system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._state_update_add_chat_message(content=system_message_content,
role=self.flow_config["system_name"])
# # ~~~ Add the demonstration query-response tuples (if any) ~~~
self._add_demonstrations()
def _process_input(self, input_data: Dict[str, Any]):
if self._is_conversation_initialized():
# Construct the message using the human message prompt template
user_message_content = self._get_message(self.human_message_prompt_template, input_data)
else:
# Initialize the conversation (add the system message, and potentially the demonstrations)
self._initialize_conversation(input_data)
if getattr(self, "init_human_message_prompt_template", None) is not None:
# Construct the message using the query message prompt template
user_message_content = self._get_message(self.init_human_message_prompt_template, input_data)
else:
user_message_content = self._get_message(self.human_message_prompt_template, input_data)
self._state_update_add_chat_message(role=self.flow_config["user_name"],
content=user_message_content)
def run(self,
input_data: Dict[str, Any]) -> Dict[str, Any]:
# ~~~ Process input ~~~
self._process_input(input_data)
# ~~~ Call ~~~
response = self._call()
#loop is in case there was more than one answer (n>1 in generation parameters)
for answer in response:
self._state_update_add_chat_message(
role=self.flow_config["assistant_name"],
content=answer
)
return {"api_output": response}