Spaces:
Runtime error
Runtime error
| """Purpose of this file: Sanitize the code produced by LLMs for the following reasons. | |
| 1. Vicuna generated code could miss one white space. We fix the white space to make Vicuna more capable. | |
| 2. {Our fault lol.} We find more EOFs tokens afterwards and truncate some messy code afterwards. | |
| """ | |
| import ast | |
| import re | |
| import traceback | |
| from typing import List, Optional | |
| def syntax_check(code, verbose=False): | |
| try: | |
| ast.parse(code) | |
| return True | |
| except (SyntaxError, MemoryError): | |
| if verbose: | |
| traceback.print_exc() | |
| return False | |
| def remove_unindented_lines(code, protect_before, execeptions, trim_tails): | |
| lines = code.splitlines() | |
| cut_idx = [] | |
| cut_enabled = False | |
| for i, line in enumerate(lines): | |
| if not cut_enabled and line.startswith(protect_before): | |
| cut_enabled = True | |
| continue | |
| if line.strip() == "": | |
| continue | |
| if any(line.startswith(e) for e in execeptions): | |
| continue | |
| lspace = len(line) - len(line.lstrip()) | |
| if lspace == 0: | |
| cut_idx.append(i) | |
| if any(line.rstrip().startswith(t) for t in trim_tails): | |
| # cut off everything behind | |
| cut_idx.extend(list(range(i, len(lines)))) | |
| break | |
| return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) | |
| def to_four_space_indents(old_code): | |
| new_code = "" | |
| for line in old_code.splitlines(): | |
| lspace = len(line) - len(line.lstrip()) | |
| if lspace == 3: | |
| new_code += " " | |
| new_code += line + "\n" | |
| return new_code | |
| def sanitize( | |
| old_code: str, | |
| entry_point: str, | |
| rm_prefix_lines: Optional[str] = None, | |
| eofs: List = None, | |
| ): | |
| new_code = old_code | |
| if rm_prefix_lines is not None: | |
| new_code = "\n".join( | |
| [ | |
| line | |
| for line in old_code.splitlines() | |
| if not line.startswith(rm_prefix_lines) | |
| ] | |
| ) | |
| new_code = "\n" + new_code | |
| def_left = "def " + entry_point | |
| # basic handling of chat output | |
| new_code = new_code.replace("\n```python\n", "\n```\n") | |
| for chunk in new_code.split("\n```\n"): | |
| if def_left in chunk: | |
| new_code = chunk | |
| break | |
| chunks = [chunk for chunk in re.split(f"{def_left}\s*\(", new_code)] | |
| # TODO: having return does not mean this is complete | |
| bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] | |
| def_left = def_left + "(" | |
| new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl | |
| new_code = to_four_space_indents(new_code) | |
| for eof in eofs or []: | |
| new_code = new_code.split(eof)[0] | |
| # remove lines starting from the first unindented line after def_left | |
| new_code = remove_unindented_lines( | |
| new_code, | |
| protect_before=def_left, | |
| execeptions=["def ", "import ", "from "], | |
| trim_tails=['"""', "if", "print"], | |
| ) | |
| new_code = chunks[0] + new_code | |
| # cut all functions that are not syntactically correct && not the entry point | |
| parts = new_code.split("\ndef ") | |
| includes = [parts[0]] | |
| for fn in new_code.split("\ndef ")[1:]: | |
| if ( | |
| fn.strip().startswith(entry_point + " ") | |
| or fn.strip().startswith(entry_point + "(") | |
| or syntax_check("\ndef " + fn) | |
| ): | |
| includes.append(fn) | |
| new_code = "\ndef ".join(includes) | |
| return new_code.strip() | |