Tachi67 commited on
Commit
49a0727
·
1 Parent(s): d5506a4

Update CodeGeneratorAtomicFlow.py

Browse files
Files changed (1) hide show
  1. CodeGeneratorAtomicFlow.py +27 -3
CodeGeneratorAtomicFlow.py CHANGED
@@ -6,9 +6,14 @@ from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
6
 
7
 
8
  class CodeGeneratorAtomicFlow(ChatAtomicFlow):
9
- """Generates code and docstrings with given goal (from controller)"""
 
10
  def __init__(self, **kwargs):
11
  super().__init__(**kwargs)
 
 
 
 
12
  self.hint_for_model = """
13
  Make sure your response is in the following format:
14
  Response Format:
@@ -33,12 +38,31 @@ class CodeGeneratorAtomicFlow(ChatAtomicFlow):
33
  # ~~~ Instantiate flow ~~~
34
  return cls(**kwargs)
35
 
36
- def _update_input(self, input_data: Dict[str, Any]):
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if 'goal' in input_data:
38
  input_data['goal'] += self.hint_for_model
 
 
 
 
 
 
39
 
40
  def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
41
- self._update_input(input_data)
42
  api_output = super().run(input_data)["api_output"].strip()
43
  try:
44
  response = json.loads(api_output)
 
6
 
7
 
8
  class CodeGeneratorAtomicFlow(ChatAtomicFlow):
9
+ """Generates one function with docstrings to finish the given goal (from the controller).
10
+ """
11
  def __init__(self, **kwargs):
12
  super().__init__(**kwargs)
13
+ self.system_message_prompt_template = self.system_message_prompt_template.partial(
14
+ code_library_file_location="no location yet",
15
+ code_library="no code yet"
16
+ )
17
  self.hint_for_model = """
18
  Make sure your response is in the following format:
19
  Response Format:
 
38
  # ~~~ Instantiate flow ~~~
39
  return cls(**kwargs)
40
 
41
+ def _get_code_library_file(self, input_data: Dict[str, Any]):
42
+ assert "memory_files" in input_data, "memory_files not passed to CodeGeneratorAtomicFlow"
43
+ assert "code_library" in input_data['memory_files'], "code_library not in memory_files"
44
+ code_library_file_location = input_data['memory_files']['code_library']
45
+ return code_library_file_location
46
+
47
+ def _get_code_library_content(self, input_data: Dict[str, Any]):
48
+ assert "code_library" in input_data, "code_library not passed to CodeGeneratorAtomicFlow"
49
+ code_library = input_data['code_library']
50
+ if len(code_library) == 0:
51
+ code_library = "No code yet"
52
+ return code_library
53
+
54
+ def _update_prompts_and_input(self, input_data: Dict[str, Any]):
55
  if 'goal' in input_data:
56
  input_data['goal'] += self.hint_for_model
57
+ code_library_file_location = self._get_code_library_file(input_data)
58
+ code_library = self._get_code_library_content(input_data)
59
+ self.system_message_prompt_template = self.system_message_prompt_template.partial(
60
+ code_library_file_location=code_library_file_location,
61
+ code_library=code_library
62
+ )
63
 
64
  def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
65
+ self._update_prompts_and_input(input_data)
66
  api_output = super().run(input_data)["api_output"].strip()
67
  try:
68
  response = json.loads(api_output)