Update MemoryReadingAtomicFlow.py
Browse files
MemoryReadingAtomicFlow.py
CHANGED
|
@@ -15,13 +15,15 @@ class MemoryReadingAtomicFlow(AtomicFlow):
|
|
| 15 |
{"plan": "examples/JARVIS/plan.txt"}
|
| 16 |
"""
|
| 17 |
|
| 18 |
-
def __init__(self):
|
| 19 |
-
super().__init__()
|
| 20 |
self.supported_mem_name = ["plan", "logs", "code_library"]
|
| 21 |
|
| 22 |
def _check_input_data(self, input_data: Dict[str, Any]):
|
| 23 |
"""input data sanity check"""
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
|
| 26 |
f"supported names are: {self.supported_mem_name}")
|
| 27 |
assert os.path.exists(mem_path), f"{mem_path} does not exist."
|
|
@@ -50,7 +52,7 @@ class MemoryReadingAtomicFlow(AtomicFlow):
|
|
| 50 |
input_data: Dict[str, Any]):
|
| 51 |
self._check_input_data(input_data)
|
| 52 |
response = {}
|
| 53 |
-
for mem_name, mem_path in input_data.items():
|
| 54 |
if mem_name in ['plan', 'logs']:
|
| 55 |
response[mem_name] = self._read_text(mem_path)
|
| 56 |
elif mem_name == 'code_library' and mem_path.endswith('.py'):
|
|
|
|
| 15 |
{"plan": "examples/JARVIS/plan.txt"}
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
def __init__(self, **kwargs):
|
| 19 |
+
super().__init__(**kwargs)
|
| 20 |
self.supported_mem_name = ["plan", "logs", "code_library"]
|
| 21 |
|
| 22 |
def _check_input_data(self, input_data: Dict[str, Any]):
|
| 23 |
"""input data sanity check"""
|
| 24 |
+
assert "memory_files" in input_data, "memory_files not passed to MemoryReadingAtomicFlow"
|
| 25 |
+
|
| 26 |
+
for mem_name, mem_path in input_data["memory_files"].items():
|
| 27 |
assert mem_name in self.supported_mem_name, (f"{mem_name} is not supported in MemoryReadingAtomicFlow, "
|
| 28 |
f"supported names are: {self.supported_mem_name}")
|
| 29 |
assert os.path.exists(mem_path), f"{mem_path} does not exist."
|
|
|
|
| 52 |
input_data: Dict[str, Any]):
|
| 53 |
self._check_input_data(input_data)
|
| 54 |
response = {}
|
| 55 |
+
for mem_name, mem_path in input_data["memory_files"].items():
|
| 56 |
if mem_name in ['plan', 'logs']:
|
| 57 |
response[mem_name] = self._read_text(mem_path)
|
| 58 |
elif mem_name == 'code_library' and mem_path.endswith('.py'):
|