Spaces:
Build error
Build error
| import numpy as np | |
| import os | |
| import IPython | |
| from cliport import tasks | |
| from cliport.dataset import RavensDataset | |
| from cliport.environments.environment import Environment | |
| import imageio | |
| from pygments import highlight | |
| from pygments.lexers import PythonLexer | |
| from pygments.formatters import HtmlFormatter, TerminalFormatter | |
| import gradio | |
| import time | |
| import random | |
| import json | |
| import traceback | |
| from gensim.utils import ( | |
| mkdir_if_missing, | |
| save_text, | |
| save_stat, | |
| compute_diversity_score_from_assets, | |
| add_to_txt | |
| ) | |
| import pybullet as p | |
| class SimulationRunner: | |
| """ the main class that runs simulation loop """ | |
| def __init__(self, cfg, agent, critic, memory): | |
| self.cfg = cfg | |
| self.agent = agent | |
| self.critic = critic | |
| self.memory = memory | |
| self.log = "" | |
| # statistics | |
| self.syntax_pass_rate = 0 | |
| self.runtime_pass_rate = 0 | |
| self.env_pass_rate = 0 | |
| self.curr_trials = 0 | |
| self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
| self.chat_log = memory.chat_log | |
| self.task_asset_logs = [] | |
| # All the generated tasks in this run. | |
| # Different from the ones in online buffer that can load from offline. | |
| self.generated_task_assets = [] | |
| self.generated_task_programs = [] | |
| self.generated_task_names = [] | |
| self.generated_tasks = [] | |
| self.passed_tasks = [] # accepted ones | |
| def print_current_stats(self): | |
| """ print the current statistics of the simulation design """ | |
| print("=========================================================") | |
| print(f"{self.cfg['prompt_folder']} Trial {self.curr_trials} SYNTAX_PASS_RATE: {(self.syntax_pass_rate / (self.curr_trials)) * 100:.1f}% RUNTIME_PASS_RATE: {(self.runtime_pass_rate / (self.curr_trials)) * 100:.1f}% ENV_PASS_RATE: {(self.env_pass_rate / (self.curr_trials)) * 100:.1f}%") | |
| print("=========================================================") | |
| def save_stats(self): | |
| """ save the final simulation statistics """ | |
| self.diversity_score = compute_diversity_score_from_assets(self.task_asset_logs, self.curr_trials) | |
| save_stat(self.cfg, self.cfg['model_output_dir'], self.generated_tasks, self.syntax_pass_rate / (self.curr_trials), | |
| self.runtime_pass_rate / (self.data_pathcurr_trials), self.env_pass_rate / (self.curr_trials), self.diversity_score) | |
| print("Model Folder: ", self.cfg['model_output_dir']) | |
| print(f"Total {len(self.generated_tasks)} New Tasks:", [task['task-name'] for task in self.generated_tasks]) | |
| try: | |
| print(f"Added {len(self.passed_tasks)} Tasks:", self.passed_tasks) | |
| except: | |
| pass | |
| def example_task_creation(self): | |
| """ create the task through interactions of agent and critic """ | |
| self.task_creation_pass = True | |
| mkdir_if_missing(self.cfg['model_output_dir']) | |
| try: | |
| start_time = time.time() | |
| self.generated_task = {'task-name': 'TASK_NAME_TEMPLATE', 'task-description': 'TASK_STRING_TEMPLATE', 'assets-used': ['ASSET_1', 'ASSET_2', Ellipsis]} | |
| print("generated_task\n", self.generated_task) | |
| yield "Task Generated ==>", "", None, None | |
| self.generated_asset = self.agent.propose_assets() | |
| # self.generated_asset = {} | |
| print("generated_asset\n", self.generated_asset) | |
| yield "Task Generated ==> Asset Generated ==> ", "", None, None | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> ", "", None, None | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> ", "", None, None | |
| online_code_buffer = {} | |
| for task_file in json.load(open(os.path.join('prompts/data', "generated_task_codes.json"))): | |
| if os.path.exists("cliport/generated_tasks/" + task_file): | |
| online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read() | |
| random_task_file = random.sample(list(online_code_buffer.keys()), 1)[0] | |
| class_def = [line for line in online_code_buffer[random_task_file].split("\n") if line.startswith('class')] | |
| task_name = class_def[0] | |
| task_name = task_name[task_name.find("class "): task_name.rfind("(Task)")][6:] | |
| self.curr_task_name = self.generated_task_name = task_name | |
| self.generated_code = online_code_buffer[random_task_file] | |
| print("generated_code\n", self.generated_code) | |
| print("curr_task_name\n", self.curr_task_name) | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> ", "", self.generated_code, None | |
| self.generated_tasks.append(self.generated_task) | |
| self.generated_task_assets.append(self.generated_asset) | |
| self.generated_task_programs.append(self.generated_code) | |
| self.generated_task_names.append(self.generated_task_name) | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| print("Task Creation Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generation Failed", self.log, "", None | |
| self.task_creation_pass = False | |
| return | |
| # self.curr_task_name = self.generated_task['task-name'] | |
| print("task creation time {:.3f}".format(time.time() - start_time)) | |
| def task_creation(self): | |
| """ create the task through interactions of agent and critic """ | |
| self.task_creation_pass = True | |
| mkdir_if_missing(self.cfg['model_output_dir']) | |
| try: | |
| start_time = time.time() | |
| self.generated_task = self.agent.propose_task(self.generated_task_names) | |
| # self.generated_task = {'task-name': 'TASK_NAME_TEMPLATE', 'task-description': 'TASK_STRING_TEMPLATE', 'assets-used': ['ASSET_1', 'ASSET_2', Ellipsis]} | |
| print("generated_task\n", self.generated_task) | |
| yield "Task Generated ==>", "", None, None | |
| self.generated_asset = self.agent.propose_assets() | |
| print("generated_asset\n", self.generated_asset) | |
| yield "Task Generated ==> Asset Generated ==> ", "", None, None | |
| self.agent.api_review() | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> ", "", None, None | |
| self.critic.error_review(self.generated_task) | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> ", "", None, None | |
| self.generated_code, self.curr_task_name = self.agent.implement_task() | |
| self.task_asset_logs.append(self.generated_task["assets-used"]) | |
| self.generated_task_name = self.generated_task["task-name"] | |
| print("generated_code\n", self.generated_code) | |
| print("curr_task_name\n", self.curr_task_name) | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> ", self.log, self.generated_code, None | |
| self.generated_tasks.append(self.generated_task) | |
| self.generated_task_assets.append(self.generated_asset) | |
| self.generated_task_programs.append(self.generated_code) | |
| self.generated_task_names.append(self.generated_task_name) | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| print("Task Creation Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generation Failed", self.log, "", None | |
| self.task_creation_pass = False | |
| return | |
| # self.curr_task_name = self.generated_task['task-name'] | |
| print("task creation time {:.3f}".format(time.time() - start_time)) | |
| def setup_env(self): | |
| """ build the new task""" | |
| env = Environment( | |
| self.cfg['assets_root'], | |
| disp=self.cfg['disp'], | |
| shared_memory=self.cfg['shared_memory'], | |
| hz=480, | |
| record_cfg=self.cfg['record'] | |
| ) | |
| task = eval(self.curr_task_name)() | |
| task.mode = self.cfg['mode'] | |
| record = self.cfg['record']['save_video'] | |
| save_data = self.cfg['save_data'] | |
| # Initialize scripted oracle agent and dataset. | |
| expert = task.oracle(env) | |
| self.cfg['task'] = self.generated_task["task-name"] | |
| data_path = os.path.join(self.cfg['data_dir'], "{}-{}".format(self.generated_task["task-name"], task.mode)) | |
| dataset = RavensDataset(data_path, self.cfg, n_demos=0, augment=False) | |
| print(f"Saving to: {data_path}") | |
| print(f"Mode: {task.mode}") | |
| # Start video recording | |
| if record: | |
| env.start_rec(f'{dataset.n_episodes+1:06d}') | |
| return task, dataset, env, expert | |
| def run_one_episode(self, dataset, expert, env, task, episode, seed): | |
| """ run the new task for one episode """ | |
| add_to_txt( | |
| self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) | |
| record = self.cfg['record']['save_video'] | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) | |
| env.set_task(task) | |
| obs = env.reset() | |
| info = env.info | |
| reward = 0 | |
| total_reward = 0 | |
| # Rollout expert policy | |
| for _ in range(task.max_steps): | |
| act = expert.act(obs, info) | |
| episode.append((obs, act, reward, info)) | |
| lang_goal = info['lang_goal'] | |
| obs, reward, done, info = env.step(act) | |
| total_reward += reward | |
| print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
| if done: | |
| break | |
| episode.append((obs, None, reward, info)) | |
| return total_reward | |
| def simulate_task(self): | |
| """ simulate the created task and save demonstrations """ | |
| total_cnt = 0. | |
| reset_success_cnt = 0. | |
| env_success_cnt = 0. | |
| seed = 123 | |
| self.curr_trials += 1 | |
| if p.isConnected(): | |
| p.disconnect() | |
| if not self.task_creation_pass: | |
| print("task creation failure => count as syntax exceptions.") | |
| return | |
| # Check syntax and compilation-time error | |
| try: | |
| exec(self.generated_code, globals()) | |
| task, dataset, env, expert = self.setup_env() | |
| self.syntax_pass_rate += 1 | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', str(traceback.format_exc())) | |
| print("========================================================") | |
| print("Syntax Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> Code Syntax Parse Failed", self.log, self.generated_code, None | |
| return | |
| try: | |
| # Collect environment and collect data from oracle demonstrations. | |
| env.generated_code = self.generated_code | |
| # Set seeds. | |
| episode = [] | |
| """ run the new task for one episode """ | |
| add_to_txt( | |
| self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) | |
| env.set_task(task) | |
| obs = env.reset() | |
| info = env.info | |
| reward = 0 | |
| total_reward = 0 | |
| # Rollout expert policy | |
| start_time = time.time() | |
| print("start sim") | |
| for i in range(task.max_steps): | |
| act = expert.act(obs, info) | |
| episode.append((obs, act, reward, info)) | |
| lang_goal = info['lang_goal'] | |
| env.step(act) | |
| obs, reward, done, info = env.cur_obs, env.cur_reward, env.cur_done, env.cur_info | |
| total_reward += reward | |
| print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
| if done: | |
| break | |
| end_time = time.time() | |
| print("end sim, time used = ", end_time - start_time) | |
| if not os.path.exists(env.record_cfg['save_video_path']): | |
| os.mkdir(env.record_cfg['save_video_path']) | |
| self.video_path = os.path.join(env.record_cfg['save_video_path'], "123.mp4") | |
| video_writer = imageio.get_writer(self.video_path, | |
| fps=env.record_cfg['fps'], | |
| format='FFMPEG', | |
| codec='h264', ) | |
| print(f"has {len(env.curr_video)} frames to save") | |
| for color in env.curr_video: | |
| video_writer.append_data(color) | |
| video_writer.close() | |
| print("save video to ", self.video_path) | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> Simulation Running completed", self.log, self.generated_code, self.video_path | |
| episode.append((obs, None, reward, info)) | |
| # reset_success_cnt += 1 | |
| # env_success_cnt += total_reward > 0.99 | |
| # | |
| # self.runtime_pass_rate += 1 | |
| print("Runtime Test Pass!") | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', str(traceback.format_exc())) | |
| print("========================================================") | |
| print("Runtime Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Task Generated ==> Asset Generated ==> API Reviewed ==> Error Reviewed ==> Code Generated ==> Simulation Running Failed", self.log, self.generated_code, None | |
| self.memory.save_run(self.generated_task) | |