File size: 9,200 Bytes
13b4740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
Correção automática de erros comuns em código gerado.
Sistema híbrido: corrige erros estruturais automaticamente antes de retry.
"""
import re
from typing import Tuple, List, Optional


class CodeCorrector:
    """Corrige erros comuns automaticamente"""
    
    def __init__(self):
        self.corrections_applied = []
    
    def correct(self, code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
        """
        Corrige erros comuns no código
        
        Args:
            code: Código a ser corrigido
            task_description: Descrição da tarefa para correções contextuais
        
        Returns:
            (corrected_code, list_of_corrections_applied)
        """
        self.corrections_applied = []
        corrected = code
        
        # 1. Corrigir imports duplicados
        corrected = self._fix_duplicate_imports(corrected)
        
        # 2. Corrigir base_pose incorreto
        corrected = self._fix_base_pose(corrected)
        
        # 3. Corrigir utils.apply() sem rotação
        corrected = self._fix_utils_apply(corrected)
        
        # 4. Corrigir def init para __init__
        corrected = self._fix_init_method(corrected)
        
        # 5. Corrigir loops sem variável
        corrected = self._fix_for_loops(corrected)
        
        # 6. Corrigir quantidade de objetos
        if task_description:
            corrected = self._fix_quantity(corrected, task_description)
        
        # 7. Adicionar additional_reset() se faltar
        corrected = self._fix_additional_reset(corrected)
        
        # 8. Corrigir p.getQuaternionFromEuler
        corrected = self._fix_quaternion(corrected)
        
        # 9. Corrigir step_max_reward para goal único
        corrected = self._fix_step_max_reward(corrected)
        
        return corrected, self.corrections_applied
    
    def _fix_duplicate_imports(self, code: str) -> str:
        """Remove imports duplicados"""
        lines = code.split('\n')
        seen_imports = set()
        result = []
        
        for line in lines:
            stripped = line.strip()
            if stripped.startswith('import ') or stripped.startswith('from '):
                normalized = re.sub(r'\s+', ' ', stripped)
                if normalized not in seen_imports:
                    seen_imports.add(normalized)
                    result.append(line)
                else:
                    self.corrections_applied.append(f"Removed duplicate import: {stripped}")
            else:
                result.append(line)
        
        return '\n'.join(result)
    
    def _fix_base_pose(self, code: str) -> str:
        """Corrige base_pose = (x, y, z) para base_pose = ((x, y, z), (0, 0, 0, 1))"""
        # Padrão: base_pose = (0.5, 0, 0.02) sem rotação
        pattern = r'base_pose\s*=\s*\(([^)]+)\)\s*(?=\n|#|$)'
        
        def replace_pose(match):
            coords = match.group(1).strip()
            # Verificar se já é uma pose completa
            if coords.count('(') > 0 and coords.count(')') > 0:
                return match.group(0)  # Já está correto
            
            # Extrair coordenadas
            coords_match = re.match(r'([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)\s*,\s*([0-9.eE+-]+)', coords)
            if coords_match:
                x, y, z = coords_match.groups()
                self.corrections_applied.append(f"Fixed base_pose: added rotation tuple")
                return f'base_pose = (({x}, {y}, {z}), (0, 0, 0, 1))'
            return match.group(0)
        
        corrected = re.sub(pattern, replace_pose, code)
        return corrected
    
    def _fix_utils_apply(self, code: str) -> str:
        """Corrige utils.apply() sem rotação"""
        # Padrão: targs = [utils.apply(base_pose, offset) for offset in offsets]
        # Deve ser: targs = [(utils.apply(base_pose, offset), base_pose[1]) for offset in offsets]
        
        pattern = r'(\w+)\s*=\s*\[utils\.apply\(([^)]+)\)\s+for\s+([^\]]+)\]'
        
        def replace_apply(match):
            var_name = match.group(1)
            apply_args = match.group(2)
            loop_vars = match.group(3)
            
            # Verificar se já tem rotação
            context = code[max(0, code.find(match.group(0))-100):code.find(match.group(0))+100]
            if 'base_pose[1]' in context or '(0, 0, 0, 1)' in context:
                return match.group(0)  # Já está correto
            
            self.corrections_applied.append(f"Fixed utils.apply(): added rotation")
            return f'{var_name} = [(utils.apply({apply_args}), base_pose[1]) for {loop_vars}]'
        
        corrected = re.sub(pattern, replace_apply, code)
        return corrected
    
    def _fix_init_method(self, code: str) -> str:
        """Corrige def init(self): para def __init__(self):"""
        if 'def init(self):' in code and '__init__' not in code:
            self.corrections_applied.append("Fixed: def init() -> def __init__()")
            code = code.replace('def init(self):', 'def __init__(self):')
        return code
    
    def _fix_for_loops(self, code: str) -> str:
        """Corrige for  in range para for _ in range"""
        pattern = r'for\s+in\s+range'
        if re.search(pattern, code):
            self.corrections_applied.append("Fixed: for loop missing variable")
            code = re.sub(r'for\s+in\s+range', 'for _ in range', code)
        return code
    
    def _fix_quantity(self, code: str, task_description: str) -> str:
        """Corrige quantidade de objetos baseado na descrição"""
        # Extrair número esperado
        numbers = re.findall(r'\b(\d+)\s+(?:blocos?|blocks?|objetos?|objects?)\b', 
                            task_description.lower())
        if not numbers:
            numbers = re.findall(r'\b(?:uma|um|one|a)\s+(?:fileira|row)\s+de\s+(\d+)', 
                                task_description.lower())
        
        if not numbers:
            return code
        
        expected_count = int(numbers[0])
        
        # Encontrar e corrigir range() incorreto
        def replace_range(match):
            current_count = int(match.group(1))
            if current_count != expected_count:
                self.corrections_applied.append(
                    f"Fixed quantity: range({current_count}) -> range({expected_count})"
                )
                return f'range({expected_count})'
            return match.group(0)
        
        code = re.sub(r'range\((\d+)\)', replace_range, code)
        return code
    
    def _fix_additional_reset(self, code: str) -> str:
        """Adiciona self.additional_reset() se faltar no __init__"""
        if '__init__' in code and 'additional_reset()' not in code:
            # Encontrar fim do __init__
            init_match = re.search(r'def __init__\(self\):.*?(?=\n    def |\nclass |\Z)', code, re.DOTALL)
            if init_match:
                init_body = init_match.group(0)
                # Adicionar antes do fim
                if 'self.additional_reset()' not in init_body:
                    # Adicionar após super().__init__()
                    if 'super().__init__()' in init_body:
                        code = code.replace(
                            'super().__init__()',
                            'super().__init__()\n        self.additional_reset()'
                        )
                        self.corrections_applied.append("Added self.additional_reset() to __init__")
        return code
    
    def _fix_quaternion(self, code: str) -> str:
        """Substitui p.getQuaternionFromEuler por (0, 0, 0, 1)"""
        if 'p.getQuaternionFromEuler' in code:
            # Substituir chamadas simples
            code = re.sub(
                r'p\.getQuaternionFromEuler\(\[0,\s*0,\s*0\]\)',
                '(0, 0, 0, 1)',
                code
            )
            self.corrections_applied.append("Replaced p.getQuaternionFromEuler with (0, 0, 0, 1)")
        return code
    
    def _fix_step_max_reward(self, code: str) -> str:
        """Corrige step_max_reward para goal único"""
        # Contar número de add_goal() calls
        goal_count = len(re.findall(r'\.add_goal\(', code))
        
        if goal_count == 1:
            # Se há apenas 1 goal, step_max_reward deve ser 1.0
            pattern = r'step_max_reward\s*=\s*1\s*/\s*\d+'
            if re.search(pattern, code):
                code = re.sub(
                    r'step_max_reward\s*=\s*1\s*/\s*\d+',
                    'step_max_reward=1.0',
                    code
                )
                self.corrections_applied.append("Fixed step_max_reward: 1/N -> 1.0 for single goal")
        
        return code


def auto_correct_code(code: str, task_description: Optional[str] = None) -> Tuple[str, List[str]]:
    """
    Função de conveniência para correção automática
    
    Args:
        code: Código a ser corrigido
        task_description: Descrição da tarefa (opcional)
    
    Returns:
        (corrected_code, corrections_applied)
    """
    corrector = CodeCorrector()
    return corrector.correct(code, task_description)