File size: 4,475 Bytes
f0e2e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Enhanced Code validation utilities for Manim code generation
"""

import re
import ast
import logging
from typing import Tuple, Optional, Dict, Any, List

logger = logging.getLogger(__name__)

class CodeValidator:
    """Validates and fixes Manim code generation issues."""
    
    @staticmethod
    def fix_common_issues(code: str) -> str:
        """Fix common issues in generated Manim code."""
        # Remove markdown formatting
        code = re.sub(r'```python\n?', '', code)
        code = re.sub(r'```\n?', '', code)
        code = code.replace('```', '')
        code = code.replace('`', '')
        
        # Fix string literals in Text objects
        code = re.sub(r'Text\(([^"\'()\[\]]*?)\)', r'Text("\1")', code)
        code = re.sub(r'Text\(([^"\']*?), font_size', r'Text("\1", font_size', code)
        
        # Fix unterminated string literals
        lines = code.split('\n')
        fixed_lines = []
        for line in lines:
            if 'Text(' in line and line.count('"') % 2 != 0:
                # Add closing quote before comma or parenthesis
                line = re.sub(r'([^"]*?)(,|\))', r'\1"\2', line)
            fixed_lines.append(line)
        
        code = '\n'.join(fixed_lines)
        
        # Ensure proper imports
        if not re.search(r'from manim import \*', code):
            code = 'from manim import *\n\n' + code
        
        return code.strip()
    
    @staticmethod
    def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]:
        """Validate Python syntax of the generated code."""
        try:
            ast.parse(code)
            return True, None
        except SyntaxError as e:
            return False, str(e)
    
    @staticmethod
    def has_scene_class(code: str) -> bool:
        """Check if the code has a proper Scene class."""
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef):
                    # Check if it inherits from Scene
                    for base in node.bases:
                        if isinstance(base, ast.Name) and base.id == 'Scene':
                            return True
                        elif isinstance(base, ast.Attribute) and base.attr == 'Scene':
                            return True
            return False
        except SyntaxError:
            return False
    
    @staticmethod
    def has_construct_method(code: str) -> bool:
        """Check if the Scene class has a construct method."""
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef):
                    for item in node.body:
                        if isinstance(item, ast.FunctionDef) and item.name == 'construct':
                            return True
            return False
        except SyntaxError:
            return False
    
    @staticmethod
    def has_animations(code: str) -> bool:
        """Check if the code has at least one animation."""
        required_patterns = [
            r'self\.play\(',
            r'self\.wait\('
        ]
        
        for pattern in required_patterns:
            if not re.search(pattern, code):
                return False
        return True
    
    @staticmethod
    def validate_manim_code(code: str) -> Tuple[bool, Optional[str]]:
        """Validate that the code meets Manim requirements."""
        # Check basic structure
        if not CodeValidator.has_scene_class(code):
            return False, "Code must contain a class that inherits from Scene"
        
        if not CodeValidator.has_construct_method(code):
            return False, "Scene class must have a construct method"
        
        if not CodeValidator.has_animations(code):
            return False, "Code must include at least one animation (self.play) and wait"
        
        # Check for problematic patterns
        problematic_patterns = [
            r'for\s+\*',
            r'\*\s*=',
            r'^\s*\*',  # Lines starting with *
            r'get\*center',
            r'import\s+\*',  # No wildcard imports except manim
            r'from\s+[^\s]+\s+import\s+\*'  # No wildcard imports except manim
        ]
        
        for pattern in problematic_patterns:
            if re.search(pattern, code, re.MULTILINE):
                return False, f"Code contains problematic pattern: {pattern}"
        
        return True, None