import networkx as nx import torch import re from transformers import AutoTokenizer, AutoModel from torch_geometric.data import Data import config class GraphExtractor: def __init__(self): print("⏳ Loading GraphCodeBERT Model...") self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") self.bert_model = AutoModel.from_pretrained("microsoft/graphcodebert-base") # CPU for Hugging Face Spaces (GPU is unavailable on free tier) self.device = torch.device("cpu") self.bert_model.to(self.device) # [CRITICAL] V2 Edge Mapping (Matches your training config) # AST=0, CFG=1, CDG=2, REACHING_DEF=4, DDG=5 (The 'Other' bucket) self.edge_type_map = { 'AST': 0, 'CFG': 1, 'CDG': 2, 'REACHING_DEF': 4, 'DDG': 5 } # Signal Types self.SIGNAL_TYPES = {'IDENTIFIER', 'LITERAL', 'CALL', 'METHOD_RETURN', 'METHOD', 'SimpleName'} self.TRAVERSAL_EDGES = ['AST', 'CFG', 'CDG', 'REACHING_DEF', 'DDG'] def normalize_code(self, code_str): """Anti-Cheat Normalizer (White List + Regex)""" if not code_str or len(code_str) > 512: return "empty" keep_list = { 'char', 'int', 'float', 'double', 'void', 'long', 'short', 'unsigned', 'signed', 'struct', 'union', 'enum', 'const', 'volatile', 'static', 'auto', 'register', 'if', 'else', 'switch', 'case', 'default', 'while', 'do', 'for', 'goto', 'continue', 'break', 'return', 'sizeof', 'typedef', 'NULL', 'true', 'false', 'malloc', 'free', 'memset', 'memcpy', 'strcpy', 'strncpy', 'printf', 'scanf', 'recv', 'send', 'socket', 'bind', 'listen', 'accept', 'system', 'popen', 'getenv', 'snprintf', 'fprintf', 'sprintf' } tokens = re.findall(r'[a-zA-Z_]\w*', code_str) var_map = {} counter = 1 new_code = code_str for token in tokens: if token not in keep_list and token not in var_map: var_map[token] = f"VAR_{counter}" counter += 1 for original, replacement in sorted(var_map.items(), key=lambda x: -len(x[0])): pattern = r'\b' + re.escape(original) + r'\b' new_code = re.sub(pattern, replacement, new_code) return new_code def get_code_embedding(self, code_snippets): if not code_snippets: return torch.empty(0, 768) # [IMPROVEMENT] Increased max_length to 512 for better context inputs = self.tokenizer(code_snippets, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device) with torch.no_grad(): outputs = self.bert_model(**inputs) return outputs.last_hidden_state[:, 0, :] def extract_function_subgraph(self, G, root_node): """Look Down (AST/CFG) and Look Up (REF) crawler""" nodes_in_func = {root_node} stack = [root_node] visited = {root_node} while stack: curr = stack.pop() # Look Down for nbr in G.successors(curr): if nbr in visited: continue all_edges = G.get_edge_data(curr, nbr) is_child = False for _, attr in all_edges.items(): lbl = attr.get('labelE', attr.get('label', '')) if lbl in self.TRAVERSAL_EDGES: is_child = True break if is_child: visited.add(nbr); nodes_in_func.add(nbr); stack.append(nbr) # Look Up for nbr in G.predecessors(curr): if nbr in visited: continue edge_data = G.get_edge_data(nbr, curr) is_ref = False for _, attr in edge_data.items(): lbl = attr.get('labelE', attr.get('label', '')) if lbl == 'REF': is_ref = True break if is_ref: visited.add(nbr); nodes_in_func.add(nbr) return G.subgraph(nodes_in_func).copy() def process_graph(self, graphml_path): try: G = nx.read_graphml(graphml_path) except Exception: return [] inference_data_list = [] for n, data in G.nodes(data=True): if data.get('labelV') == 'METHOD': func_name = data.get('NAME', '') or data.get('METHOD_FULL_NAME', '') if str(data.get('IS_EXTERNAL', 'false')).lower() == 'true': continue if "" in func_name or "" in func_name or "" in func_name: continue subG = self.extract_function_subgraph(G, n) if len(subG.nodes) < 3: continue node_ids = list(subG.nodes()) node_map = {id: idx for idx, id in enumerate(node_ids)} x_tensor = torch.zeros((len(node_ids), 768)) code_batch = [] valid_indices = [] for i, nid in enumerate(node_ids): d = subG.nodes[nid] if d.get('labelV', 'UNKNOWN') in self.SIGNAL_TYPES: raw_code = str(d.get('code') or d.get('CODE') or d.get('name') or d.get('NAME') or "") clean_code = self.normalize_code(raw_code) code_batch.append(clean_code) valid_indices.append(i) if code_batch: embeddings = [] for k in range(0, len(code_batch), 32): batch_text = [t if t.strip() else "empty" for t in code_batch[k:k+32]] embeddings.append(self.get_code_embedding(batch_text).cpu()) if embeddings: full_embeddings = torch.cat(embeddings, dim=0) for idx, emb in zip(valid_indices, full_embeddings): x_tensor[idx] = emb edge_indices = [] edge_attrs = [] for src, dst, key, edata in subG.edges(keys=True, data=True): if src in node_map and dst in node_map: edge_indices.append([node_map[src], node_map[dst]]) lbl = edata.get('labelE', edata.get('label', 'AST')) # Use the V2 Map. Default to 5 (DDG/Other) etype = self.edge_type_map.get(lbl, 5) # One-Hot Encoding (Size 6) one_hot = [0] * config.NUM_EDGE_TYPES if etype < config.NUM_EDGE_TYPES: one_hot[etype] = 1 edge_attrs.append(one_hot) if not edge_indices: continue data_obj = Data( x=x_tensor, edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(), edge_attr=torch.tensor(edge_attrs, dtype=torch.float), func_name=func_name ) inference_data_list.append(data_obj) return inference_data_list