File size: 6,836 Bytes
24c2665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import ast
import re
from typing import List
def parse_imports(code_snippet: str) -> List[str]:
imports = []
try:
tree = ast.parse(code_snippet)
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
# Reconstruct import line from AST node
if isinstance(node, ast.Import):
import_line = "import " + ", ".join(
[alias.name + (f" as {alias.asname}" if alias.asname else "")
for alias in node.names]
)
else:
module = node.module or ""
import_line = f"from {module} import " + ", ".join(
[alias.name + (f" as {alias.asname}" if alias.asname else "")
for alias in node.names]
)
if node.level > 0:
import_line = f"from {'.' * node.level}{module} import " + ", ".join(
[alias.name + (f" as {alias.asname}" if alias.asname else "")
for alias in node.names]
)
imports.append(import_line)
except Exception as e:
import_pattern = r"^\s*(?:from|import)\s+.*$"
imports = [i.strip() for i in re.findall(import_pattern, code_snippet, re.MULTILINE)]
return imports
def parse_error(error_message: str) -> str:
# split by colon
error_message = error_message.split(':')[0]
return error_message.strip()
def replace_main_function_name(code: str, old_name: str, new_name: str) -> str:
"""
Replace all occurrences of `old_name` with `new_name` in the code.
Replace the definition and all recursive calls of `old_name` with `new_name`.
"""
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == old_name:
node.name = new_name
elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == old_name:
node.func.id = new_name
return ast.unparse(tree)
def remove_comments_and_docstrings(code: str) -> str:
"""
Remove all comments and docstrings from the code.
"""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module)):
# Remove all leading docstrings
while node.body and isinstance(node.body[0], ast.Expr):
expr = node.body[0].value
if isinstance(expr, (ast.Str, ast.Constant)) and (
isinstance(expr.value, str) if isinstance(expr, ast.Constant) else True
):
node.body.pop(0)
else:
break
# Convert back to code - AST unparse already removes comments
code_without_docstrings = ast.unparse(tree)
# Only remove empty lines and trim whitespace
lines = [
line.rstrip()
for line in code_without_docstrings.split('\n')
if line.strip()
]
return '\n'.join(lines)
except Exception as e:
return code # Return original code if parsing fails
def remove_any_not_definition_imports(code: str) -> str:
"""
Remove anything that is not a definition or import.
Preserves:
- Import/From imports
- Class definitions
- Function/AsyncFunction definitions
Removes:
- Top-level assignments
- Standalone expressions
- Constant declarations
"""
class DefinitionFilter(ast.NodeTransformer):
def visit_Module(self, node):
# Keep only definitions and imports (explicitly exclude assignments)
node.body = [
n for n in node.body
if isinstance(n, (
ast.Import,
ast.ImportFrom,
ast.FunctionDef,
ast.AsyncFunctionDef,
ast.ClassDef
))
]
return node
try:
tree = ast.parse(code)
tree = DefinitionFilter().visit(tree)
ast.fix_missing_locations(tree)
# Remove empty lines and format
cleaned = ast.unparse(tree)
return '\n'.join([line for line in cleaned.split('\n') if line.strip()])
except Exception as e:
return code
class PrintRemover(ast.NodeTransformer):
def visit_Expr(self, node):
# Handle top-level print statements
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'print':
return None
return node
def visit_Call(self, node):
# Handle print calls in other contexts (like assignments)
if isinstance(node.func, ast.Name) and node.func.id == 'print':
return ast.Constant(value=None)
return node
def _handle_block(self, node):
self.generic_visit(node)
if not node.body:
node.body.append(ast.Pass())
return node
def visit_For(self, node):
return self._handle_block(node)
def visit_While(self, node):
return self._handle_block(node)
def visit_FunctionDef(self, node):
return self._handle_block(node)
def visit_AsyncFunctionDef(self, node):
return self._handle_block(node)
def visit_If(self, node):
return self._handle_block(node)
def visit_With(self, node):
return self._handle_block(node)
def visit_Try(self, node):
self.generic_visit(node)
# Handle main try body
if not node.body:
node.body.append(ast.Pass())
# Handle except handlers
for handler in node.handlers:
if not handler.body:
handler.body.append(ast.Pass())
# Handle else clause
if node.orelse and not node.orelse:
node.orelse.append(ast.Pass())
# Handle finally clause
if node.finalbody and not node.finalbody:
node.finalbody.append(ast.Pass())
return node
def remove_print_statements(code: str) -> str:
"""
Remove all print statements from the code.
"""
tree = ast.parse(code)
tree = PrintRemover().visit(tree)
ast.fix_missing_locations(tree)
return ast.unparse(tree)
if __name__ == "__main__":
print(parse_error("NameError: name 'x' is not defined"))
print(parse_error("TypeError: unsupported operand type(s) for -: 'str' and 'str'"))
print(parse_error("ValueError: invalid literal for int() with base 10: 'x'"))
|