|
|
import ast |
|
|
import re |
|
|
from typing import List |
|
|
|
|
|
|
|
|
def parse_imports(code_snippet: str) -> List[str]: |
|
|
imports = [] |
|
|
try: |
|
|
tree = ast.parse(code_snippet) |
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, (ast.Import, ast.ImportFrom)): |
|
|
|
|
|
if isinstance(node, ast.Import): |
|
|
import_line = "import " + ", ".join( |
|
|
[alias.name + (f" as {alias.asname}" if alias.asname else "") |
|
|
for alias in node.names] |
|
|
) |
|
|
else: |
|
|
module = node.module or "" |
|
|
import_line = f"from {module} import " + ", ".join( |
|
|
[alias.name + (f" as {alias.asname}" if alias.asname else "") |
|
|
for alias in node.names] |
|
|
) |
|
|
if node.level > 0: |
|
|
import_line = f"from {'.' * node.level}{module} import " + ", ".join( |
|
|
[alias.name + (f" as {alias.asname}" if alias.asname else "") |
|
|
for alias in node.names] |
|
|
) |
|
|
imports.append(import_line) |
|
|
except Exception as e: |
|
|
import_pattern = r"^\s*(?:from|import)\s+.*$" |
|
|
imports = [i.strip() for i in re.findall(import_pattern, code_snippet, re.MULTILINE)] |
|
|
return imports |
|
|
|
|
|
|
|
|
def parse_error(error_message: str) -> str: |
|
|
|
|
|
error_message = error_message.split(':')[0] |
|
|
return error_message.strip() |
|
|
|
|
|
|
|
|
def replace_main_function_name(code: str, old_name: str, new_name: str) -> str: |
|
|
""" |
|
|
Replace all occurrences of `old_name` with `new_name` in the code. |
|
|
Replace the definition and all recursive calls of `old_name` with `new_name`. |
|
|
""" |
|
|
tree = ast.parse(code) |
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, ast.FunctionDef) and node.name == old_name: |
|
|
node.name = new_name |
|
|
elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == old_name: |
|
|
node.func.id = new_name |
|
|
return ast.unparse(tree) |
|
|
|
|
|
|
|
|
def remove_comments_and_docstrings(code: str) -> str: |
|
|
""" |
|
|
Remove all comments and docstrings from the code. |
|
|
""" |
|
|
try: |
|
|
tree = ast.parse(code) |
|
|
for node in ast.walk(tree): |
|
|
if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module)): |
|
|
|
|
|
while node.body and isinstance(node.body[0], ast.Expr): |
|
|
expr = node.body[0].value |
|
|
if isinstance(expr, (ast.Str, ast.Constant)) and ( |
|
|
isinstance(expr.value, str) if isinstance(expr, ast.Constant) else True |
|
|
): |
|
|
node.body.pop(0) |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
code_without_docstrings = ast.unparse(tree) |
|
|
|
|
|
|
|
|
lines = [ |
|
|
line.rstrip() |
|
|
for line in code_without_docstrings.split('\n') |
|
|
if line.strip() |
|
|
] |
|
|
|
|
|
return '\n'.join(lines) |
|
|
except Exception as e: |
|
|
return code |
|
|
|
|
|
|
|
|
def remove_any_not_definition_imports(code: str) -> str: |
|
|
""" |
|
|
Remove anything that is not a definition or import. |
|
|
Preserves: |
|
|
- Import/From imports |
|
|
- Class definitions |
|
|
- Function/AsyncFunction definitions |
|
|
Removes: |
|
|
- Top-level assignments |
|
|
- Standalone expressions |
|
|
- Constant declarations |
|
|
""" |
|
|
class DefinitionFilter(ast.NodeTransformer): |
|
|
def visit_Module(self, node): |
|
|
|
|
|
node.body = [ |
|
|
n for n in node.body |
|
|
if isinstance(n, ( |
|
|
ast.Import, |
|
|
ast.ImportFrom, |
|
|
ast.FunctionDef, |
|
|
ast.AsyncFunctionDef, |
|
|
ast.ClassDef |
|
|
)) |
|
|
] |
|
|
return node |
|
|
|
|
|
try: |
|
|
tree = ast.parse(code) |
|
|
tree = DefinitionFilter().visit(tree) |
|
|
ast.fix_missing_locations(tree) |
|
|
|
|
|
|
|
|
cleaned = ast.unparse(tree) |
|
|
return '\n'.join([line for line in cleaned.split('\n') if line.strip()]) |
|
|
|
|
|
except Exception as e: |
|
|
return code |
|
|
|
|
|
|
|
|
class PrintRemover(ast.NodeTransformer): |
|
|
def visit_Expr(self, node): |
|
|
|
|
|
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'print': |
|
|
return None |
|
|
return node |
|
|
|
|
|
def visit_Call(self, node): |
|
|
|
|
|
if isinstance(node.func, ast.Name) and node.func.id == 'print': |
|
|
return ast.Constant(value=None) |
|
|
return node |
|
|
|
|
|
def _handle_block(self, node): |
|
|
self.generic_visit(node) |
|
|
if not node.body: |
|
|
node.body.append(ast.Pass()) |
|
|
return node |
|
|
|
|
|
def visit_For(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_While(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_FunctionDef(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_AsyncFunctionDef(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_If(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_With(self, node): |
|
|
return self._handle_block(node) |
|
|
|
|
|
def visit_Try(self, node): |
|
|
self.generic_visit(node) |
|
|
|
|
|
|
|
|
if not node.body: |
|
|
node.body.append(ast.Pass()) |
|
|
|
|
|
|
|
|
for handler in node.handlers: |
|
|
if not handler.body: |
|
|
handler.body.append(ast.Pass()) |
|
|
|
|
|
|
|
|
if node.orelse and not node.orelse: |
|
|
node.orelse.append(ast.Pass()) |
|
|
|
|
|
|
|
|
if node.finalbody and not node.finalbody: |
|
|
node.finalbody.append(ast.Pass()) |
|
|
|
|
|
return node |
|
|
|
|
|
|
|
|
def remove_print_statements(code: str) -> str: |
|
|
""" |
|
|
Remove all print statements from the code. |
|
|
""" |
|
|
tree = ast.parse(code) |
|
|
tree = PrintRemover().visit(tree) |
|
|
ast.fix_missing_locations(tree) |
|
|
return ast.unparse(tree) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(parse_error("NameError: name 'x' is not defined")) |
|
|
print(parse_error("TypeError: unsupported operand type(s) for -: 'str' and 'str'")) |
|
|
print(parse_error("ValueError: invalid literal for int() with base 10: 'x'")) |
|
|
|