Praxis / gensim /code_corrector.py
leofeltrin's picture
Implementa sistema híbrido: validação prévia, correção automática e templates flexíveis
13b4740
"""
Correção automática de erros comuns em código gerado.
Sistema híbrido: corrige erros estruturais automaticamente antes de retry.
"""
import re
from typing import Tuple, List, Optional
class CodeCorrector:
"""Corrige erros comuns automaticamente"""
def __init__(self):
self.corrections_applied = []
def correct(self, code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
"""
Corrige erros comuns no código
Args:
code: Código a ser corrigido
task_description: Descrição da tarefa para correções contextuais
Returns:
(corrected_code, list_of_corrections_applied)
"""
self.corrections_applied = []
corrected = code
# 1. Corrigir imports duplicados
corrected = self._fix_duplicate_imports(corrected)
# 2. Corrigir base_pose incorreto
corrected = self._fix_base_pose(corrected)
# 3. Corrigir utils.apply() sem rotação
corrected = self._fix_utils_apply(corrected)
# 4. Corrigir def init para __init__
corrected = self._fix_init_method(corrected)
# 5. Corrigir loops sem variável
corrected = self._fix_for_loops(corrected)
# 6. Corrigir quantidade de objetos
if task_description:
corrected = self._fix_quantity(corrected, task_description)
# 7. Adicionar additional_reset() se faltar
corrected = self._fix_additional_reset(corrected)
# 8. Corrigir p.getQuaternionFromEuler
corrected = self._fix_quaternion(corrected)
# 9. Corrigir step_max_reward para goal único
corrected = self._fix_step_max_reward(corrected)
return corrected, self.corrections_applied
def _fix_duplicate_imports(self, code: str) -> str:
"""Remove imports duplicados"""
lines = code.split('\n')
seen_imports = set()
result = []
for line in lines:
stripped = line.strip()
if stripped.startswith('import ') or stripped.startswith('from '):
normalized = re.sub(r'\s+', ' ', stripped)
if normalized not in seen_imports:
seen_imports.add(normalized)
result.append(line)
else:
self.corrections_applied.append(f"Removed duplicate import: {stripped}")
else:
result.append(line)
return '\n'.join(result)
def _fix_base_pose(self, code: str) -> str:
"""Corrige base_pose = (x, y, z) para base_pose = ((x, y, z), (0, 0, 0, 1))"""
# Padrão: base_pose = (0.5, 0, 0.02) sem rotação
pattern = r'base_pose\s*=\s*\(([^)]+)\)\s*(?=\n|#|$)'
def replace_pose(match):
coords = match.group(1).strip()
# Verificar se já é uma pose completa
if coords.count('(') > 0 and coords.count(')') > 0:
return match.group(0) # Já está correto
# Extrair coordenadas
coords_match = re.match(r'([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)', coords)
if coords_match:
x, y, z = coords_match.groups()
self.corrections_applied.append(f"Fixed base_pose: added rotation tuple")
return f'base_pose = (({x}, {y}, {z}), (0, 0, 0, 1))'
return match.group(0)
corrected = re.sub(pattern, replace_pose, code)
return corrected
def _fix_utils_apply(self, code: str) -> str:
"""Corrige utils.apply() sem rotação"""
# Padrão: targs = [utils.apply(base_pose, offset) for offset in offsets]
# Deve ser: targs = [(utils.apply(base_pose, offset), base_pose[1]) for offset in offsets]
pattern = r'(\w+)\s*=\s*\[utils\.apply\(([^)]+)\)\s+for\s+([^\]]+)\]'
def replace_apply(match):
var_name = match.group(1)
apply_args = match.group(2)
loop_vars = match.group(3)
# Verificar se já tem rotação
context = code[max(0, code.find(match.group(0))-100):code.find(match.group(0))+100]
if 'base_pose[1]' in context or '(0, 0, 0, 1)' in context:
return match.group(0) # Já está correto
self.corrections_applied.append(f"Fixed utils.apply(): added rotation")
return f'{var_name} = [(utils.apply({apply_args}), base_pose[1]) for {loop_vars}]'
corrected = re.sub(pattern, replace_apply, code)
return corrected
def _fix_init_method(self, code: str) -> str:
"""Corrige def init(self): para def __init__(self):"""
if 'def init(self):' in code and '__init__' not in code:
self.corrections_applied.append("Fixed: def init() -> def __init__()")
code = code.replace('def init(self):', 'def __init__(self):')
return code
def _fix_for_loops(self, code: str) -> str:
"""Corrige for in range para for _ in range"""
pattern = r'for\s+in\s+range'
if re.search(pattern, code):
self.corrections_applied.append("Fixed: for loop missing variable")
code = re.sub(r'for\s+in\s+range', 'for _ in range', code)
return code
def _fix_quantity(self, code: str, task_description: str) -> str:
"""Corrige quantidade de objetos baseado na descrição"""
# Extrair número esperado
numbers = re.findall(r'\b(\d+)\s+(?:blocos?|blocks?|objetos?|objects?)\b',
task_description.lower())
if not numbers:
numbers = re.findall(r'\b(?:uma|um|one|a)\s+(?:fileira|row)\s+de\s+(\d+)',
task_description.lower())
if not numbers:
return code
expected_count = int(numbers[0])
# Encontrar e corrigir range() incorreto
def replace_range(match):
current_count = int(match.group(1))
if current_count != expected_count:
self.corrections_applied.append(
f"Fixed quantity: range({current_count}) -> range({expected_count})"
)
return f'range({expected_count})'
return match.group(0)
code = re.sub(r'range\((\d+)\)', replace_range, code)
return code
def _fix_additional_reset(self, code: str) -> str:
"""Adiciona self.additional_reset() se faltar no __init__"""
if '__init__' in code and 'additional_reset()' not in code:
# Encontrar fim do __init__
init_match = re.search(r'def __init__\(self\):.*?(?=\n def |\nclass |\Z)', code, re.DOTALL)
if init_match:
init_body = init_match.group(0)
# Adicionar antes do fim
if 'self.additional_reset()' not in init_body:
# Adicionar após super().__init__()
if 'super().__init__()' in init_body:
code = code.replace(
'super().__init__()',
'super().__init__()\n self.additional_reset()'
)
self.corrections_applied.append("Added self.additional_reset() to __init__")
return code
def _fix_quaternion(self, code: str) -> str:
"""Substitui p.getQuaternionFromEuler por (0, 0, 0, 1)"""
if 'p.getQuaternionFromEuler' in code:
# Substituir chamadas simples
code = re.sub(
r'p\.getQuaternionFromEuler\(\[0,\s*0,\s*0\]\)',
'(0, 0, 0, 1)',
code
)
self.corrections_applied.append("Replaced p.getQuaternionFromEuler with (0, 0, 0, 1)")
return code
def _fix_step_max_reward(self, code: str) -> str:
"""Corrige step_max_reward para goal único"""
# Contar número de add_goal() calls
goal_count = len(re.findall(r'\.add_goal\(', code))
if goal_count == 1:
# Se há apenas 1 goal, step_max_reward deve ser 1.0
pattern = r'step_max_reward\s*=\s*1\s*/\s*\d+'
if re.search(pattern, code):
code = re.sub(
r'step_max_reward\s*=\s*1\s*/\s*\d+',
'step_max_reward=1.0',
code
)
self.corrections_applied.append("Fixed step_max_reward: 1/N -> 1.0 for single goal")
return code
def auto_correct_code(code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
"""
Função de conveniência para correção automática
Args:
code: Código a ser corrigido
task_description: Descrição da tarefa (opcional)
Returns:
(corrected_code, corrections_applied)
"""
corrector = CodeCorrector()
return corrector.correct(code, task_description)