import jinja2 from flows.base_flows import AtomicFlow from flows.utils import logging from flows.utils import general_helpers from typing import Dict,Any,Optional,List from flows.prompt_template import JinjaPrompt from copy import deepcopy import numpy as np import os import hydra log = logging.get_logger(__name__) class DemonstrationsAtomicFlow(AtomicFlow): demonstrations_k: Optional[int] = None query_prompt_template: JinjaPrompt response_prompt_template: JinjaPrompt params: Dict def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs): super().__init__(**kwargs) self.params = params self.data = data self.demonstrations_k = self.params.get("demonstrations_k", None) #typically the query would be what the user (human) asks the assistant (LLM) self.query_prompt_template = query_prompt_template #typically the response would be what the assistant (LLM) should answer to the user (human) self.response_prompt_template = response_prompt_template if self.data is None: self._load_data() @classmethod def _set_up_prompts(cls, config): kwargs = {} kwargs["query_prompt_template"] = \ hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial") kwargs["response_prompt_template"] = \ hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial") return kwargs @classmethod def instantiate_from_config(cls, config): flow_config = deepcopy(config) kwargs = {"flow_config": flow_config} # ~~~ Set up prompts ~~~ kwargs.update(cls._set_up_prompts(flow_config)) kwargs.update({"params": flow_config["params"]}) kwargs.update({"data": flow_config["data"]}) # ~~~ Instantiate flow ~~~ return cls(**kwargs) def _get_query_message_content(self, sample_data: Dict): input_variables = self.query_prompt_template.input_variables return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables}) def _get_response_message_content(self, sample_data: Dict): input_variables = self.response_prompt_template.input_variables return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) def _get_io_pair(self, idx): dp = self.data[idx] query_data = dp["query_data"] response_data = dp["response_data"] query = self._get_query_message_content(query_data) response = self._get_response_message_content(response_data) return {"idx": idx, "query": query,"response": response} def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]: demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data) io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)] return io_pairs def _load_data(self): demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl") self.data = general_helpers.read_jsonlines(demonstrations_file) if self.params.get("ids_to_keep", False): if isinstance(self.params["ids_to_keep"], str): ids_to_keep = set(self.params["ids_to_keep"].split(",")) else: ids_to_keep = set(self.params["ids_to_keep"]) self.data = [d for d in self.data if d["id"] in ids_to_keep] log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"]) def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: return {**input_data,**{"demonstrations": self._get_io_pairs(input_data=input_data)}}