dv4aby commited on
Commit
1abb892
·
verified ·
1 Parent(s): 0f0d8c7

Upload source code structural_encoder_v2.py

Browse files
Files changed (1) hide show
  1. structural_encoder_v2.py +184 -0
structural_encoder_v2.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Tuple, TYPE_CHECKING, Optional
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch_geometric.data import HeteroData, Batch
8
+ from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
9
+ from transformers import AutoModel, AutoTokenizer
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+ if TYPE_CHECKING:
14
+ import pandas as pd
15
+
16
+ from dataloader import CodeGraphBuilder
17
+
18
+ class RelationalGraphEncoder(nn.Module):
19
+ """R-GNN encoder over the AST+CFG heterogeneous graph."""
20
+
21
+ EDGE_TYPES = (
22
+ ("ast", "ast_parent_child", "ast"),
23
+ ("ast", "ast_child_parent", "ast"),
24
+ ("ast", "ast_next_sibling", "ast"),
25
+ ("ast", "ast_prev_sibling", "ast"),
26
+ ("token", "token_to_ast", "ast"),
27
+ ("ast", "ast_to_token", "token"),
28
+ ("stmt", "cfg", "stmt"),
29
+ ("stmt", "cfg_rev", "stmt"),
30
+ ("stmt", "stmt_to_ast", "ast"),
31
+ ("ast", "ast_to_stmt", "stmt"),
32
+ )
33
+
34
+ def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None:
35
+ super().__init__()
36
+ self.hidden_dim = hidden_dim
37
+ self.out_dim = out_dim
38
+
39
+ self.ast_encoder = nn.Embedding(2048, hidden_dim)
40
+ self.token_encoder = nn.Embedding(8192, hidden_dim)
41
+ self.stmt_encoder = nn.Embedding(512, hidden_dim)
42
+
43
+ self.convs = nn.ModuleList()
44
+ for _ in range(num_layers):
45
+ hetero_modules = {
46
+ edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False)
47
+ for edge_type in self.EDGE_TYPES
48
+ }
49
+ hetero_conv = HeteroConv(hetero_modules, aggr="sum")
50
+ self.convs.append(hetero_conv)
51
+
52
+ self.output_proj = nn.Linear(hidden_dim, out_dim)
53
+
54
+ def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]:
55
+ device = self.ast_encoder.weight.device
56
+
57
+ def get_embed(node_type, encoder):
58
+ if node_type not in data.node_types:
59
+ return torch.zeros((0, self.hidden_dim), device=device)
60
+
61
+ x = data[node_type].get('x')
62
+ if x is None:
63
+ return torch.zeros((0, self.hidden_dim), device=device)
64
+
65
+ x = x.to(device)
66
+ return encoder(x)
67
+
68
+ x_dict = {
69
+ "ast": get_embed("ast", self.ast_encoder),
70
+ "token": get_embed("token", self.token_encoder),
71
+ "stmt": get_embed("stmt", self.stmt_encoder),
72
+ }
73
+ return x_dict
74
+
75
+ def forward(self, data: HeteroData) -> torch.Tensor:
76
+ device = next(self.parameters()).device
77
+ data = data.to(device)
78
+
79
+ x_dict = self._encode_nodes(data)
80
+
81
+ edge_index_dict = {}
82
+ for edge_type in self.EDGE_TYPES:
83
+ if edge_type in data.edge_index_dict:
84
+ edge_index_dict[edge_type] = data.edge_index_dict[edge_type]
85
+
86
+ for conv in self.convs:
87
+ x_dict = conv(x_dict, edge_index_dict)
88
+ x_dict = {key: F.relu(x) for key, x in x_dict.items()}
89
+
90
+ batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
91
+
92
+ pooled_embeddings = []
93
+ for key, x in x_dict.items():
94
+ if x.size(0) == 0:
95
+ continue
96
+
97
+ if hasattr(data[key], 'batch') and data[key].batch is not None:
98
+ pool = global_mean_pool(x, data[key].batch, size=batch_size)
99
+ else:
100
+ pool = x.mean(dim=0, keepdim=True)
101
+ if pool.size(0) != batch_size:
102
+ pass
103
+ pooled_embeddings.append(pool)
104
+
105
+ if not pooled_embeddings:
106
+ return torch.zeros((batch_size, self.out_dim), device=device)
107
+
108
+ graph_repr = torch.stack(pooled_embeddings).mean(dim=0)
109
+ return self.output_proj(graph_repr)
110
+
111
+
112
+ class GatedFusion(nn.Module):
113
+ def __init__(self, text_dim: int, graph_dim: int) -> None:
114
+ super().__init__()
115
+ self.graph_proj = nn.Linear(graph_dim, text_dim)
116
+ self.gate = nn.Linear(text_dim * 2, text_dim)
117
+
118
+ def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor:
119
+ h_graph_proj = self.graph_proj(h_graph)
120
+ joint = torch.cat([h_text, h_graph_proj], dim=-1)
121
+ gate = torch.sigmoid(self.gate(joint))
122
+ return gate * h_text + (1.0 - gate) * h_graph_proj
123
+
124
+
125
+ class StructuralEncoderV2(nn.Module):
126
+ """Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context."""
127
+
128
+ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
129
+ super().__init__()
130
+ self.device = torch.device(device)
131
+ # Tokenizer is now in dataloader, but used here for size configs or inference if needed
132
+ self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
133
+ self.text_model.to(self.device)
134
+
135
+ self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers)
136
+ self.graph_encoder.to(self.device)
137
+
138
+ self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size)
139
+ self.fusion.to(self.device)
140
+
141
+ def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
142
+ input_ids = input_ids.to(self.device)
143
+ attention_mask = attention_mask.to(self.device)
144
+ outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
145
+ return outputs.last_hidden_state[:, 0, :]
146
+
147
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch | HeteroData) -> torch.Tensor:
148
+ text_embeddings = self.encode_text(input_ids, attention_mask)
149
+ graph_embeddings = self.graph_encoder(graph_batch)
150
+ return self.fusion(text_embeddings, graph_embeddings)
151
+
152
+ def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray:
153
+ # Local resources for inference
154
+ builder = CodeGraphBuilder()
155
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
156
+
157
+ codes = df["code"].tolist()
158
+ batches = range(0, len(codes), batch_size)
159
+ all_embeddings: List[torch.Tensor] = []
160
+
161
+ for start in tqdm(batches, desc=desc):
162
+ batch_codes = codes[start:start + batch_size]
163
+
164
+ data_list = [builder.build(c) for c in batch_codes]
165
+ graph_batch = Batch.from_data_list(data_list)
166
+
167
+ tok = tokenizer(batch_codes, padding=True, truncation=True, max_length=512, return_tensors="pt")
168
+
169
+ with torch.no_grad():
170
+ fused = self.forward(tok["input_ids"], tok["attention_mask"], graph_batch)
171
+ all_embeddings.append(fused.cpu())
172
+
173
+ embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
174
+ if save_path is not None:
175
+ np.save(save_path, embeddings)
176
+ return embeddings
177
+
178
+ def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
179
+ if not checkpoint_path:
180
+ raise ValueError("checkpoint_path must be provided")
181
+ state = torch.load(checkpoint_path, map_location=map_location)
182
+ if isinstance(state, dict) and "state_dict" in state:
183
+ state = state["state_dict"]
184
+ self.load_state_dict(state, strict=strict)