graphguard-backend / graph_extractor.py
Bharateesha lvn
Deploy GraphGuard Backend V1
d511278
import networkx as nx
import torch
import re
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
import config
class GraphExtractor:
def __init__(self):
print("⏳ Loading GraphCodeBERT Model...")
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
self.bert_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
# CPU for Hugging Face Spaces (GPU is unavailable on free tier)
self.device = torch.device("cpu")
self.bert_model.to(self.device)
# [CRITICAL] V2 Edge Mapping (Matches your training config)
# AST=0, CFG=1, CDG=2, REACHING_DEF=4, DDG=5 (The 'Other' bucket)
self.edge_type_map = {
'AST': 0,
'CFG': 1,
'CDG': 2,
'REACHING_DEF': 4,
'DDG': 5
}
# Signal Types
self.SIGNAL_TYPES = {'IDENTIFIER', 'LITERAL', 'CALL', 'METHOD_RETURN', 'METHOD', 'SimpleName'}
self.TRAVERSAL_EDGES = ['AST', 'CFG', 'CDG', 'REACHING_DEF', 'DDG']
def normalize_code(self, code_str):
"""Anti-Cheat Normalizer (White List + Regex)"""
if not code_str or len(code_str) > 512: return "empty"
keep_list = {
'char', 'int', 'float', 'double', 'void', 'long', 'short', 'unsigned', 'signed',
'struct', 'union', 'enum', 'const', 'volatile', 'static', 'auto', 'register',
'if', 'else', 'switch', 'case', 'default', 'while', 'do', 'for', 'goto',
'continue', 'break', 'return', 'sizeof', 'typedef', 'NULL', 'true', 'false',
'malloc', 'free', 'memset', 'memcpy', 'strcpy', 'strncpy', 'printf', 'scanf',
'recv', 'send', 'socket', 'bind', 'listen', 'accept',
'system', 'popen', 'getenv', 'snprintf', 'fprintf', 'sprintf'
}
tokens = re.findall(r'[a-zA-Z_]\w*', code_str)
var_map = {}
counter = 1
new_code = code_str
for token in tokens:
if token not in keep_list and token not in var_map:
var_map[token] = f"VAR_{counter}"
counter += 1
for original, replacement in sorted(var_map.items(), key=lambda x: -len(x[0])):
pattern = r'\b' + re.escape(original) + r'\b'
new_code = re.sub(pattern, replacement, new_code)
return new_code
def get_code_embedding(self, code_snippets):
if not code_snippets: return torch.empty(0, 768)
# [IMPROVEMENT] Increased max_length to 512 for better context
inputs = self.tokenizer(code_snippets, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
return outputs.last_hidden_state[:, 0, :]
def extract_function_subgraph(self, G, root_node):
"""Look Down (AST/CFG) and Look Up (REF) crawler"""
nodes_in_func = {root_node}
stack = [root_node]
visited = {root_node}
while stack:
curr = stack.pop()
# Look Down
for nbr in G.successors(curr):
if nbr in visited: continue
all_edges = G.get_edge_data(curr, nbr)
is_child = False
for _, attr in all_edges.items():
lbl = attr.get('labelE', attr.get('label', ''))
if lbl in self.TRAVERSAL_EDGES:
is_child = True
break
if is_child:
visited.add(nbr); nodes_in_func.add(nbr); stack.append(nbr)
# Look Up
for nbr in G.predecessors(curr):
if nbr in visited: continue
edge_data = G.get_edge_data(nbr, curr)
is_ref = False
for _, attr in edge_data.items():
lbl = attr.get('labelE', attr.get('label', ''))
if lbl == 'REF':
is_ref = True
break
if is_ref:
visited.add(nbr); nodes_in_func.add(nbr)
return G.subgraph(nodes_in_func).copy()
def process_graph(self, graphml_path):
try:
G = nx.read_graphml(graphml_path)
except Exception:
return []
inference_data_list = []
for n, data in G.nodes(data=True):
if data.get('labelV') == 'METHOD':
func_name = data.get('NAME', '') or data.get('METHOD_FULL_NAME', '')
if str(data.get('IS_EXTERNAL', 'false')).lower() == 'true': continue
if "<operator>" in func_name or "<global>" in func_name or "<init>" in func_name: continue
subG = self.extract_function_subgraph(G, n)
if len(subG.nodes) < 3: continue
node_ids = list(subG.nodes())
node_map = {id: idx for idx, id in enumerate(node_ids)}
x_tensor = torch.zeros((len(node_ids), 768))
code_batch = []
valid_indices = []
for i, nid in enumerate(node_ids):
d = subG.nodes[nid]
if d.get('labelV', 'UNKNOWN') in self.SIGNAL_TYPES:
raw_code = str(d.get('code') or d.get('CODE') or d.get('name') or d.get('NAME') or "")
clean_code = self.normalize_code(raw_code)
code_batch.append(clean_code)
valid_indices.append(i)
if code_batch:
embeddings = []
for k in range(0, len(code_batch), 32):
batch_text = [t if t.strip() else "empty" for t in code_batch[k:k+32]]
embeddings.append(self.get_code_embedding(batch_text).cpu())
if embeddings:
full_embeddings = torch.cat(embeddings, dim=0)
for idx, emb in zip(valid_indices, full_embeddings):
x_tensor[idx] = emb
edge_indices = []
edge_attrs = []
for src, dst, key, edata in subG.edges(keys=True, data=True):
if src in node_map and dst in node_map:
edge_indices.append([node_map[src], node_map[dst]])
lbl = edata.get('labelE', edata.get('label', 'AST'))
# Use the V2 Map. Default to 5 (DDG/Other)
etype = self.edge_type_map.get(lbl, 5)
# One-Hot Encoding (Size 6)
one_hot = [0] * config.NUM_EDGE_TYPES
if etype < config.NUM_EDGE_TYPES:
one_hot[etype] = 1
edge_attrs.append(one_hot)
if not edge_indices: continue
data_obj = Data(
x=x_tensor,
edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(),
edge_attr=torch.tensor(edge_attrs, dtype=torch.float),
func_name=func_name
)
inference_data_list.append(data_obj)
return inference_data_list