File size: 5,212 Bytes
8953138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import ast
import re
import javalang

# Java Cleaner
def clean_java_code(code):
    code = re.sub(r'//.*', '', code)
    code = re.sub(r'/\*[\s\S]*?\*/', '', code)
    code = re.sub(r'^\s*import\s+.*;', '', code, flags=re.MULTILINE)
    code = re.sub(r'^\s*package\s+.*;', '', code, flags=re.MULTILINE)
    code = re.sub(r'\n\s*\n', '\n', code)
    return code.strip()

# Python Cleaner
def clean_python_code(code):
    code = re.sub(r'#.*', '', code)
    code = re.sub(r'""".*?"""', '', code, flags=re.DOTALL)
    code = re.sub(r"'''.*?'''", '', code, flags=re.DOTALL)
    code = re.sub(r'^\s*import\s+.*', '', code, flags=re.MULTILINE)
    code = re.sub(r'^\s*from\s+.*import.*', '', code, flags=re.MULTILINE)
    code = re.sub(r'\n\s*\n', '\n', code)
    return code.strip()


def clean_code(code, lang):
    if lang == 'python':
        return clean_python_code(code)
    else:
        return clean_java_code(code)
    

def get_python_features(code):
    try:
        tree = ast.parse(code)
    except:
        return [0, 0, 0, 0, 0]

    max_depth = 0
    branch_count = 0
    has_recursion = 0
    has_log_math = 0
    has_sort = 0

    current_functions = []

    class DepthVisitor(ast.NodeVisitor):
        def __init__(self):
            self.max_depth = 0
            self.current_depth = 0

        def visit_For(self, node):
            self.current_depth += 1
            self.max_depth = max(self.max_depth, self.current_depth)
            self.generic_visit(node)
            self.current_depth -= 1

        def visit_While(self, node):
            self.current_depth += 1
            self.max_depth = max(self.max_depth, self.current_depth)
            self.generic_visit(node)
            self.current_depth -= 1

        def visit_ListComp(self, node):
            self.current_depth += len(node.generators)
            self.max_depth = max(self.max_depth, self.current_depth)
            self.generic_visit(node)
            self.current_depth -= len(node.generators)

    depth_visitor = DepthVisitor()
    depth_visitor.visit(tree)
    max_depth = depth_visitor.max_depth

    for node in ast.walk(tree):
        # Branch Counting
        if isinstance(node, (ast.If, ast.While, ast.For, ast.AsyncFor, ast.ListComp)):
            branch_count += 1

        # Recursion & Sort Detection
        if isinstance(node, ast.FunctionDef):
            current_functions.append(node.name)

        if isinstance(node, ast.Call):
            # Recursion
            if isinstance(node.func, ast.Name) and node.func.id in current_functions:
                has_recursion = 1
            # Sort Detection: sorted(arr)
            if isinstance(node.func, ast.Name) and node.func.id == 'sorted':
                has_sort = 1
            # Sort Detection: arr.sort()
            if isinstance(node.func, ast.Attribute) and node.func.attr == 'sort':
                has_sort = 1

        # Logarithmic Math Detection
        if isinstance(node, ast.BinOp):
            if isinstance(node.op, (ast.Div, ast.FloorDiv, ast.RShift, ast.Mult, ast.LShift)):
                has_log_math = 1
        if isinstance(node, ast.AugAssign):
             if isinstance(node.op, (ast.Div, ast.FloorDiv, ast.RShift, ast.Mult, ast.LShift)):
                has_log_math = 1

    # Return 5 features
    return [max_depth, branch_count, has_recursion, has_log_math, has_sort]

def get_java_features(code):
    try:
        if "class " not in code:
             tokens = javalang.tokenizer.tokenize("class Dummy { " + code + " }")
        else:
             tokens = javalang.tokenizer.tokenize(code)
        parser = javalang.parser.Parser(tokens)
        tree = parser.parse_member_declaration()
    except:
        return [0, 0, 0, 0, 0]

    real_max_depth = 0
    branch_count = 0
    has_recursion = 0
    has_log_math = 0
    has_sort = 0

    # Max Depth
    for path, node in tree.filter(javalang.tree.ForStatement):
        current = sum(1 for p in path if isinstance(p, (javalang.tree.ForStatement, javalang.tree.WhileStatement, javalang.tree.DoStatement)))
        real_max_depth = max(real_max_depth, current + 1)

    for path, node in tree.filter(javalang.tree.WhileStatement):
        current = sum(1 for p in path if isinstance(p, (javalang.tree.ForStatement, javalang.tree.WhileStatement, javalang.tree.DoStatement)))
        real_max_depth = max(real_max_depth, current + 1)

    # Branch Count
    for path, node in tree.filter(javalang.tree.IfStatement):
        branch_count += 1

    # Recursion & Sorting
    methods = [node.name for path, node in tree.filter(javalang.tree.MethodDeclaration)]
    for path, node in tree.filter(javalang.tree.MethodInvocation):
        if node.member in methods:
            has_recursion = 1
        if node.member == 'sort':
            has_sort = 1

    # AST-Based Log Math
    for path, node in tree.filter(javalang.tree.BinaryOperation):
        if node.operator in ['/', '*', '>>', '<<', '>>>']:
            has_log_math = 1

    for path, node in tree.filter(javalang.tree.Assignment):
        if node.type in ['/=', '*=', '>>=', '<<=', '>>>=']:
            has_log_math = 1

    return [real_max_depth, branch_count, has_recursion, has_log_math, has_sort]