File size: 2,365 Bytes
b19c92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import ast, astor, javalang, re, yaml, black, isort
from .utils import run_cmd
class RefactorEngine:
    def __init__(self, rules_path='config/refactor_rules.yaml'):
        with open(rules_path) as f: self.rules=yaml.safe_load(f)
    def format_python(self,code):
        try: return isort.code(black.format_str(code,mode=black.FileMode()))
        except: return code
    def remove_unused_imports_python(self,code):
        try:
            t=ast.parse(code);im=[n for n in t.body if isinstance(n,(ast.Import,ast.ImportFrom))]
            used={n.id for n in ast.walk(t) if isinstance(n,ast.Name)}
            keep=[i for i in im if any((a.asname or a.name.split('.')[0]) in used for a in i.names)]
            t.body=keep+[n for n in t.body if n not in im];return astor.to_source(t)
        except: return code
    def inline_simple_functions_python(self,code):
        try:
            t=ast.parse(code);funcs={}
            for n in t.body:
                if isinstance(n,ast.FunctionDef) and len(n.body)==1 and isinstance(n.body[0],ast.Return):
                    funcs[n.name]=astor.to_source(n.body[0].value).strip()
            out=code
            for f,b in funcs.items(): out=re.sub(rf'\b{f}\(\)',b,out)
            return out
        except: return code
    def refactor_python(self,code):
        r=self.rules.get('python',{}); 
        if r.get('remove_unused_imports'): code=self.remove_unused_imports_python(code)
        if r.get('inline_simple_functions'): code=self.inline_simple_functions_python(code)
        if r.get('format'): code=self.format_python(code)
        return code
    def convert_java_for_each(self,code):
        return re.sub(r'for \(int (\w+)=0; \1 < (\w+).size\(\); \1\+\+\)', r'for (var x : \2)', code)
    def refactor_java(self,code):
        r=self.rules.get('java',{})
        if r.get('convert_for_each'): code=self.convert_java_for_each(code)
        return code
    def refactor_javascript(self,code):
        r=self.rules.get('javascript',{})
        if r.get('convert_var_to_let'): code=code.replace('var ','let ')
        return code
    def refactor(self,code,lang):
        lang=lang.lower()
        if lang=='python': return self.refactor_python(code)
        if lang=='java': return self.refactor_java(code)
        if lang in ('js','javascript'): return self.refactor_javascript(code)
        return code