Update ControllerAtomicFlow.py
Browse files- ControllerAtomicFlow.py +5 -6
ControllerAtomicFlow.py
CHANGED
|
@@ -96,21 +96,19 @@ class ControllerAtomicFlow(OpenAIChatAtomicFlow):
|
|
| 96 |
return ret
|
| 97 |
except FileNotFoundError:
|
| 98 |
return 'There is no function available yet.'
|
| 99 |
-
|
| 100 |
def _get_plan(self):
|
| 101 |
try:
|
| 102 |
with open(self.plan_file_location, 'r') as file:
|
| 103 |
return file.read()
|
| 104 |
except FileNotFoundError:
|
| 105 |
return "There is no plan yet"
|
| 106 |
-
|
| 107 |
-
def
|
| 108 |
if 'goal' in input_data:
|
| 109 |
input_data['goal'] += self.hint_for_model
|
| 110 |
if 'human_feedback' in input_data:
|
| 111 |
input_data['human_feedback'] += self.hint_for_model
|
| 112 |
-
|
| 113 |
-
# self.system_message_prompt_template.template
|
| 114 |
plan_to_append = self._get_plan()
|
| 115 |
function_signatures_to_append = self._get_library_function_signatures()
|
| 116 |
self.system_message_prompt_template.template = \
|
|
@@ -120,8 +118,9 @@ class ControllerAtomicFlow(OpenAIChatAtomicFlow):
|
|
| 120 |
+ plan_to_append + "\n\n" + f"Make sure the plan your write is at {self.plan_file_location}\n" \
|
| 121 |
+ f"Make sure the code you call the code writer to write is at {self.code_file_location}"
|
| 122 |
|
|
|
|
|
|
|
| 123 |
api_output = super().run(input_data)["api_output"].strip()
|
| 124 |
-
|
| 125 |
try:
|
| 126 |
response = json.loads(api_output)
|
| 127 |
return response
|
|
|
|
| 96 |
return ret
|
| 97 |
except FileNotFoundError:
|
| 98 |
return 'There is no function available yet.'
|
| 99 |
+
|
| 100 |
def _get_plan(self):
|
| 101 |
try:
|
| 102 |
with open(self.plan_file_location, 'r') as file:
|
| 103 |
return file.read()
|
| 104 |
except FileNotFoundError:
|
| 105 |
return "There is no plan yet"
|
| 106 |
+
|
| 107 |
+
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
|
| 108 |
if 'goal' in input_data:
|
| 109 |
input_data['goal'] += self.hint_for_model
|
| 110 |
if 'human_feedback' in input_data:
|
| 111 |
input_data['human_feedback'] += self.hint_for_model
|
|
|
|
|
|
|
| 112 |
plan_to_append = self._get_plan()
|
| 113 |
function_signatures_to_append = self._get_library_function_signatures()
|
| 114 |
self.system_message_prompt_template.template = \
|
|
|
|
| 118 |
+ plan_to_append + "\n\n" + f"Make sure the plan your write is at {self.plan_file_location}\n" \
|
| 119 |
+ f"Make sure the code you call the code writer to write is at {self.code_file_location}"
|
| 120 |
|
| 121 |
+
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 122 |
+
self._update_prompts_and_input(input_data)
|
| 123 |
api_output = super().run(input_data)["api_output"].strip()
|
|
|
|
| 124 |
try:
|
| 125 |
response = json.loads(api_output)
|
| 126 |
return response
|