final-python-env / utils /ast_parser.py
uvpatel7271's picture
Upload folder using huggingface_hub
019e7db verified
raw
history blame
9.43 kB
"""AST-based parsing helpers for Python code review."""
from __future__ import annotations
import ast
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class _StructureVisitor(ast.NodeVisitor):
"""Collect lightweight structural signals from Python source."""
imports: set[str] = field(default_factory=set)
route_decorators: set[str] = field(default_factory=set)
function_names: list[str] = field(default_factory=list)
class_names: list[str] = field(default_factory=list)
code_smells: list[str] = field(default_factory=list)
branch_count: int = 0
max_loop_depth: int = 0
max_nesting_depth: int = 0
current_loop_depth: int = 0
current_nesting_depth: int = 0
recursive_functions: set[str] = field(default_factory=set)
current_function: str | None = None
docstring_total: int = 0
docstring_with_docs: int = 0
backward_calls: int = 0
optimizer_step_calls: int = 0
container_builds: int = 0
def visit_Import(self, node: ast.Import) -> None: # noqa: N802
for alias in node.names:
self.imports.add(alias.name.split(".")[0])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # noqa: N802
if node.module:
self.imports.add(node.module.split(".")[0])
self.generic_visit(node)
def _push_nesting(self) -> None:
self.current_nesting_depth += 1
self.max_nesting_depth = max(self.max_nesting_depth, self.current_nesting_depth)
def _pop_nesting(self) -> None:
self.current_nesting_depth = max(0, self.current_nesting_depth - 1)
def _visit_loop(self, node: ast.AST) -> None:
self.branch_count += 1
self.current_loop_depth += 1
self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
self._push_nesting()
self.generic_visit(node)
self._pop_nesting()
self.current_loop_depth = max(0, self.current_loop_depth - 1)
def visit_For(self, node: ast.For) -> None: # noqa: N802
self._visit_loop(node)
def visit_AsyncFor(self, node: ast.AsyncFor) -> None: # noqa: N802
self._visit_loop(node)
def visit_While(self, node: ast.While) -> None: # noqa: N802
self._visit_loop(node)
def visit_If(self, node: ast.If) -> None: # noqa: N802
self.branch_count += 1
self._push_nesting()
self.generic_visit(node)
self._pop_nesting()
def visit_Try(self, node: ast.Try) -> None: # noqa: N802
self.branch_count += 1
self._push_nesting()
self.generic_visit(node)
self._pop_nesting()
def visit_With(self, node: ast.With) -> None: # noqa: N802
self._push_nesting()
self.generic_visit(node)
self._pop_nesting()
def visit_AsyncWith(self, node: ast.AsyncWith) -> None: # noqa: N802
self._push_nesting()
self.generic_visit(node)
self._pop_nesting()
def visit_comprehension(self, node: ast.comprehension) -> None: # noqa: N802
self._visit_loop(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802
self.function_names.append(node.name)
self.docstring_total += 1
if ast.get_docstring(node):
self.docstring_with_docs += 1
prior = self.current_function
self.current_function = node.name
for decorator in node.decorator_list:
decorator_name = self._decorator_name(decorator)
if decorator_name in {"get", "post", "put", "patch", "delete"}:
self.route_decorators.add(decorator_name)
self.generic_visit(node)
self.current_function = prior
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802
self.visit_FunctionDef(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None: # noqa: N802
self.class_names.append(node.name)
self.generic_visit(node)
def visit_Call(self, node: ast.Call) -> None: # noqa: N802
dotted_name = self._call_name(node.func)
if dotted_name.endswith(".backward") or dotted_name == "backward":
self.backward_calls += 1
if dotted_name.endswith(".step") or dotted_name == "step":
if "optimizer" in dotted_name:
self.optimizer_step_calls += 1
if dotted_name in {"list", "dict", "set", "tuple"}:
self.container_builds += 1
if self.current_function and dotted_name == self.current_function:
self.recursive_functions.add(self.current_function)
self.generic_visit(node)
@staticmethod
def _call_name(node: ast.AST) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
left = _StructureVisitor._call_name(node.value)
return f"{left}.{node.attr}" if left else node.attr
return ""
@staticmethod
def _decorator_name(node: ast.AST) -> str:
if isinstance(node, ast.Call):
return _StructureVisitor._decorator_name(node.func)
if isinstance(node, ast.Attribute):
return node.attr.lower()
if isinstance(node, ast.Name):
return node.id.lower()
return ""
def _line_smells(lines: list[str]) -> tuple[int, list[int], bool]:
long_lines = sum(1 for line in lines if len(line) > 88)
trailing_whitespace_lines = [index + 1 for index, line in enumerate(lines) if line.rstrip() != line]
tabs_used = any("\t" in line for line in lines)
return long_lines, trailing_whitespace_lines, tabs_used
def parse_code_structure(code: str) -> dict[str, Any]:
"""Extract deterministic syntax, import, and structure signals from Python code."""
normalized_code = code or ""
lines = normalized_code.splitlines()
long_lines, trailing_whitespace_lines, tabs_used = _line_smells(lines)
result: dict[str, Any] = {
"syntax_valid": True,
"syntax_error": "",
"line_count": len(lines),
"imports": [],
"function_names": [],
"class_names": [],
"long_lines": long_lines,
"trailing_whitespace_lines": trailing_whitespace_lines,
"tabs_used": tabs_used,
"docstring_ratio": 0.0,
"uses_recursion": False,
"max_loop_depth": 0,
"max_nesting_depth": 0,
"route_decorators": [],
"code_smells": [],
"uses_pandas": False,
"uses_numpy": False,
"uses_torch": False,
"uses_sklearn": False,
"uses_fastapi": False,
"uses_flask": False,
"uses_pydantic": False,
"calls_backward": False,
"calls_optimizer_step": False,
"branch_count": 0,
"container_builds": 0,
}
try:
tree = ast.parse(normalized_code or "\n")
except SyntaxError as exc:
result["syntax_valid"] = False
result["syntax_error"] = f"{exc.msg} (line {exc.lineno}, column {exc.offset})"
result["code_smells"] = ["Code does not parse.", "Fix syntax before deeper review."]
return result
visitor = _StructureVisitor()
visitor.visit(tree)
imports = sorted(visitor.imports)
uses_pandas = "pandas" in imports or "pd" in normalized_code
uses_numpy = "numpy" in imports or "np." in normalized_code
uses_torch = "torch" in imports or "torch." in normalized_code
uses_sklearn = "sklearn" in imports
uses_fastapi = "fastapi" in imports
uses_flask = "flask" in imports
uses_pydantic = "pydantic" in imports or "BaseModel" in normalized_code
code_smells = list(visitor.code_smells)
if visitor.max_loop_depth >= 2:
code_smells.append("Nested loops may create avoidable performance pressure.")
if long_lines:
code_smells.append("Long lines reduce readability and reviewability.")
if trailing_whitespace_lines:
code_smells.append("Trailing whitespace suggests style drift.")
if visitor.docstring_total and visitor.docstring_with_docs == 0:
code_smells.append("Public functions are missing docstrings.")
if not visitor.function_names:
code_smells.append("Encapsulate behavior in functions for testability.")
result.update(
{
"imports": imports,
"function_names": visitor.function_names,
"class_names": visitor.class_names,
"docstring_ratio": round(
visitor.docstring_with_docs / max(visitor.docstring_total, 1),
4,
),
"uses_recursion": bool(visitor.recursive_functions),
"max_loop_depth": visitor.max_loop_depth,
"max_nesting_depth": visitor.max_nesting_depth,
"route_decorators": sorted(visitor.route_decorators),
"code_smells": code_smells,
"uses_pandas": uses_pandas,
"uses_numpy": uses_numpy,
"uses_torch": uses_torch,
"uses_sklearn": uses_sklearn,
"uses_fastapi": uses_fastapi,
"uses_flask": uses_flask,
"uses_pydantic": uses_pydantic,
"calls_backward": visitor.backward_calls > 0,
"calls_optimizer_step": visitor.optimizer_step_calls > 0,
"branch_count": visitor.branch_count,
"container_builds": visitor.container_builds,
}
)
return result