File size: 2,308 Bytes
96638b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import re
import ast

class QueryValidator:
    """
    Validates generated Python/Pandas code for safety.
    Prevents execution of dangerous operations using AST analysis and keyword blocking.
    """
    
    BLOCKED_IMPORTS = ['os', 'sys', 'subprocess', 'shutil', 'pickle', 'importlib']
    BLOCKED_CALLS = ['open', 'eval', 'exec', 'compile', 'input', 'exit', 'quit']
    ALLOWED_MODULES = ['pd', 'np', 'pandas', 'numpy']
    
    @staticmethod
    def validate(code: str) -> bool:
        """
        Validates the provided code string.
        Returns: (is_safe: bool, reason: str)
        """
        try:
            tree = ast.parse(code)
        except SyntaxError:
            return False, "Syntax Error in generated code"
            
        for node in ast.walk(tree):
            # Check for imports
            if isinstance(node, (ast.Import, ast.ImportFrom)):
                for alias in node.names:
                    if alias.name.split('.')[0] in QueryValidator.BLOCKED_IMPORTS:
                        return False, f"Blocked import: {alias.name}"
                        
            # Check for function calls
            if isinstance(node, ast.Call):
                if isinstance(node.func, ast.Name):
                    if node.func.id in QueryValidator.BLOCKED_CALLS:
                        return False, f"Blocked function call: {node.func.id}"
                # Check for attribute calls (e.g., os.system)
                elif isinstance(node.func, ast.Attribute):
                    if isinstance(node.func.value, ast.Name):
                         if node.func.value.id in QueryValidator.BLOCKED_IMPORTS:
                             return False, f"Blocked attribute call on: {node.func.value.id}"

        return True, "Code is safe"

    @staticmethod
    def clean_code(llm_response: str) -> str:
        """Extracts code from markdown blocks if present"""
        code_match = re.search(r"```python\s*(.*?)```", llm_response, re.DOTALL)
        if code_match:
            return code_match.group(1).strip()
        
        # Fallback: remove python prefix if just ```
        code_match = re.search(r"```\s*(.*?)```", llm_response, re.DOTALL)
        if code_match:
            return code_match.group(1).strip()
            
        return llm_response.strip()