Spaces:
Build error
Build error
| import re | |
| import traceback | |
| from datatypes import ParseError, StepOutput, TaskState | |
| from tasks.base import Task | |
| from openhands.controller.state.state import State | |
| class SimplifiedEnv: | |
| INVALID_INPUT_MESSAGE = ( | |
| "I don't understand your input. \n" | |
| 'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n' | |
| 'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n' | |
| 'For example: The answer to the question is <solution> 42 </solution>. \n' | |
| ) | |
| def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]): | |
| self.agent_state = agent_state | |
| self.task = task | |
| agent_action_count = { | |
| 'propose_solution': 0, | |
| 'use_tool': 0, | |
| 'invalid_action': 0, | |
| } | |
| # check if agent_state has attribute turn_info set | |
| if hasattr(self.agent_state, 'propose_solution_count'): | |
| agent_action_count['propose_solution'] = ( | |
| self.agent_state.propose_solution_count | |
| ) | |
| self.task_state = TaskState(agent_action_count=agent_action_count) | |
| self.task_config = task_config | |
| def step(self, lm_message: str): | |
| observation = self.handle_propose_solution(lm_message) | |
| self.check_max_iteration() | |
| turn_info = ( | |
| self.task_config['max_iterations'] - self.agent_state.iteration, | |
| self.task_config['max_propose_solution'] | |
| - self.task_state.agent_action_count['propose_solution'], | |
| ) | |
| output = StepOutput( | |
| observation=observation, | |
| success=self.task_state.success, | |
| turn_info=turn_info, | |
| ) | |
| self.agent_state.propose_solution_count = self.task_state.agent_action_count[ | |
| 'propose_solution' | |
| ] | |
| self.log_output(output) | |
| return self.task_state | |
| def handle_propose_solution(self, lm_message) -> str | None: | |
| """Propose answer to check the task success. | |
| It might set self.state.finished = True if the task is successful. | |
| """ | |
| self.task_state.agent_action_count['propose_solution'] += 1 | |
| try: | |
| parsed = self.parse_propose_solution(lm_message) | |
| task_success = self.check_task_success(parsed['answer']) | |
| if task_success: | |
| self.task_state.finished = True | |
| self.task_state.success = True | |
| self.task_state.terminate_reason = 'task_success' | |
| # NOTE: should not return the function now, because we need to log the output | |
| # Set state.finished = True will terminate the episode | |
| except ParseError: | |
| return SimplifiedEnv.INVALID_INPUT_MESSAGE | |
| except Exception: | |
| error_traceback = traceback.format_exc() | |
| return f'{error_traceback}' | |
| def parse_propose_solution(self, lm_message: str) -> dict: | |
| """Define the parsing logic.""" | |
| lm_output = '\n' + lm_message + '\n' | |
| answer = '\n'.join( | |
| [ | |
| i.strip() | |
| for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL) | |
| ] | |
| ) | |
| if answer == '': | |
| raise ParseError('No answer found.') | |
| return {'answer': answer} | |
| def log_output(self, output: StepOutput) -> None: | |
| if self.task_state.finished: | |
| return | |
| content = output.to_str() | |
| self.task_state.latest_output = output.to_dict() | |
| self.task_state.latest_output['content'] = content | |
| def check_task_success(self, answer: str) -> bool: | |
| # log_message.info(f"STUDENT ANSWER: [{answer}]") | |
| # log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]") | |
| return self.task.success(answer) | |
| def check_max_iteration(self): | |
| """Check if the agent has reached the max iteration limit. | |
| It might set self.state.finished = True if the agent has reached the max iteration limit. | |
| """ | |
| if self.task_state.finished: | |
| # ignore if the episode is already finished (e.g., task success) | |
| return | |
| if ( | |
| # propose solution > max output solution | |
| self.task_state.agent_action_count['propose_solution'] | |
| >= self.task_config['max_propose_solution'] | |
| ): | |
| self.task_state.finished = True | |
| self.task_state.success = False | |
| self.task_state.terminate_reason = 'max_propose_steps' | |
| elif self.agent_state.iteration >= self.task_config['max_iterations']: | |
| self.task_state.finished = True | |
| self.task_state.success = False | |
| self.task_state.terminate_reason = 'max_iterations' | |