File size: 7,573 Bytes
1abb892 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import hashlib
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch
from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np
if TYPE_CHECKING:
import pandas as pd
from dataloader import CodeGraphBuilder
class RelationalGraphEncoder(nn.Module):
"""R-GNN encoder over the AST+CFG heterogeneous graph."""
EDGE_TYPES = (
("ast", "ast_parent_child", "ast"),
("ast", "ast_child_parent", "ast"),
("ast", "ast_next_sibling", "ast"),
("ast", "ast_prev_sibling", "ast"),
("token", "token_to_ast", "ast"),
("ast", "ast_to_token", "token"),
("stmt", "cfg", "stmt"),
("stmt", "cfg_rev", "stmt"),
("stmt", "stmt_to_ast", "ast"),
("ast", "ast_to_stmt", "stmt"),
)
def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.ast_encoder = nn.Embedding(2048, hidden_dim)
self.token_encoder = nn.Embedding(8192, hidden_dim)
self.stmt_encoder = nn.Embedding(512, hidden_dim)
self.convs = nn.ModuleList()
for _ in range(num_layers):
hetero_modules = {
edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False)
for edge_type in self.EDGE_TYPES
}
hetero_conv = HeteroConv(hetero_modules, aggr="sum")
self.convs.append(hetero_conv)
self.output_proj = nn.Linear(hidden_dim, out_dim)
def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]:
device = self.ast_encoder.weight.device
def get_embed(node_type, encoder):
if node_type not in data.node_types:
return torch.zeros((0, self.hidden_dim), device=device)
x = data[node_type].get('x')
if x is None:
return torch.zeros((0, self.hidden_dim), device=device)
x = x.to(device)
return encoder(x)
x_dict = {
"ast": get_embed("ast", self.ast_encoder),
"token": get_embed("token", self.token_encoder),
"stmt": get_embed("stmt", self.stmt_encoder),
}
return x_dict
def forward(self, data: HeteroData) -> torch.Tensor:
device = next(self.parameters()).device
data = data.to(device)
x_dict = self._encode_nodes(data)
edge_index_dict = {}
for edge_type in self.EDGE_TYPES:
if edge_type in data.edge_index_dict:
edge_index_dict[edge_type] = data.edge_index_dict[edge_type]
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
pooled_embeddings = []
for key, x in x_dict.items():
if x.size(0) == 0:
continue
if hasattr(data[key], 'batch') and data[key].batch is not None:
pool = global_mean_pool(x, data[key].batch, size=batch_size)
else:
pool = x.mean(dim=0, keepdim=True)
if pool.size(0) != batch_size:
pass
pooled_embeddings.append(pool)
if not pooled_embeddings:
return torch.zeros((batch_size, self.out_dim), device=device)
graph_repr = torch.stack(pooled_embeddings).mean(dim=0)
return self.output_proj(graph_repr)
class GatedFusion(nn.Module):
def __init__(self, text_dim: int, graph_dim: int) -> None:
super().__init__()
self.graph_proj = nn.Linear(graph_dim, text_dim)
self.gate = nn.Linear(text_dim * 2, text_dim)
def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor:
h_graph_proj = self.graph_proj(h_graph)
joint = torch.cat([h_text, h_graph_proj], dim=-1)
gate = torch.sigmoid(self.gate(joint))
return gate * h_text + (1.0 - gate) * h_graph_proj
class StructuralEncoderV2(nn.Module):
"""Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context."""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
super().__init__()
self.device = torch.device(device)
# Tokenizer is now in dataloader, but used here for size configs or inference if needed
self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
self.text_model.to(self.device)
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers)
self.graph_encoder.to(self.device)
self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size)
self.fusion.to(self.device)
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state[:, 0, :]
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch | HeteroData) -> torch.Tensor:
text_embeddings = self.encode_text(input_ids, attention_mask)
graph_embeddings = self.graph_encoder(graph_batch)
return self.fusion(text_embeddings, graph_embeddings)
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray:
# Local resources for inference
builder = CodeGraphBuilder()
tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
codes = df["code"].tolist()
batches = range(0, len(codes), batch_size)
all_embeddings: List[torch.Tensor] = []
for start in tqdm(batches, desc=desc):
batch_codes = codes[start:start + batch_size]
data_list = [builder.build(c) for c in batch_codes]
graph_batch = Batch.from_data_list(data_list)
tok = tokenizer(batch_codes, padding=True, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
fused = self.forward(tok["input_ids"], tok["attention_mask"], graph_batch)
all_embeddings.append(fused.cpu())
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
if save_path is not None:
np.save(save_path, embeddings)
return embeddings
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
if not checkpoint_path:
raise ValueError("checkpoint_path must be provided")
state = torch.load(checkpoint_path, map_location=map_location)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
self.load_state_dict(state, strict=strict)
|