Spaces:
Sleeping
Sleeping
File size: 9,200 Bytes
13b4740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""
Correção automática de erros comuns em código gerado.
Sistema híbrido: corrige erros estruturais automaticamente antes de retry.
"""
import re
from typing import Tuple, List, Optional
class CodeCorrector:
"""Corrige erros comuns automaticamente"""
def __init__(self):
self.corrections_applied = []
def correct(self, code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
"""
Corrige erros comuns no código
Args:
code: Código a ser corrigido
task_description: Descrição da tarefa para correções contextuais
Returns:
(corrected_code, list_of_corrections_applied)
"""
self.corrections_applied = []
corrected = code
# 1. Corrigir imports duplicados
corrected = self._fix_duplicate_imports(corrected)
# 2. Corrigir base_pose incorreto
corrected = self._fix_base_pose(corrected)
# 3. Corrigir utils.apply() sem rotação
corrected = self._fix_utils_apply(corrected)
# 4. Corrigir def init para __init__
corrected = self._fix_init_method(corrected)
# 5. Corrigir loops sem variável
corrected = self._fix_for_loops(corrected)
# 6. Corrigir quantidade de objetos
if task_description:
corrected = self._fix_quantity(corrected, task_description)
# 7. Adicionar additional_reset() se faltar
corrected = self._fix_additional_reset(corrected)
# 8. Corrigir p.getQuaternionFromEuler
corrected = self._fix_quaternion(corrected)
# 9. Corrigir step_max_reward para goal único
corrected = self._fix_step_max_reward(corrected)
return corrected, self.corrections_applied
def _fix_duplicate_imports(self, code: str) -> str:
"""Remove imports duplicados"""
lines = code.split('\n')
seen_imports = set()
result = []
for line in lines:
stripped = line.strip()
if stripped.startswith('import ') or stripped.startswith('from '):
normalized = re.sub(r'\s+', ' ', stripped)
if normalized not in seen_imports:
seen_imports.add(normalized)
result.append(line)
else:
self.corrections_applied.append(f"Removed duplicate import: {stripped}")
else:
result.append(line)
return '\n'.join(result)
def _fix_base_pose(self, code: str) -> str:
"""Corrige base_pose = (x, y, z) para base_pose = ((x, y, z), (0, 0, 0, 1))"""
# Padrão: base_pose = (0.5, 0, 0.02) sem rotação
pattern = r'base_pose\s*=\s*\(([^)]+)\)\s*(?=\n|#|$)'
def replace_pose(match):
coords = match.group(1).strip()
# Verificar se já é uma pose completa
if coords.count('(') > 0 and coords.count(')') > 0:
return match.group(0) # Já está correto
# Extrair coordenadas
coords_match = re.match(r'([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)', coords)
if coords_match:
x, y, z = coords_match.groups()
self.corrections_applied.append(f"Fixed base_pose: added rotation tuple")
return f'base_pose = (({x}, {y}, {z}), (0, 0, 0, 1))'
return match.group(0)
corrected = re.sub(pattern, replace_pose, code)
return corrected
def _fix_utils_apply(self, code: str) -> str:
"""Corrige utils.apply() sem rotação"""
# Padrão: targs = [utils.apply(base_pose, offset) for offset in offsets]
# Deve ser: targs = [(utils.apply(base_pose, offset), base_pose[1]) for offset in offsets]
pattern = r'(\w+)\s*=\s*\[utils\.apply\(([^)]+)\)\s+for\s+([^\]]+)\]'
def replace_apply(match):
var_name = match.group(1)
apply_args = match.group(2)
loop_vars = match.group(3)
# Verificar se já tem rotação
context = code[max(0, code.find(match.group(0))-100):code.find(match.group(0))+100]
if 'base_pose[1]' in context or '(0, 0, 0, 1)' in context:
return match.group(0) # Já está correto
self.corrections_applied.append(f"Fixed utils.apply(): added rotation")
return f'{var_name} = [(utils.apply({apply_args}), base_pose[1]) for {loop_vars}]'
corrected = re.sub(pattern, replace_apply, code)
return corrected
def _fix_init_method(self, code: str) -> str:
"""Corrige def init(self): para def __init__(self):"""
if 'def init(self):' in code and '__init__' not in code:
self.corrections_applied.append("Fixed: def init() -> def __init__()")
code = code.replace('def init(self):', 'def __init__(self):')
return code
def _fix_for_loops(self, code: str) -> str:
"""Corrige for in range para for _ in range"""
pattern = r'for\s+in\s+range'
if re.search(pattern, code):
self.corrections_applied.append("Fixed: for loop missing variable")
code = re.sub(r'for\s+in\s+range', 'for _ in range', code)
return code
def _fix_quantity(self, code: str, task_description: str) -> str:
"""Corrige quantidade de objetos baseado na descrição"""
# Extrair número esperado
numbers = re.findall(r'\b(\d+)\s+(?:blocos?|blocks?|objetos?|objects?)\b',
task_description.lower())
if not numbers:
numbers = re.findall(r'\b(?:uma|um|one|a)\s+(?:fileira|row)\s+de\s+(\d+)',
task_description.lower())
if not numbers:
return code
expected_count = int(numbers[0])
# Encontrar e corrigir range() incorreto
def replace_range(match):
current_count = int(match.group(1))
if current_count != expected_count:
self.corrections_applied.append(
f"Fixed quantity: range({current_count}) -> range({expected_count})"
)
return f'range({expected_count})'
return match.group(0)
code = re.sub(r'range\((\d+)\)', replace_range, code)
return code
def _fix_additional_reset(self, code: str) -> str:
"""Adiciona self.additional_reset() se faltar no __init__"""
if '__init__' in code and 'additional_reset()' not in code:
# Encontrar fim do __init__
init_match = re.search(r'def __init__\(self\):.*?(?=\n def |\nclass |\Z)', code, re.DOTALL)
if init_match:
init_body = init_match.group(0)
# Adicionar antes do fim
if 'self.additional_reset()' not in init_body:
# Adicionar após super().__init__()
if 'super().__init__()' in init_body:
code = code.replace(
'super().__init__()',
'super().__init__()\n self.additional_reset()'
)
self.corrections_applied.append("Added self.additional_reset() to __init__")
return code
def _fix_quaternion(self, code: str) -> str:
"""Substitui p.getQuaternionFromEuler por (0, 0, 0, 1)"""
if 'p.getQuaternionFromEuler' in code:
# Substituir chamadas simples
code = re.sub(
r'p\.getQuaternionFromEuler\(\[0,\s*0,\s*0\]\)',
'(0, 0, 0, 1)',
code
)
self.corrections_applied.append("Replaced p.getQuaternionFromEuler with (0, 0, 0, 1)")
return code
def _fix_step_max_reward(self, code: str) -> str:
"""Corrige step_max_reward para goal único"""
# Contar número de add_goal() calls
goal_count = len(re.findall(r'\.add_goal\(', code))
if goal_count == 1:
# Se há apenas 1 goal, step_max_reward deve ser 1.0
pattern = r'step_max_reward\s*=\s*1\s*/\s*\d+'
if re.search(pattern, code):
code = re.sub(
r'step_max_reward\s*=\s*1\s*/\s*\d+',
'step_max_reward=1.0',
code
)
self.corrections_applied.append("Fixed step_max_reward: 1/N -> 1.0 for single goal")
return code
def auto_correct_code(code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
"""
Função de conveniência para correção automática
Args:
code: Código a ser corrigido
task_description: Descrição da tarefa (opcional)
Returns:
(corrected_code, corrections_applied)
"""
corrector = CodeCorrector()
return corrector.correct(code, task_description)
|