Spaces:
Running
Running
| import networkx as nx | |
| import torch | |
| import re | |
| from transformers import AutoTokenizer, AutoModel | |
| from torch_geometric.data import Data | |
| import config | |
| class GraphExtractor: | |
| def __init__(self): | |
| print("⏳ Loading GraphCodeBERT Model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") | |
| self.bert_model = AutoModel.from_pretrained("microsoft/graphcodebert-base") | |
| # CPU for Hugging Face Spaces (GPU is unavailable on free tier) | |
| self.device = torch.device("cpu") | |
| self.bert_model.to(self.device) | |
| # [CRITICAL] V2 Edge Mapping (Matches your training config) | |
| # AST=0, CFG=1, CDG=2, REACHING_DEF=4, DDG=5 (The 'Other' bucket) | |
| self.edge_type_map = { | |
| 'AST': 0, | |
| 'CFG': 1, | |
| 'CDG': 2, | |
| 'REACHING_DEF': 4, | |
| 'DDG': 5 | |
| } | |
| # Signal Types | |
| self.SIGNAL_TYPES = {'IDENTIFIER', 'LITERAL', 'CALL', 'METHOD_RETURN', 'METHOD', 'SimpleName'} | |
| self.TRAVERSAL_EDGES = ['AST', 'CFG', 'CDG', 'REACHING_DEF', 'DDG'] | |
| def normalize_code(self, code_str): | |
| """Anti-Cheat Normalizer (White List + Regex)""" | |
| if not code_str or len(code_str) > 512: return "empty" | |
| keep_list = { | |
| 'char', 'int', 'float', 'double', 'void', 'long', 'short', 'unsigned', 'signed', | |
| 'struct', 'union', 'enum', 'const', 'volatile', 'static', 'auto', 'register', | |
| 'if', 'else', 'switch', 'case', 'default', 'while', 'do', 'for', 'goto', | |
| 'continue', 'break', 'return', 'sizeof', 'typedef', 'NULL', 'true', 'false', | |
| 'malloc', 'free', 'memset', 'memcpy', 'strcpy', 'strncpy', 'printf', 'scanf', | |
| 'recv', 'send', 'socket', 'bind', 'listen', 'accept', | |
| 'system', 'popen', 'getenv', 'snprintf', 'fprintf', 'sprintf' | |
| } | |
| tokens = re.findall(r'[a-zA-Z_]\w*', code_str) | |
| var_map = {} | |
| counter = 1 | |
| new_code = code_str | |
| for token in tokens: | |
| if token not in keep_list and token not in var_map: | |
| var_map[token] = f"VAR_{counter}" | |
| counter += 1 | |
| for original, replacement in sorted(var_map.items(), key=lambda x: -len(x[0])): | |
| pattern = r'\b' + re.escape(original) + r'\b' | |
| new_code = re.sub(pattern, replacement, new_code) | |
| return new_code | |
| def get_code_embedding(self, code_snippets): | |
| if not code_snippets: return torch.empty(0, 768) | |
| # [IMPROVEMENT] Increased max_length to 512 for better context | |
| inputs = self.tokenizer(code_snippets, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.bert_model(**inputs) | |
| return outputs.last_hidden_state[:, 0, :] | |
| def extract_function_subgraph(self, G, root_node): | |
| """Look Down (AST/CFG) and Look Up (REF) crawler""" | |
| nodes_in_func = {root_node} | |
| stack = [root_node] | |
| visited = {root_node} | |
| while stack: | |
| curr = stack.pop() | |
| # Look Down | |
| for nbr in G.successors(curr): | |
| if nbr in visited: continue | |
| all_edges = G.get_edge_data(curr, nbr) | |
| is_child = False | |
| for _, attr in all_edges.items(): | |
| lbl = attr.get('labelE', attr.get('label', '')) | |
| if lbl in self.TRAVERSAL_EDGES: | |
| is_child = True | |
| break | |
| if is_child: | |
| visited.add(nbr); nodes_in_func.add(nbr); stack.append(nbr) | |
| # Look Up | |
| for nbr in G.predecessors(curr): | |
| if nbr in visited: continue | |
| edge_data = G.get_edge_data(nbr, curr) | |
| is_ref = False | |
| for _, attr in edge_data.items(): | |
| lbl = attr.get('labelE', attr.get('label', '')) | |
| if lbl == 'REF': | |
| is_ref = True | |
| break | |
| if is_ref: | |
| visited.add(nbr); nodes_in_func.add(nbr) | |
| return G.subgraph(nodes_in_func).copy() | |
| def process_graph(self, graphml_path): | |
| try: | |
| G = nx.read_graphml(graphml_path) | |
| except Exception: | |
| return [] | |
| inference_data_list = [] | |
| for n, data in G.nodes(data=True): | |
| if data.get('labelV') == 'METHOD': | |
| func_name = data.get('NAME', '') or data.get('METHOD_FULL_NAME', '') | |
| if str(data.get('IS_EXTERNAL', 'false')).lower() == 'true': continue | |
| if "<operator>" in func_name or "<global>" in func_name or "<init>" in func_name: continue | |
| subG = self.extract_function_subgraph(G, n) | |
| if len(subG.nodes) < 3: continue | |
| node_ids = list(subG.nodes()) | |
| node_map = {id: idx for idx, id in enumerate(node_ids)} | |
| x_tensor = torch.zeros((len(node_ids), 768)) | |
| code_batch = [] | |
| valid_indices = [] | |
| for i, nid in enumerate(node_ids): | |
| d = subG.nodes[nid] | |
| if d.get('labelV', 'UNKNOWN') in self.SIGNAL_TYPES: | |
| raw_code = str(d.get('code') or d.get('CODE') or d.get('name') or d.get('NAME') or "") | |
| clean_code = self.normalize_code(raw_code) | |
| code_batch.append(clean_code) | |
| valid_indices.append(i) | |
| if code_batch: | |
| embeddings = [] | |
| for k in range(0, len(code_batch), 32): | |
| batch_text = [t if t.strip() else "empty" for t in code_batch[k:k+32]] | |
| embeddings.append(self.get_code_embedding(batch_text).cpu()) | |
| if embeddings: | |
| full_embeddings = torch.cat(embeddings, dim=0) | |
| for idx, emb in zip(valid_indices, full_embeddings): | |
| x_tensor[idx] = emb | |
| edge_indices = [] | |
| edge_attrs = [] | |
| for src, dst, key, edata in subG.edges(keys=True, data=True): | |
| if src in node_map and dst in node_map: | |
| edge_indices.append([node_map[src], node_map[dst]]) | |
| lbl = edata.get('labelE', edata.get('label', 'AST')) | |
| # Use the V2 Map. Default to 5 (DDG/Other) | |
| etype = self.edge_type_map.get(lbl, 5) | |
| # One-Hot Encoding (Size 6) | |
| one_hot = [0] * config.NUM_EDGE_TYPES | |
| if etype < config.NUM_EDGE_TYPES: | |
| one_hot[etype] = 1 | |
| edge_attrs.append(one_hot) | |
| if not edge_indices: continue | |
| data_obj = Data( | |
| x=x_tensor, | |
| edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(), | |
| edge_attr=torch.tensor(edge_attrs, dtype=torch.float), | |
| func_name=func_name | |
| ) | |
| inference_data_list.append(data_obj) | |
| return inference_data_list | |