File size: 7,866 Bytes
87d2e85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import hashlib
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch
from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np
if TYPE_CHECKING:
import pandas as pd
# Import Builder from dataloader for inference/eval
from dataloader import CodeGraphBuilder
class RelationalGraphEncoder(nn.Module):
"""R-GNN encoder over the AST+CFG heterogeneous graph."""
EDGE_TYPES = (
("ast", "ast_parent_child", "ast"),
("ast", "ast_child_parent", "ast"),
("ast", "ast_next_sibling", "ast"),
("ast", "ast_prev_sibling", "ast"),
("token", "token_to_ast", "ast"),
("ast", "ast_to_token", "token"),
("stmt", "cfg", "stmt"),
("stmt", "cfg_rev", "stmt"),
("stmt", "stmt_to_ast", "ast"),
("ast", "ast_to_stmt", "stmt"),
)
def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.ast_encoder = nn.Embedding(2048, hidden_dim)
self.token_encoder = nn.Embedding(8192, hidden_dim)
self.stmt_encoder = nn.Embedding(512, hidden_dim)
self.convs = nn.ModuleList()
for _ in range(num_layers):
hetero_modules = {
edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False)
for edge_type in self.EDGE_TYPES
}
hetero_conv = HeteroConv(hetero_modules, aggr="sum")
self.convs.append(hetero_conv)
self.output_proj = nn.Linear(hidden_dim, out_dim)
def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]:
device = self.ast_encoder.weight.device
def get_embed(node_type, encoder):
if node_type not in data.node_types:
return torch.zeros((0, self.hidden_dim), device=device)
x = data[node_type].get('x')
if x is None:
return torch.zeros((0, self.hidden_dim), device=device)
x = x.to(device)
return encoder(x)
x_dict = {
"ast": get_embed("ast", self.ast_encoder),
"token": get_embed("token", self.token_encoder),
"stmt": get_embed("stmt", self.stmt_encoder),
}
return x_dict
def forward(self, data: HeteroData) -> torch.Tensor:
device = next(self.parameters()).device
data = data.to(device)
x_dict = self._encode_nodes(data)
edge_index_dict = {}
for edge_type in self.EDGE_TYPES:
if edge_type in data.edge_index_dict:
edge_index_dict[edge_type] = data.edge_index_dict[edge_type]
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
# Global Pooling
batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
pooled_embeddings = []
for key, x in x_dict.items():
if x.size(0) == 0:
continue
if hasattr(data[key], 'batch') and data[key].batch is not None:
pool = global_mean_pool(x, data[key].batch, size=batch_size)
else:
# Logic for single graph without batch attribute (e.g. inference on one item)
pool = x.mean(dim=0, keepdim=True)
if pool.size(0) != batch_size:
# Should be 1
pass
pooled_embeddings.append(pool)
if not pooled_embeddings:
return torch.zeros((batch_size, self.out_dim), device=device)
# Average across node types [num_types, B, dim] -> [B, dim]
# We need to ensure all pools are [B, dim].
# If a graph misses a node type, its embedding for that type might be 0 or NaN?
# global_mean_pool returns 0 for empty batches.
graph_repr = torch.stack(pooled_embeddings).mean(dim=0)
return self.output_proj(graph_repr)
class GatedFusion(nn.Module):
def __init__(self, text_dim: int, graph_dim: int) -> None:
super().__init__()
self.graph_proj = nn.Linear(graph_dim, text_dim)
self.gate = nn.Linear(text_dim * 2, text_dim)
def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor:
h_graph_proj = self.graph_proj(h_graph)
joint = torch.cat([h_text, h_graph_proj], dim=-1)
gate = torch.sigmoid(self.gate(joint))
return gate * h_text + (1.0 - gate) * h_graph_proj
class StructuralEncoderV2(nn.Module):
"""Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context."""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
super().__init__()
self.device = torch.device(device)
self.text_tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
self.text_model.to(self.device)
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers)
self.graph_encoder.to(self.device)
self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size)
self.fusion.to(self.device)
def encode_text(self, codes: List[str]) -> torch.Tensor:
inputs = self.text_tokenizer(
codes,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(self.device)
outputs = self.text_model(**inputs)
return outputs.last_hidden_state[:, 0, :]
def forward(self, codes: List[str], graph_batch: Batch | HeteroData) -> torch.Tensor:
text_embeddings = self.encode_text(codes)
graph_embeddings = self.graph_encoder(graph_batch)
return self.fusion(text_embeddings, graph_embeddings)
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray:
# Create local builder for inference
builder = CodeGraphBuilder()
codes = df["code"].tolist()
batches = range(0, len(codes), batch_size)
all_embeddings: List[torch.Tensor] = []
for start in tqdm(batches, desc=desc):
batch_codes = codes[start:start + batch_size]
# Parallelism here not strictly needed for eval unless slow, but we do it simply
data_list = [builder.build(c) for c in batch_codes]
graph_batch = Batch.from_data_list(data_list)
with torch.no_grad():
fused = self.forward(batch_codes, graph_batch)
all_embeddings.append(fused.cpu())
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
if save_path is not None:
np.save(save_path, embeddings)
return embeddings
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
if not checkpoint_path:
raise ValueError("checkpoint_path must be provided")
state = torch.load(checkpoint_path, map_location=map_location)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
self.load_state_dict(state, strict=strict)
|