Spaces:
Running
Running
File size: 7,563 Bytes
d511278 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import networkx as nx
import torch
import re
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
import config
class GraphExtractor:
def __init__(self):
print("⏳ Loading GraphCodeBERT Model...")
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
self.bert_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
# CPU for Hugging Face Spaces (GPU is unavailable on free tier)
self.device = torch.device("cpu")
self.bert_model.to(self.device)
# [CRITICAL] V2 Edge Mapping (Matches your training config)
# AST=0, CFG=1, CDG=2, REACHING_DEF=4, DDG=5 (The 'Other' bucket)
self.edge_type_map = {
'AST': 0,
'CFG': 1,
'CDG': 2,
'REACHING_DEF': 4,
'DDG': 5
}
# Signal Types
self.SIGNAL_TYPES = {'IDENTIFIER', 'LITERAL', 'CALL', 'METHOD_RETURN', 'METHOD', 'SimpleName'}
self.TRAVERSAL_EDGES = ['AST', 'CFG', 'CDG', 'REACHING_DEF', 'DDG']
def normalize_code(self, code_str):
"""Anti-Cheat Normalizer (White List + Regex)"""
if not code_str or len(code_str) > 512: return "empty"
keep_list = {
'char', 'int', 'float', 'double', 'void', 'long', 'short', 'unsigned', 'signed',
'struct', 'union', 'enum', 'const', 'volatile', 'static', 'auto', 'register',
'if', 'else', 'switch', 'case', 'default', 'while', 'do', 'for', 'goto',
'continue', 'break', 'return', 'sizeof', 'typedef', 'NULL', 'true', 'false',
'malloc', 'free', 'memset', 'memcpy', 'strcpy', 'strncpy', 'printf', 'scanf',
'recv', 'send', 'socket', 'bind', 'listen', 'accept',
'system', 'popen', 'getenv', 'snprintf', 'fprintf', 'sprintf'
}
tokens = re.findall(r'[a-zA-Z_]\w*', code_str)
var_map = {}
counter = 1
new_code = code_str
for token in tokens:
if token not in keep_list and token not in var_map:
var_map[token] = f"VAR_{counter}"
counter += 1
for original, replacement in sorted(var_map.items(), key=lambda x: -len(x[0])):
pattern = r'\b' + re.escape(original) + r'\b'
new_code = re.sub(pattern, replacement, new_code)
return new_code
def get_code_embedding(self, code_snippets):
if not code_snippets: return torch.empty(0, 768)
# [IMPROVEMENT] Increased max_length to 512 for better context
inputs = self.tokenizer(code_snippets, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
return outputs.last_hidden_state[:, 0, :]
def extract_function_subgraph(self, G, root_node):
"""Look Down (AST/CFG) and Look Up (REF) crawler"""
nodes_in_func = {root_node}
stack = [root_node]
visited = {root_node}
while stack:
curr = stack.pop()
# Look Down
for nbr in G.successors(curr):
if nbr in visited: continue
all_edges = G.get_edge_data(curr, nbr)
is_child = False
for _, attr in all_edges.items():
lbl = attr.get('labelE', attr.get('label', ''))
if lbl in self.TRAVERSAL_EDGES:
is_child = True
break
if is_child:
visited.add(nbr); nodes_in_func.add(nbr); stack.append(nbr)
# Look Up
for nbr in G.predecessors(curr):
if nbr in visited: continue
edge_data = G.get_edge_data(nbr, curr)
is_ref = False
for _, attr in edge_data.items():
lbl = attr.get('labelE', attr.get('label', ''))
if lbl == 'REF':
is_ref = True
break
if is_ref:
visited.add(nbr); nodes_in_func.add(nbr)
return G.subgraph(nodes_in_func).copy()
def process_graph(self, graphml_path):
try:
G = nx.read_graphml(graphml_path)
except Exception:
return []
inference_data_list = []
for n, data in G.nodes(data=True):
if data.get('labelV') == 'METHOD':
func_name = data.get('NAME', '') or data.get('METHOD_FULL_NAME', '')
if str(data.get('IS_EXTERNAL', 'false')).lower() == 'true': continue
if "<operator>" in func_name or "<global>" in func_name or "<init>" in func_name: continue
subG = self.extract_function_subgraph(G, n)
if len(subG.nodes) < 3: continue
node_ids = list(subG.nodes())
node_map = {id: idx for idx, id in enumerate(node_ids)}
x_tensor = torch.zeros((len(node_ids), 768))
code_batch = []
valid_indices = []
for i, nid in enumerate(node_ids):
d = subG.nodes[nid]
if d.get('labelV', 'UNKNOWN') in self.SIGNAL_TYPES:
raw_code = str(d.get('code') or d.get('CODE') or d.get('name') or d.get('NAME') or "")
clean_code = self.normalize_code(raw_code)
code_batch.append(clean_code)
valid_indices.append(i)
if code_batch:
embeddings = []
for k in range(0, len(code_batch), 32):
batch_text = [t if t.strip() else "empty" for t in code_batch[k:k+32]]
embeddings.append(self.get_code_embedding(batch_text).cpu())
if embeddings:
full_embeddings = torch.cat(embeddings, dim=0)
for idx, emb in zip(valid_indices, full_embeddings):
x_tensor[idx] = emb
edge_indices = []
edge_attrs = []
for src, dst, key, edata in subG.edges(keys=True, data=True):
if src in node_map and dst in node_map:
edge_indices.append([node_map[src], node_map[dst]])
lbl = edata.get('labelE', edata.get('label', 'AST'))
# Use the V2 Map. Default to 5 (DDG/Other)
etype = self.edge_type_map.get(lbl, 5)
# One-Hot Encoding (Size 6)
one_hot = [0] * config.NUM_EDGE_TYPES
if etype < config.NUM_EDGE_TYPES:
one_hot[etype] = 1
edge_attrs.append(one_hot)
if not edge_indices: continue
data_obj = Data(
x=x_tensor,
edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(),
edge_attr=torch.tensor(edge_attrs, dtype=torch.float),
func_name=func_name
)
inference_data_list.append(data_obj)
return inference_data_list
|