""" Correção automática de erros comuns em código gerado. Sistema híbrido: corrige erros estruturais automaticamente antes de retry. """ import re from typing import Tuple, List, Optional class CodeCorrector: """Corrige erros comuns automaticamente""" def __init__(self): self.corrections_applied = [] def correct(self, code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]: """ Corrige erros comuns no código Args: code: Código a ser corrigido task_description: Descrição da tarefa para correções contextuais Returns: (corrected_code, list_of_corrections_applied) """ self.corrections_applied = [] corrected = code # 1. Corrigir imports duplicados corrected = self._fix_duplicate_imports(corrected) # 2. Corrigir base_pose incorreto corrected = self._fix_base_pose(corrected) # 3. Corrigir utils.apply() sem rotação corrected = self._fix_utils_apply(corrected) # 4. Corrigir def init para __init__ corrected = self._fix_init_method(corrected) # 5. Corrigir loops sem variável corrected = self._fix_for_loops(corrected) # 6. Corrigir quantidade de objetos if task_description: corrected = self._fix_quantity(corrected, task_description) # 7. Adicionar additional_reset() se faltar corrected = self._fix_additional_reset(corrected) # 8. Corrigir p.getQuaternionFromEuler corrected = self._fix_quaternion(corrected) # 9. Corrigir step_max_reward para goal único corrected = self._fix_step_max_reward(corrected) return corrected, self.corrections_applied def _fix_duplicate_imports(self, code: str) -> str: """Remove imports duplicados""" lines = code.split('\n') seen_imports = set() result = [] for line in lines: stripped = line.strip() if stripped.startswith('import ') or stripped.startswith('from '): normalized = re.sub(r'\s+', ' ', stripped) if normalized not in seen_imports: seen_imports.add(normalized) result.append(line) else: self.corrections_applied.append(f"Removed duplicate import: {stripped}") else: result.append(line) return '\n'.join(result) def _fix_base_pose(self, code: str) -> str: """Corrige base_pose = (x, y, z) para base_pose = ((x, y, z), (0, 0, 0, 1))""" # Padrão: base_pose = (0.5, 0, 0.02) sem rotação pattern = r'base_pose\s*=\s*\(([^)]+)\)\s*(?=\n|#|$)' def replace_pose(match): coords = match.group(1).strip() # Verificar se já é uma pose completa if coords.count('(') > 0 and coords.count(')') > 0: return match.group(0) # Já está correto # Extrair coordenadas coords_match = re.match(r'([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)', coords) if coords_match: x, y, z = coords_match.groups() self.corrections_applied.append(f"Fixed base_pose: added rotation tuple") return f'base_pose = (({x}, {y}, {z}), (0, 0, 0, 1))' return match.group(0) corrected = re.sub(pattern, replace_pose, code) return corrected def _fix_utils_apply(self, code: str) -> str: """Corrige utils.apply() sem rotação""" # Padrão: targs = [utils.apply(base_pose, offset) for offset in offsets] # Deve ser: targs = [(utils.apply(base_pose, offset), base_pose[1]) for offset in offsets] pattern = r'(\w+)\s*=\s*\[utils\.apply\(([^)]+)\)\s+for\s+([^\]]+)\]' def replace_apply(match): var_name = match.group(1) apply_args = match.group(2) loop_vars = match.group(3) # Verificar se já tem rotação context = code[max(0, code.find(match.group(0))-100):code.find(match.group(0))+100] if 'base_pose[1]' in context or '(0, 0, 0, 1)' in context: return match.group(0) # Já está correto self.corrections_applied.append(f"Fixed utils.apply(): added rotation") return f'{var_name} = [(utils.apply({apply_args}), base_pose[1]) for {loop_vars}]' corrected = re.sub(pattern, replace_apply, code) return corrected def _fix_init_method(self, code: str) -> str: """Corrige def init(self): para def __init__(self):""" if 'def init(self):' in code and '__init__' not in code: self.corrections_applied.append("Fixed: def init() -> def __init__()") code = code.replace('def init(self):', 'def __init__(self):') return code def _fix_for_loops(self, code: str) -> str: """Corrige for in range para for _ in range""" pattern = r'for\s+in\s+range' if re.search(pattern, code): self.corrections_applied.append("Fixed: for loop missing variable") code = re.sub(r'for\s+in\s+range', 'for _ in range', code) return code def _fix_quantity(self, code: str, task_description: str) -> str: """Corrige quantidade de objetos baseado na descrição""" # Extrair número esperado numbers = re.findall(r'\b(\d+)\s+(?:blocos?|blocks?|objetos?|objects?)\b', task_description.lower()) if not numbers: numbers = re.findall(r'\b(?:uma|um|one|a)\s+(?:fileira|row)\s+de\s+(\d+)', task_description.lower()) if not numbers: return code expected_count = int(numbers[0]) # Encontrar e corrigir range() incorreto def replace_range(match): current_count = int(match.group(1)) if current_count != expected_count: self.corrections_applied.append( f"Fixed quantity: range({current_count}) -> range({expected_count})" ) return f'range({expected_count})' return match.group(0) code = re.sub(r'range\((\d+)\)', replace_range, code) return code def _fix_additional_reset(self, code: str) -> str: """Adiciona self.additional_reset() se faltar no __init__""" if '__init__' in code and 'additional_reset()' not in code: # Encontrar fim do __init__ init_match = re.search(r'def __init__\(self\):.*?(?=\n def |\nclass |\Z)', code, re.DOTALL) if init_match: init_body = init_match.group(0) # Adicionar antes do fim if 'self.additional_reset()' not in init_body: # Adicionar após super().__init__() if 'super().__init__()' in init_body: code = code.replace( 'super().__init__()', 'super().__init__()\n self.additional_reset()' ) self.corrections_applied.append("Added self.additional_reset() to __init__") return code def _fix_quaternion(self, code: str) -> str: """Substitui p.getQuaternionFromEuler por (0, 0, 0, 1)""" if 'p.getQuaternionFromEuler' in code: # Substituir chamadas simples code = re.sub( r'p\.getQuaternionFromEuler\(\[0,\s*0,\s*0\]\)', '(0, 0, 0, 1)', code ) self.corrections_applied.append("Replaced p.getQuaternionFromEuler with (0, 0, 0, 1)") return code def _fix_step_max_reward(self, code: str) -> str: """Corrige step_max_reward para goal único""" # Contar número de add_goal() calls goal_count = len(re.findall(r'\.add_goal\(', code)) if goal_count == 1: # Se há apenas 1 goal, step_max_reward deve ser 1.0 pattern = r'step_max_reward\s*=\s*1\s*/\s*\d+' if re.search(pattern, code): code = re.sub( r'step_max_reward\s*=\s*1\s*/\s*\d+', 'step_max_reward=1.0', code ) self.corrections_applied.append("Fixed step_max_reward: 1/N -> 1.0 for single goal") return code def auto_correct_code(code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]: """ Função de conveniência para correção automática Args: code: Código a ser corrigido task_description: Descrição da tarefa (opcional) Returns: (corrected_code, corrections_applied) """ corrector = CodeCorrector() return corrector.correct(code, task_description)