Spaces:
Sleeping
Sleeping
| ''' | |
| Copyright 2024 Infosys Ltd. | |
| Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), | |
| to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all copies | |
| or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |
| INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE | |
| AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, | |
| DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
| ''' | |
| import os | |
| import json | |
| # import logging | |
| from typing import List | |
| from ..language_models import AbstractLanguageModel | |
| from ..operations import GraphOfOperations, Thought | |
| from ..prompter import Prompter | |
| from ..parser import Parser | |
| from llm_explain.config.logger import CustomLogger | |
| logging = CustomLogger() | |
| class Controller: | |
| """ | |
| Controller class to manage the execution flow of the Graph of Operations, | |
| generating the Graph Reasoning State. | |
| This involves language models, graph operations, prompting, and parsing. | |
| """ | |
| def __init__( | |
| self, | |
| lm: AbstractLanguageModel, | |
| graph: GraphOfOperations, | |
| prompter: Prompter, | |
| parser: Parser, | |
| problem_parameters: dict, | |
| ) -> None: | |
| """ | |
| Initialize the Controller instance with the language model, | |
| operations graph, prompter, parser, and problem parameters. | |
| :param lm: An instance of the AbstractLanguageModel. | |
| :type lm: AbstractLanguageModel | |
| :param graph: The Graph of Operations to be executed. | |
| :type graph: OperationsGraph | |
| :param prompter: An instance of the Prompter class, used to generate prompts. | |
| :type prompter: Prompter | |
| :param parser: An instance of the Parser class, used to parse responses. | |
| :type parser: Parser | |
| :param problem_parameters: Initial parameters/state of the problem. | |
| :type problem_parameters: dict | |
| """ | |
| self.logger = CustomLogger() | |
| self.lm = lm | |
| self.graph = graph | |
| self.prompter = prompter | |
| self.parser = parser | |
| self.problem_parameters = problem_parameters | |
| self.run_executed = False | |
| def run(self) -> None: | |
| """ | |
| Run the controller and execute the operations from the Graph of | |
| Operations based on their readiness. | |
| Ensures the program is in a valid state before execution. | |
| :raises AssertionError: If the Graph of Operation has no roots. | |
| :raises AssertionError: If the successor of an operation is not in the Graph of Operations. | |
| """ | |
| # self.logger.debug("Checking that the program is in a valid state") | |
| assert self.graph.roots is not None, "The operations graph has no root" | |
| # self.logger.debug("The program is in a valid state") | |
| execution_queue = [ | |
| operation | |
| for operation in self.graph.operations | |
| if operation.can_be_executed() | |
| ] | |
| # self.logger.info(execution_queue) | |
| while len(execution_queue) > 0: | |
| current_operation = execution_queue.pop(0) | |
| # self.logger.info("Executing operation %s", current_operation.operation_type) | |
| current_operation.execute( | |
| self.lm, self.prompter, self.parser, **self.problem_parameters | |
| ) | |
| # self.logger.debug("Operation %s executed", current_operation.operation_type) | |
| for operation in current_operation.successors: | |
| assert ( | |
| operation in self.graph.operations | |
| ), "The successor of an operation is not in the operations graph" | |
| if operation.can_be_executed(): | |
| execution_queue.append(operation) | |
| # self.logger.info("All operations executed") | |
| self.run_executed = True | |
| def get_final_thoughts(self) -> List[List[Thought]]: | |
| """ | |
| Retrieve the final thoughts after all operations have been executed. | |
| :return: List of thoughts for each operation in the graph's leaves. | |
| :rtype: List[List[Thought]] | |
| :raises AssertionError: If the `run` method hasn't been executed yet. | |
| """ | |
| assert self.run_executed, "The run method has not been executed" | |
| return [operation.get_thoughts() for operation in self.graph.leaves] | |
| def output_graph(self, path: str) -> None: | |
| """ | |
| Serialize the state and results of the operations graph to a JSON file. | |
| :param path: The path to the output file. | |
| :type path: str | |
| """ | |
| output = [] | |
| for operation in self.graph.operations: | |
| operation_serialized = { | |
| "operation": operation.operation_type.name, | |
| "thoughts": [thought.state for thought in operation.get_thoughts()], | |
| } | |
| if any([thought.scored for thought in operation.get_thoughts()]): | |
| operation_serialized["scored"] = [ | |
| thought.scored for thought in operation.get_thoughts() | |
| ] | |
| operation_serialized["scores"] = [ | |
| thought.score for thought in operation.get_thoughts() | |
| ] | |
| if any([thought.validated for thought in operation.get_thoughts()]): | |
| operation_serialized["validated"] = [ | |
| thought.validated for thought in operation.get_thoughts() | |
| ] | |
| operation_serialized["validity"] = [ | |
| thought.valid for thought in operation.get_thoughts() | |
| ] | |
| if any( | |
| [ | |
| thought.compared_to_ground_truth | |
| for thought in operation.get_thoughts() | |
| ] | |
| ): | |
| operation_serialized["compared_to_ground_truth"] = [ | |
| thought.compared_to_ground_truth | |
| for thought in operation.get_thoughts() | |
| ] | |
| operation_serialized["problem_solved"] = [ | |
| thought.solved for thought in operation.get_thoughts() | |
| ] | |
| output.append(operation_serialized) | |
| output.append( | |
| { | |
| "prompt_tokens": self.lm.prompt_tokens, | |
| "completion_tokens": self.lm.completion_tokens, | |
| "cost": self.lm.cost, | |
| } | |
| ) | |
| with open(path, "w") as file: | |
| file.write(json.dumps(output, indent=2)) | |
| def format_graph(self, source: str): | |
| def count_unique_matches(l1, l2): | |
| l1_set = set(l1) # Convert l1 to a set for unique elements | |
| l2_set = set(l2) # Convert l2 to a set for unique elements | |
| matches = l1_set & l2_set # Find the intersection | |
| return len(matches) | |
| import copy | |
| with open(source, "r") as file: | |
| data = json.load(file) | |
| data_new = copy.deepcopy(data) | |
| global_thoughts = [] | |
| global_thoughts_num = [] | |
| data_thoughts = {} | |
| # generate | |
| l = [] | |
| for i in range(len(data[0]['thoughts'])): | |
| l.append(data[0]['thoughts'][i]['current']) | |
| if data[0]['thoughts'][i]['current'] not in global_thoughts: | |
| global_thoughts.append(data[0]['thoughts'][i]['current']) | |
| global_thoughts_num.append(f"thought_{i+1}") | |
| data_new[0]['thoughts'][i]['current'] = f"thought_{i+1}" | |
| data_new[0]['thoughts'][i]['score'] = data_new[1]['scores'][i] | |
| data_thoughts[f"thought_{i+1}"] = data[0]['thoughts'][i]['current'] | |
| # score | |
| for i in range(len(data[1]['thoughts'])): | |
| data_new[1]['thoughts'][i]['current'] = f"thought_{i+1}" | |
| # keep_best_n | |
| prev_thoughts = {} | |
| for i in range(len(data[2]['thoughts'])): | |
| if data[2]['thoughts'][i]['current'] in l: | |
| data_new[2]['thoughts'][i]['current'] = f"thought_{l.index(data[2]['thoughts'][i]['current'])+1}" | |
| data_new[2]['thoughts'][i]['score'] = data_new[2]['scores'][i] | |
| # data_thoughts[f"thought_{l.index(data[2]['thoughts'][i]['current'])+1}"] = data[2]['thoughts'][i]['current'] | |
| elif data[2]['thoughts'][i]['current'] in global_thoughts: | |
| data_new[2]['thoughts'][i]['current'] = f"thought_{global_thoughts_num[global_thoughts.index(data[2]['thoughts'][i]['current'])]}" | |
| data_new[2]['thoughts'][i]['score'] = data_new[2]['scores'][i] | |
| # data_thoughts[f"thought_{global_thoughts_num[global_thoughts.index(data[2]['thoughts'][i]['current'])]}"] = data[2]['thoughts'][i]['current'] | |
| prev_thoughts[str(i)] = data[2]['thoughts'][i]['current'] | |
| # aggregate | |
| len1 = len(data[0]['thoughts']) | |
| l, l3 = [], [] | |
| for i in range(len(data[3]['thoughts'])): | |
| l.append(data[3]['thoughts'][i]['current']) | |
| temp = [] | |
| for j in range(len(data[2]['thoughts'])): | |
| temp.append(count_unique_matches(data[2]['thoughts'][j]['current'].split(), data[3]['thoughts'][i]['current'].split())) | |
| val = data_new[2]['thoughts'][temp.index(max(temp))]['current'] | |
| if data[3]['thoughts'][i]['current'] not in global_thoughts: | |
| global_thoughts.append(data[3]['thoughts'][i]['current']) | |
| global_thoughts_num.append(f"aggregate_{val}") | |
| data_new[3]['thoughts'][i]['current'] = f"aggregate_{val}" | |
| data_new[3]['thoughts'][i]['score'] = data_new[4]['scores'][i] | |
| # data_thoughts[f"{val}_thought_{i+1+len1}"] = data[3]['thoughts'][i]['current'] | |
| l3.append(f"aggregate_{val}") | |
| # score | |
| data_new[4]['thoughts'] = data_new[3]['thoughts'] | |
| # keep_best_n | |
| for i in range(len(data[5]['thoughts'])): | |
| if data[5]['thoughts'][i]['current'] in l: | |
| data_new[5]['thoughts'][i]['current'] = l3[l.index(data[5]['thoughts'][0]['current'])] | |
| data_new[5]['thoughts'][i]['score'] = data_new[5]['scores'][i] | |
| data_thoughts[l3[l.index(data[5]['thoughts'][0]['current'])]] = data[5]['thoughts'][i]['current'] | |
| # data_thoughts['final_thought'] = data[5]['thoughts'][i]['current'] | |
| elif data[5]['thoughts'][i]['current'] in global_thoughts: | |
| data_new[5]['thoughts'][i]['current'] = global_thoughts_num[global_thoughts.index(data[5]['thoughts'][i]['current'])] | |
| data_new[5]['thoughts'][i]['score'] = data_new[5]['scores'][i] | |
| data_thoughts[global_thoughts_num[global_thoughts.index(data[5]['thoughts'][i]['current'])]] = data[5]['thoughts'][i]['current'] | |
| # data_thoughts['final_thought'] = data[5]['thoughts'][i]['current'] | |
| # data_new[5]['thoughts'][i]['current'] = 'final_thought' | |
| for i in range(len(data_new)): | |
| if i >= len(data_new): | |
| break | |
| if 'operation' in data_new[i] and data_new[i]['operation'] == 'score': | |
| del data_new[i] | |
| if 'operation' in data_new[i] and data_new[i]['operation'] == 'keep_best_n': | |
| del data_new[i]['scored'] | |
| del data_new[i]['scores'] | |
| os.remove(source) | |
| return data_new, data_thoughts | |