Spaces:
Runtime error
Runtime error
File size: 4,694 Bytes
2f3e169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Modified from Dream repos: https://github.com/HKUNLP/Dream
"""Post-processing LLM-generated Python code implemented using tree-sitter."""
import os
import sys
import pathlib
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))])
import ast
import traceback
from typing import Dict, List, Optional, Set, Tuple
def refine_text(text: str) -> str:
text = text.replace("\t", " ")
text = text.replace("\r\n", "\n").replace("\r", "\n")
return text.strip() + "\n"
def syntax_check(code, verbose = False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def extract_longest_valid_code(text: str) -> str:
lines = text.splitlines()
if len(lines) > 100:
lines = lines[:100]
max_valid_lines = 0
max_valid_snippet = ""
for i in range(len(lines)):
for j in range(i, len(lines)):
current_snippet = "\n".join(lines[i:j+1])
if syntax_check(current_snippet):
valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
if valid_line_count > max_valid_lines:
max_valid_lines = valid_line_count
max_valid_snippet = current_snippet
return max_valid_snippet
def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
name2deps = {}
for name, node in nodes:
deps = set()
stack = [node]
while stack:
current = stack.pop()
for child in ast.iter_child_nodes(current):
if isinstance(child, ast.Name):
deps.add(child.id)
elif isinstance(child, ast.Attribute):
deps.add(child.attr)
else:
stack.append(child)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
visited = set()
to_visit = [entrypoint]
while to_visit:
current = to_visit.pop(0)
if current not in visited:
visited.add(current)
to_visit.extend(call_graph.get(current, set()) - visited)
return visited
def get_definition_name(node: ast.AST) -> Optional[str]:
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
return node.name
elif isinstance(node, ast.Assign):
targets = node.targets
if targets and isinstance(targets[0], ast.Name):
return targets[0].id
return None
def has_return_statement(node: ast.AST) -> bool:
return any(isinstance(n, ast.Return) for n in ast.walk(node))
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
text = refine_text(text)
# text = python_extract(text)
code = extract_longest_valid_code(text)
tree = ast.parse(code)
definitions = {}
imports = []
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
elif isinstance(node, ast.ClassDef):
name = node.name
definitions[name] = ('class', node)
elif isinstance(node, ast.FunctionDef):
name = node.name
if has_return_statement(node):
definitions[name] = ('function', node)
elif isinstance(node, ast.Assign):
name = get_definition_name(node)
if name:
definitions[name] = ('variable', node)
if entrypoint:
name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
reachable = get_function_dependency(entrypoint, name2deps)
sanitized_output = []
for node in imports:
sanitized_output.append(ast.unparse(node))
for name, (_, node) in definitions.items():
if not entrypoint or name in reachable:
sanitized_output.append(ast.unparse(node))
return "\n".join(sanitized_output) |