File size: 7,563 Bytes
d511278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

import networkx as nx
import torch
import re
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
import config

class GraphExtractor:
    def __init__(self):
        print("⏳ Loading GraphCodeBERT Model...")
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
        self.bert_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
        
        # CPU for Hugging Face Spaces (GPU is unavailable on free tier)
        self.device = torch.device("cpu") 
        self.bert_model.to(self.device)

        # [CRITICAL] V2 Edge Mapping (Matches your training config)
        # AST=0, CFG=1, CDG=2, REACHING_DEF=4, DDG=5 (The 'Other' bucket)
        self.edge_type_map = {
            'AST': 0, 
            'CFG': 1, 
            'CDG': 2, 
            'REACHING_DEF': 4, 
            'DDG': 5 
        }
        
        # Signal Types
        self.SIGNAL_TYPES = {'IDENTIFIER', 'LITERAL', 'CALL', 'METHOD_RETURN', 'METHOD', 'SimpleName'}
        self.TRAVERSAL_EDGES = ['AST', 'CFG', 'CDG', 'REACHING_DEF', 'DDG']

    def normalize_code(self, code_str):
        """Anti-Cheat Normalizer (White List + Regex)"""
        if not code_str or len(code_str) > 512: return "empty"
        
        keep_list = {
            'char', 'int', 'float', 'double', 'void', 'long', 'short', 'unsigned', 'signed',
            'struct', 'union', 'enum', 'const', 'volatile', 'static', 'auto', 'register',
            'if', 'else', 'switch', 'case', 'default', 'while', 'do', 'for', 'goto',
            'continue', 'break', 'return', 'sizeof', 'typedef', 'NULL', 'true', 'false',
            'malloc', 'free', 'memset', 'memcpy', 'strcpy', 'strncpy', 'printf', 'scanf',
            'recv', 'send', 'socket', 'bind', 'listen', 'accept',
            'system', 'popen', 'getenv', 'snprintf', 'fprintf', 'sprintf'
        }
        
        tokens = re.findall(r'[a-zA-Z_]\w*', code_str)
        var_map = {}
        counter = 1
        new_code = code_str
        
        for token in tokens:
            if token not in keep_list and token not in var_map:
                var_map[token] = f"VAR_{counter}"
                counter += 1
                
        for original, replacement in sorted(var_map.items(), key=lambda x: -len(x[0])):
            pattern = r'\b' + re.escape(original) + r'\b'
            new_code = re.sub(pattern, replacement, new_code)
            
        return new_code

    def get_code_embedding(self, code_snippets):
        if not code_snippets: return torch.empty(0, 768)
        
        # [IMPROVEMENT] Increased max_length to 512 for better context
        inputs = self.tokenizer(code_snippets, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.bert_model(**inputs)
        return outputs.last_hidden_state[:, 0, :]

    def extract_function_subgraph(self, G, root_node):
        """Look Down (AST/CFG) and Look Up (REF) crawler"""
        nodes_in_func = {root_node}
        stack = [root_node]
        visited = {root_node}
        
        while stack:
            curr = stack.pop()
            # Look Down
            for nbr in G.successors(curr):
                if nbr in visited: continue
                all_edges = G.get_edge_data(curr, nbr)
                is_child = False
                for _, attr in all_edges.items():
                    lbl = attr.get('labelE', attr.get('label', ''))
                    if lbl in self.TRAVERSAL_EDGES:
                        is_child = True
                        break
                if is_child:
                    visited.add(nbr); nodes_in_func.add(nbr); stack.append(nbr)
            # Look Up
            for nbr in G.predecessors(curr):
                if nbr in visited: continue
                edge_data = G.get_edge_data(nbr, curr)
                is_ref = False
                for _, attr in edge_data.items():
                    lbl = attr.get('labelE', attr.get('label', ''))
                    if lbl == 'REF': 
                        is_ref = True
                        break
                if is_ref:
                    visited.add(nbr); nodes_in_func.add(nbr)
        return G.subgraph(nodes_in_func).copy()

    def process_graph(self, graphml_path):
        try:
            G = nx.read_graphml(graphml_path)
        except Exception:
            return []
        
        inference_data_list = []
        
        for n, data in G.nodes(data=True):
            if data.get('labelV') == 'METHOD':
                func_name = data.get('NAME', '') or data.get('METHOD_FULL_NAME', '')
                if str(data.get('IS_EXTERNAL', 'false')).lower() == 'true': continue
                if "<operator>" in func_name or "<global>" in func_name or "<init>" in func_name: continue
                
                subG = self.extract_function_subgraph(G, n)
                if len(subG.nodes) < 3: continue 

                node_ids = list(subG.nodes())
                node_map = {id: idx for idx, id in enumerate(node_ids)}
                x_tensor = torch.zeros((len(node_ids), 768))
                
                code_batch = []
                valid_indices = []
                
                for i, nid in enumerate(node_ids):
                    d = subG.nodes[nid]
                    if d.get('labelV', 'UNKNOWN') in self.SIGNAL_TYPES:
                        raw_code = str(d.get('code') or d.get('CODE') or d.get('name') or d.get('NAME') or "")
                        clean_code = self.normalize_code(raw_code)
                        code_batch.append(clean_code)
                        valid_indices.append(i)
                
                if code_batch:
                    embeddings = []
                    for k in range(0, len(code_batch), 32):
                        batch_text = [t if t.strip() else "empty" for t in code_batch[k:k+32]]
                        embeddings.append(self.get_code_embedding(batch_text).cpu())
                    if embeddings:
                        full_embeddings = torch.cat(embeddings, dim=0)
                        for idx, emb in zip(valid_indices, full_embeddings):
                            x_tensor[idx] = emb
                
                edge_indices = []
                edge_attrs = []
                
                for src, dst, key, edata in subG.edges(keys=True, data=True):
                    if src in node_map and dst in node_map:
                        edge_indices.append([node_map[src], node_map[dst]])
                        lbl = edata.get('labelE', edata.get('label', 'AST'))
                        
                        # Use the V2 Map. Default to 5 (DDG/Other)
                        etype = self.edge_type_map.get(lbl, 5) 
                        
                        # One-Hot Encoding (Size 6)
                        one_hot = [0] * config.NUM_EDGE_TYPES
                        if etype < config.NUM_EDGE_TYPES:
                            one_hot[etype] = 1
                        edge_attrs.append(one_hot)

                if not edge_indices: continue
                
                data_obj = Data(
                    x=x_tensor, 
                    edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(), 
                    edge_attr=torch.tensor(edge_attrs, dtype=torch.float), 
                    func_name=func_name
                )
                inference_data_list.append(data_obj)

        return inference_data_list