Spaces:
Sleeping
Sleeping
Corrige bounds de validação de pose para corresponder ao workspace real do Cliport (X: [0.25, 0.75], Y: [-0.5, 0.5], Z: [0.0, 0.3])
397e692
| import numpy as np | |
| import os | |
| import IPython | |
| from cliport import tasks | |
| from cliport.dataset import RavensDataset | |
| from cliport.environments.environment import Environment | |
| import imageio | |
| from pygments import highlight | |
| from pygments.lexers import PythonLexer | |
| from pygments.formatters import HtmlFormatter, TerminalFormatter | |
| import gradio | |
| import time | |
| import random | |
| import json | |
| import traceback | |
| import re | |
| from gensim.utils import ( | |
| mkdir_if_missing, | |
| save_text, | |
| save_stat, | |
| compute_diversity_score_from_assets, | |
| add_to_txt | |
| ) | |
| import pybullet as p | |
| from gensim.code_fixer import attempt_code_repair | |
| POSE_PATTERN = re.compile( | |
| r"\(\s*\(\s*([-+0-9.eE]+)\s*,\s*([-+0-9.eE]+)\s*,\s*([-+0-9.eE]+)\s*\)\s*,\s*\(\s*([-+0-9.eE]+)\s*,\s*([-+0-9.eE]+)\s*,\s*([-+0-9.eE]+)\s*,\s*([-+0-9.eE]+)\s*\)\s*\)" | |
| ) | |
| class SimulationRunner: | |
| """ the main class that runs simulation loop """ | |
| def __init__(self, cfg, agent, critic, memory): | |
| self.cfg = cfg | |
| self.agent = agent | |
| self.critic = critic | |
| self.memory = memory | |
| self.log = "" | |
| # statistics | |
| self.syntax_pass_rate = 0 | |
| self.runtime_pass_rate = 0 | |
| self.env_pass_rate = 0 | |
| self.curr_trials = 0 | |
| self.prompt_folder = f"prompts/{cfg['prompt_folder']}" | |
| self.chat_log = memory.chat_log | |
| self.task_asset_logs = [] | |
| # All the generated tasks in this run. | |
| # Different from the ones in online buffer that can load from offline. | |
| self.generated_task_assets = [] | |
| self.generated_task_programs = [] | |
| self.generated_task_names = [] | |
| self.generated_tasks = [] | |
| self.passed_tasks = [] # accepted ones | |
| self.auto_fix_attempts = cfg.get('auto_fix_attempts', 0) | |
| pose_validation_cfg = cfg.get('pose_validation', {}) | |
| self.pose_validation_enabled = pose_validation_cfg.get('enabled', False) | |
| self.pose_validation_max_retries = pose_validation_cfg.get('max_retries', 0) | |
| # Support separate X and Y bounds, fallback to xy_bounds for compatibility | |
| if 'x_bounds' in pose_validation_cfg and 'y_bounds' in pose_validation_cfg: | |
| x_bounds = pose_validation_cfg.get('x_bounds', [0.25, 0.75]) | |
| y_bounds = pose_validation_cfg.get('y_bounds', [-0.5, 0.5]) | |
| self.pose_x_bounds = (x_bounds[0], x_bounds[1]) | |
| self.pose_y_bounds = (y_bounds[0], y_bounds[1]) | |
| else: | |
| # Legacy: use xy_bounds for both X and Y | |
| xy_bounds = pose_validation_cfg.get('xy_bounds', [0.25, 0.75]) | |
| self.pose_x_bounds = (xy_bounds[0], xy_bounds[1]) | |
| self.pose_y_bounds = (xy_bounds[0], xy_bounds[1]) | |
| z_bounds = pose_validation_cfg.get('z_bounds', [0.0, 0.3]) | |
| self.pose_z_bounds = (z_bounds[0], z_bounds[1]) | |
| def print_current_stats(self): | |
| """ print the current statistics of the simulation design """ | |
| print("=========================================================") | |
| print(f"{self.cfg['prompt_folder']} Trial {self.curr_trials} SYNTAX_PASS_RATE: {(self.syntax_pass_rate / (self.curr_trials)) * 100:.1f}% RUNTIME_PASS_RATE: {(self.runtime_pass_rate / (self.curr_trials)) * 100:.1f}% ENV_PASS_RATE: {(self.env_pass_rate / (self.curr_trials)) * 100:.1f}%") | |
| print("=========================================================") | |
| def save_stats(self): | |
| """ save the final simulation statistics """ | |
| self.diversity_score = compute_diversity_score_from_assets(self.task_asset_logs, self.curr_trials) | |
| save_stat(self.cfg, self.cfg['model_output_dir'], self.generated_tasks, self.syntax_pass_rate / (self.curr_trials), | |
| self.runtime_pass_rate / (self.data_pathcurr_trials), self.env_pass_rate / (self.curr_trials), self.diversity_score) | |
| print("Model Folder: ", self.cfg['model_output_dir']) | |
| print(f"Total {len(self.generated_tasks)} New Tasks:", [task['task-name'] for task in self.generated_tasks]) | |
| try: | |
| print(f"Added {len(self.passed_tasks)} Tasks:", self.passed_tasks) | |
| except: | |
| pass | |
| def _format_html_traceback(self, trace_str): | |
| return highlight(f"{str(trace_str)}", PythonLexer(), HtmlFormatter()) | |
| def _can_attempt_fix(self, attempt_idx): | |
| return attempt_idx < self.auto_fix_attempts | |
| def _reset_physics(self): | |
| if p.isConnected(): | |
| p.disconnect() | |
| def _apply_auto_fix(self, error_trace, attempt_idx): | |
| fix = attempt_code_repair( | |
| cfg=self.cfg, | |
| task_metadata=self.generated_task, | |
| current_code=self.generated_code, | |
| error_log=error_trace, | |
| attempt_idx=attempt_idx, | |
| interaction_log=self.chat_log, | |
| ) | |
| if not fix: | |
| return False | |
| new_code, new_task_name = fix | |
| if not new_code: | |
| return False | |
| is_valid, reason = self._validate_generated_code(new_code) | |
| if not is_valid: | |
| print("Auto-fix rejected due to invalid pose:", reason) | |
| self.log = self._format_html_traceback(reason) | |
| return False | |
| self.generated_code = new_code | |
| if new_task_name: | |
| self.curr_task_name = new_task_name | |
| if self.generated_task_programs: | |
| self.generated_task_programs[-1] = new_code | |
| save_text( | |
| self.cfg['model_output_dir'], | |
| f"{self.generated_task_name}_autofix_attempt_{attempt_idx + 1}", | |
| new_code, | |
| ) | |
| return True | |
| def _validate_generated_code(self, code: str): | |
| if not self.pose_validation_enabled: | |
| return True, "" | |
| matches = POSE_PATTERN.findall(code) | |
| if not matches: | |
| return True, "" | |
| invalid_entries = [] | |
| for pose in matches: | |
| try: | |
| x, y, z = map(float, pose[:3]) | |
| except ValueError: | |
| continue | |
| if not (self.pose_x_bounds[0] <= x <= self.pose_x_bounds[1]): | |
| invalid_entries.append(f"x={x:.3f} fora dos limites X {self.pose_x_bounds}") | |
| if not (self.pose_y_bounds[0] <= y <= self.pose_y_bounds[1]): | |
| invalid_entries.append(f"y={y:.3f} fora dos limites Y {self.pose_y_bounds}") | |
| if not (self.pose_z_bounds[0] <= z <= self.pose_z_bounds[1]): | |
| invalid_entries.append(f"z={z:.3f} fora dos limites Z {self.pose_z_bounds}") | |
| if len(invalid_entries) >= 3: | |
| break | |
| if invalid_entries: | |
| reason = "Validação de pose falhou: " + "; ".join(invalid_entries) | |
| return False, reason | |
| return True, "" | |
| def example_task_creation(self): | |
| """ create the task through interactions of agent and critic """ | |
| self.task_creation_pass = True | |
| mkdir_if_missing(self.cfg['model_output_dir']) | |
| try: | |
| start_time = time.time() | |
| self.generated_task = {'task-name': 'TASK_NAME_TEMPLATE', 'task-description': 'TASK_STRING_TEMPLATE', 'assets-used': ['ASSET_1', 'ASSET_2', Ellipsis]} | |
| print("generated_task\n", self.generated_task) | |
| yield "Tarefa gerada ==>", "", None, None | |
| self.generated_asset = self.agent.propose_assets() | |
| # self.generated_asset = {} | |
| print("generated_asset\n", self.generated_asset) | |
| yield "Tarefa gerada ==> Asset gerado ==> ", "", None, None | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> ", "", None, None | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> ", "", None, None | |
| online_code_buffer = {} | |
| for task_file in json.load(open(os.path.join('prompts/data', "generated_task_codes.json"))): | |
| if os.path.exists("cliport/generated_tasks/" + task_file): | |
| online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read() | |
| random_task_file = random.sample(list(online_code_buffer.keys()), 1)[0] | |
| class_def = [line for line in online_code_buffer[random_task_file].split("\n") if line.startswith('class')] | |
| task_name = class_def[0] | |
| task_name = task_name[task_name.find("class "): task_name.rfind("(Task)")][6:] | |
| self.curr_task_name = self.generated_task_name = task_name | |
| self.generated_code = online_code_buffer[random_task_file] | |
| print("generated_code\n", self.generated_code) | |
| print("curr_task_name\n", self.curr_task_name) | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Codigo gerado ==> ", "", self.generated_code, None | |
| self.generated_tasks.append(self.generated_task) | |
| self.generated_task_assets.append(self.generated_asset) | |
| self.generated_task_programs.append(self.generated_code) | |
| self.generated_task_names.append(self.generated_task_name) | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| print("Task Creation Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Falha na geracao de codigo", self.log, "", None | |
| self.task_creation_pass = False | |
| return | |
| # self.curr_task_name = self.generated_task['task-name'] | |
| print("task creation time {:.3f}".format(time.time() - start_time)) | |
| def task_creation(self): | |
| """ create the task through interactions of agent and critic """ | |
| self.task_creation_pass = True | |
| mkdir_if_missing(self.cfg['model_output_dir']) | |
| try: | |
| start_time = time.time() | |
| self.generated_task = self.agent.propose_task(self.generated_task_names) | |
| # self.generated_task = {'task-name': 'TASK_NAME_TEMPLATE', 'task-description': 'TASK_STRING_TEMPLATE', 'assets-used': ['ASSET_1', 'ASSET_2', Ellipsis]} | |
| print("generated_task\n", self.generated_task) | |
| yield "Tarefa gerada ==>", "", None, None | |
| self.generated_asset = self.agent.propose_assets() | |
| print("generated_asset\n", self.generated_asset) | |
| yield "Tarefa gerada ==> Asset gerado ==> ", "", None, None | |
| self.agent.api_review() | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> ", "", None, None | |
| self.critic.error_review(self.generated_task) | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> ", "", None, None | |
| max_pose_retries = self.pose_validation_max_retries if self.pose_validation_enabled else 0 | |
| validation_attempts = 0 | |
| pose_status = "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Validacao de pose falhou" | |
| previous_error = None | |
| while True: | |
| self.generated_code, self.curr_task_name = self.agent.implement_task( | |
| previous_error=previous_error, | |
| attempt_number=validation_attempts | |
| ) | |
| is_valid, reason = self._validate_generated_code(self.generated_code) | |
| if is_valid: | |
| break | |
| validation_attempts += 1 | |
| previous_error = reason # Store error for next attempt | |
| reason_html = self._format_html_traceback(reason) | |
| self.log = reason_html | |
| print("Pose validation failed:", reason) | |
| yield pose_status + f" ==> Nova tentativa {validation_attempts}", reason_html, self.generated_code, None | |
| if validation_attempts > max_pose_retries: | |
| raise RuntimeError(reason) | |
| self.log = "" | |
| self.task_asset_logs.append(self.generated_task["assets-used"]) | |
| self.generated_task_name = self.generated_task["task-name"] | |
| print("generated_code\n", self.generated_code) | |
| print("curr_task_name\n", self.curr_task_name) | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Codigo gerado ==> ", self.log, self.generated_code, None | |
| self.generated_tasks.append(self.generated_task) | |
| self.generated_task_assets.append(self.generated_asset) | |
| self.generated_task_programs.append(self.generated_code) | |
| self.generated_task_names.append(self.generated_task_name) | |
| except: | |
| to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), HtmlFormatter()) | |
| print("Task Creation Exception:", highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter())) | |
| self.log = to_print | |
| yield "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Falha na geracao de codigo", self.log, "", None | |
| self.task_creation_pass = False | |
| return | |
| # self.curr_task_name = self.generated_task['task-name'] | |
| print("task creation time {:.3f}".format(time.time() - start_time)) | |
| def setup_env(self): | |
| """ build the new task""" | |
| env = Environment( | |
| self.cfg['assets_root'], | |
| disp=self.cfg['disp'], | |
| shared_memory=self.cfg['shared_memory'], | |
| hz=480, | |
| record_cfg=self.cfg['record'] | |
| ) | |
| task = eval(self.curr_task_name)() | |
| task.mode = self.cfg['mode'] | |
| record = self.cfg['record']['save_video'] | |
| save_data = self.cfg['save_data'] | |
| # Initialize scripted oracle agent and dataset. | |
| expert = task.oracle(env) | |
| self.cfg['task'] = self.generated_task["task-name"] | |
| data_path = os.path.join(self.cfg['data_dir'], "{}-{}".format(self.generated_task["task-name"], task.mode)) | |
| dataset = RavensDataset(data_path, self.cfg, n_demos=0, augment=False) | |
| print(f"Saving to: {data_path}") | |
| print(f"Mode: {task.mode}") | |
| # Start video recording | |
| if record: | |
| env.start_rec(f'{dataset.n_episodes+1:06d}') | |
| return task, dataset, env, expert | |
| def run_one_episode(self, dataset, expert, env, task, episode, seed): | |
| """ run the new task for one episode """ | |
| add_to_txt( | |
| self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) | |
| record = self.cfg['record']['save_video'] | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) | |
| env.set_task(task) | |
| obs = env.reset() | |
| info = env.info | |
| reward = 0 | |
| total_reward = 0 | |
| # Rollout expert policy | |
| for _ in range(task.max_steps): | |
| act = expert.act(obs, info) | |
| episode.append((obs, act, reward, info)) | |
| lang_goal = info['lang_goal'] | |
| obs, reward, done, info = env.step(act) | |
| total_reward += reward | |
| print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
| if done: | |
| break | |
| episode.append((obs, None, reward, info)) | |
| return total_reward | |
| def simulate_task(self): | |
| """ simulate the created task and save demonstrations """ | |
| total_cnt = 0. | |
| reset_success_cnt = 0. | |
| env_success_cnt = 0. | |
| seed = 123 | |
| self.curr_trials += 1 | |
| if p.isConnected(): | |
| p.disconnect() | |
| if not self.task_creation_pass: | |
| print("task creation failure => count as syntax exceptions.") | |
| return | |
| fix_attempts_used = 0 | |
| status_prefix = "Tarefa gerada ==> Asset gerado ==> API revisada ==> Erros revisados ==> Codigo gerado ==> " | |
| while True: | |
| env = None | |
| try: | |
| exec(self.generated_code, globals()) | |
| task, dataset, env, expert = self.setup_env() | |
| self.syntax_pass_rate += 1 | |
| except Exception: | |
| trace_str = traceback.format_exc() | |
| html_trace = self._format_html_traceback(trace_str) | |
| save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', trace_str) | |
| print("========================================================") | |
| print("Syntax Exception:", highlight(f"{trace_str}", PythonLexer(), TerminalFormatter())) | |
| self.log = html_trace | |
| if self._can_attempt_fix(fix_attempts_used) and self._apply_auto_fix(trace_str, fix_attempts_used): | |
| fix_attempts_used += 1 | |
| yield status_prefix + f"Falha de sintaxe do codigo ==> Tentativa automatica {fix_attempts_used}", self.log, self.generated_code, None | |
| self._reset_physics() | |
| continue | |
| yield status_prefix + "Falha de sintaxe do codigo", self.log, self.generated_code, None | |
| return | |
| try: | |
| env.generated_code = self.generated_code | |
| episode = [] | |
| add_to_txt( | |
| self.chat_log, f"================= TRIAL: {self.curr_trials}", with_print=True) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, self.cfg['n'], seed)) | |
| env.set_task(task) | |
| obs = env.reset() | |
| info = env.info | |
| reward = 0 | |
| total_reward = 0 | |
| start_time = time.time() | |
| print("start sim") | |
| for i in range(task.max_steps): | |
| act = expert.act(obs, info) | |
| episode.append((obs, act, reward, info)) | |
| lang_goal = info['lang_goal'] | |
| env.step(act) | |
| obs, reward, done, info = env.cur_obs, env.cur_reward, env.cur_done, env.cur_info | |
| total_reward += reward | |
| print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') | |
| if done: | |
| break | |
| end_time = time.time() | |
| print("end sim, time used = ", end_time - start_time) | |
| if not os.path.exists(env.record_cfg['save_video_path']): | |
| os.mkdir(env.record_cfg['save_video_path']) | |
| self.video_path = os.path.join(env.record_cfg['save_video_path'], "123.mp4") | |
| video_writer = imageio.get_writer(self.video_path, | |
| fps=env.record_cfg['fps'], | |
| format='FFMPEG', | |
| codec='libx264', | |
| ffmpeg_params=['-crf', '18', '-preset', 'medium']) # CRF 18 = high quality | |
| print(f"has {len(env.curr_video)} frames to save") | |
| for color in env.curr_video: | |
| video_writer.append_data(color) | |
| video_writer.close() | |
| print("save video to ", self.video_path) | |
| yield status_prefix + "Simulacao concluida com sucesso", self.log, self.generated_code, self.video_path | |
| episode.append((obs, None, reward, info)) | |
| print("Runtime Test Pass!") | |
| break | |
| except Exception: | |
| trace_str = traceback.format_exc() | |
| html_trace = self._format_html_traceback(trace_str) | |
| save_text(self.cfg['model_output_dir'], self.generated_task_name + '_error', trace_str) | |
| print("========================================================") | |
| print("Runtime Exception:", highlight(f"{trace_str}", PythonLexer(), TerminalFormatter())) | |
| self.log = html_trace | |
| if self._can_attempt_fix(fix_attempts_used) and self._apply_auto_fix(trace_str, fix_attempts_used): | |
| fix_attempts_used += 1 | |
| yield status_prefix + f"Falha na execucao da simulacao ==> Tentativa automatica {fix_attempts_used}", self.log, self.generated_code, None | |
| self._reset_physics() | |
| continue | |
| yield status_prefix + "Falha na execucao da simulacao", self.log, self.generated_code, None | |
| return | |
| self.memory.save_run(self.generated_task) | |