Spaces:
Build error
Build error
| import numpy as np | |
| import os | |
| import IPython | |
| import random | |
| import json | |
| import traceback | |
| import pybullet as p | |
| from gensim.utils import ( | |
| save_text, | |
| add_to_txt, | |
| extract_code, | |
| extract_dict, | |
| extract_list, | |
| extract_assets, | |
| format_dict_prompt, | |
| sample_list_reference, | |
| generate_feedback, | |
| ) | |
| class Agent: | |
| """ | |
| class that design new tasks and codes for simulation environments | |
| """ | |
| def __init__(self, cfg, memory): | |
| self.cfg = cfg | |
| self.model_output_dir = cfg["model_output_dir"] | |
| self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
| self.memory = memory | |
| self.chat_log = memory.chat_log | |
| self.use_template = cfg['use_template'] | |
| def propose_task(self, proposed_task_names): | |
| """Language descriptions for the task""" | |
| add_to_txt(self.chat_log, "================= Task and Asset Design!", with_print=True) | |
| if self.use_template: | |
| task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read() | |
| task_asset_replacement_str = format_dict_prompt(self.memory.online_asset_buffer, self.cfg['task_asset_candidate_num']) | |
| task_prompt_text = task_prompt_text.replace("TASK_ASSET_PROMPT", task_asset_replacement_str) | |
| task_desc_replacement_str = format_dict_prompt(self.memory.online_task_buffer, self.cfg['task_description_candidate_num']) | |
| print("prompt task description candidates:") | |
| print(task_desc_replacement_str) | |
| task_prompt_text = task_prompt_text.replace("TASK_DESCRIPTION_PROMPT", task_desc_replacement_str) | |
| if len(self.cfg['target_task_name']) > 0: | |
| task_prompt_text = task_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name']) | |
| # print("Template Task PROMPT: ", task_prompt_text) | |
| else: | |
| task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read() | |
| # maximum number | |
| print("online_task_buffer size:", len(self.memory.online_task_buffer)) | |
| total_tasks = self.memory.online_task_buffer | |
| MAX_NUM = 10 | |
| if len(total_tasks) > MAX_NUM: | |
| total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM)) | |
| task_prompt_text = task_prompt_text.replace("PAST_TASKNAME_TEMPLATE", format_dict_prompt(total_tasks)) | |
| res = generate_feedback( | |
| task_prompt_text, | |
| temperature=self.cfg["gpt_temperature"], | |
| interaction_txt=self.chat_log, | |
| ) | |
| # Extract dictionary for task name, descriptions, and assets | |
| task_def = extract_dict(res, prefix="new_task") | |
| try: | |
| exec(task_def, globals()) | |
| self.new_task = new_task | |
| return new_task | |
| except: | |
| self.new_task = {"task-name": "dummy", "assets-used": [], "task_descriptions": ""} | |
| print(str(traceback.format_exc())) | |
| return self.new_task | |
| def propose_assets(self): | |
| """Asset Generation. Not used for now.""" | |
| if os.path.exists(f"{self.prompt_folder}/cliport_prompt_asset_template.txt"): | |
| add_to_txt(self.chat_log, "================= Asset Generation!", with_print=True) | |
| asset_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_asset_template.txt").read() | |
| if self.use_template: | |
| asset_prompt_text = asset_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"]) | |
| asset_prompt_text = asset_prompt_text.replace("ASSET_STRING_TEMPLATE", str(self.new_task["assets-used"])) | |
| print("Template Asset PROMPT: ", asset_prompt_text) | |
| res = generate_feedback(asset_prompt_text, temperature=0, interaction_txt=self.chat_log) | |
| print("Save asset to:", self.model_output_dir, task_name + "_asset_output") | |
| save_text(self.model_output_dir, f'{self.new_task["task-name"]}_asset_output', res) | |
| asset_list = extract_assets(res) | |
| # save_urdf(asset_list) | |
| else: | |
| asset_list = {} | |
| return asset_list | |
| def api_review(self): | |
| """review the task api""" | |
| if os.path.exists(f"{self.prompt_folder}/cliport_prompt_api_template.txt"): | |
| add_to_txt( | |
| self.chat_log, "================= API Preview!", with_print=True) | |
| api_prompt_text = open( | |
| f"{self.prompt_folder}/cliport_prompt_api_template.txt").read() | |
| if "task-name" in self.new_task: | |
| api_prompt_text = api_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"]) | |
| api_prompt_text = api_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task)) | |
| res = generate_feedback( | |
| api_prompt_text, temperature=0, interaction_txt=self.chat_log) | |
| def template_reference_prompt(self): | |
| """ select which code reference to reference """ | |
| if os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"): | |
| self.chat_log = add_to_txt(self.chat_log, "================= Code Reference!", with_print=True) | |
| code_reference_question = open(f'{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt').read() | |
| code_reference_question = code_reference_question.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"]) | |
| code_reference_question = code_reference_question.replace("TASK_CODE_LIST_TEMPLATE", str(list(self.memory.online_code_buffer.keys()))) | |
| code_reference_question = code_reference_question.replace("TASK_STRING_TEMPLATE", str(self.new_task)) | |
| res = generate_feedback(code_reference_question, temperature=0., interaction_txt=self.chat_log) | |
| code_reference_cmd = extract_list(res, prefix='code_reference') | |
| exec(code_reference_cmd, globals()) | |
| task_code_reference_replace_prompt = '' | |
| for key in code_reference_cmd: | |
| if key in self.memory.online_code_buffer: | |
| task_code_reference_replace_prompt += f'```\n{self.memory.online_code_buffer[key]}\n```\n\n' | |
| else: | |
| print("missing task reference code:", key) | |
| else: | |
| task_code_reference_replace_prompt = sample_list_reference(base_task_codes, sample_num=cfg['task_code_candidate_num']) | |
| # print("Template Reference Code PROMPT: ", task_code_reference_replace_prompt) | |
| return task_code_reference_replace_prompt | |
| def implement_task(self): | |
| """Generate Code for the task""" | |
| code_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt").read() | |
| code_prompt_text = code_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"]) | |
| if self.use_template or os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"): | |
| task_code_reference_replace_prompt = self.template_reference_prompt() | |
| code_prompt_text = code_prompt_text.replace("TASK_CODE_REFERENCE_TEMPLATE", task_code_reference_replace_prompt) | |
| elif os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt"): | |
| self.chat_log = add_to_txt(self.chat_log, "================= Code Generation!", with_print=True) | |
| code_prompt_text = code_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task)) | |
| res = generate_feedback( | |
| code_prompt_text, temperature=0, interaction_txt=self.chat_log) | |
| code, task_name = extract_code(res) | |
| print("Save code to:", self.model_output_dir, task_name + "_code_output") | |
| save_text(self.model_output_dir, task_name + "_code_output", code) | |
| if len(task_name) == 0: | |
| print("empty task name:", task_name) | |
| return None | |
| return code, task_name | |