File size: 7,573 Bytes
1abb892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import hashlib
from collections import defaultdict
from typing import Dict, List, Tuple, TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Batch
from torch_geometric.nn import HeteroConv, GATConv, global_mean_pool
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np

if TYPE_CHECKING:
    import pandas as pd

from dataloader import CodeGraphBuilder

class RelationalGraphEncoder(nn.Module):
    """R-GNN encoder over the AST+CFG heterogeneous graph."""

    EDGE_TYPES = (
        ("ast", "ast_parent_child", "ast"),
        ("ast", "ast_child_parent", "ast"),
        ("ast", "ast_next_sibling", "ast"),
        ("ast", "ast_prev_sibling", "ast"),
        ("token", "token_to_ast", "ast"),
        ("ast", "ast_to_token", "token"),
        ("stmt", "cfg", "stmt"),
        ("stmt", "cfg_rev", "stmt"),
        ("stmt", "stmt_to_ast", "ast"),
        ("ast", "ast_to_stmt", "stmt"),
    )

    def __init__(self, hidden_dim: int = 256, out_dim: int = 768, num_layers: int = 2) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        self.ast_encoder = nn.Embedding(2048, hidden_dim)
        self.token_encoder = nn.Embedding(8192, hidden_dim)
        self.stmt_encoder = nn.Embedding(512, hidden_dim)

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            hetero_modules = {
                edge_type: GATConv((-1, -1), hidden_dim, add_self_loops=False)
                for edge_type in self.EDGE_TYPES
            }
            hetero_conv = HeteroConv(hetero_modules, aggr="sum")
            self.convs.append(hetero_conv)

        self.output_proj = nn.Linear(hidden_dim, out_dim)

    def _encode_nodes(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        device = self.ast_encoder.weight.device
        
        def get_embed(node_type, encoder):
            if node_type not in data.node_types:
                return torch.zeros((0, self.hidden_dim), device=device)
            
            x = data[node_type].get('x')
            if x is None:
                 return torch.zeros((0, self.hidden_dim), device=device)
            
            x = x.to(device)
            return encoder(x)

        x_dict = {
            "ast": get_embed("ast", self.ast_encoder),
            "token": get_embed("token", self.token_encoder),
            "stmt": get_embed("stmt", self.stmt_encoder),
        }
        return x_dict

    def forward(self, data: HeteroData) -> torch.Tensor:
        device = next(self.parameters()).device
        data = data.to(device)
        
        x_dict = self._encode_nodes(data)

        edge_index_dict = {}
        for edge_type in self.EDGE_TYPES:
            if edge_type in data.edge_index_dict:
                edge_index_dict[edge_type] = data.edge_index_dict[edge_type]

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
        
        pooled_embeddings = []
        for key, x in x_dict.items():
            if x.size(0) == 0:
                 continue
            
            if hasattr(data[key], 'batch') and data[key].batch is not None:
                pool = global_mean_pool(x, data[key].batch, size=batch_size)
            else:
                 pool = x.mean(dim=0, keepdim=True)
                 if pool.size(0) != batch_size:
                    pass
            pooled_embeddings.append(pool)
            
        if not pooled_embeddings:
             return torch.zeros((batch_size, self.out_dim), device=device)

        graph_repr = torch.stack(pooled_embeddings).mean(dim=0)
        return self.output_proj(graph_repr)


class GatedFusion(nn.Module):
    def __init__(self, text_dim: int, graph_dim: int) -> None:
        super().__init__()
        self.graph_proj = nn.Linear(graph_dim, text_dim)
        self.gate = nn.Linear(text_dim * 2, text_dim)

    def forward(self, h_text: torch.Tensor, h_graph: torch.Tensor) -> torch.Tensor:
        h_graph_proj = self.graph_proj(h_graph)
        joint = torch.cat([h_text, h_graph_proj], dim=-1)
        gate = torch.sigmoid(self.gate(joint))
        return gate * h_text + (1.0 - gate) * h_graph_proj


class StructuralEncoderV2(nn.Module):
    """Structural encoder that fuses GraphCodeBERT text features with AST+CFG graph context."""

    def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
        super().__init__()
        self.device = torch.device(device)
        # Tokenizer is now in dataloader, but used here for size configs or inference if needed
        self.text_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
        self.text_model.to(self.device)

        self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=self.text_model.config.hidden_size, num_layers=graph_layers)
        self.graph_encoder.to(self.device)

        self.fusion = GatedFusion(self.text_model.config.hidden_size, self.text_model.config.hidden_size)
        self.fusion.to(self.device)

    def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch | HeteroData) -> torch.Tensor:
        text_embeddings = self.encode_text(input_ids, attention_mask)
        graph_embeddings = self.graph_encoder(graph_batch)
        return self.fusion(text_embeddings, graph_embeddings)

    def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural V2 embeddings") -> np.ndarray:
        # Local resources for inference
        builder = CodeGraphBuilder()
        tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
        
        codes = df["code"].tolist()
        batches = range(0, len(codes), batch_size)
        all_embeddings: List[torch.Tensor] = []

        for start in tqdm(batches, desc=desc):
            batch_codes = codes[start:start + batch_size]
            
            data_list = [builder.build(c) for c in batch_codes]
            graph_batch = Batch.from_data_list(data_list)
            
            tok = tokenizer(batch_codes, padding=True, truncation=True, max_length=512, return_tensors="pt")
            
            with torch.no_grad():
                fused = self.forward(tok["input_ids"], tok["attention_mask"], graph_batch)
            all_embeddings.append(fused.cpu())

        embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
        if save_path is not None:
            np.save(save_path, embeddings)
        return embeddings

    def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
        if not checkpoint_path:
            raise ValueError("checkpoint_path must be provided")
        state = torch.load(checkpoint_path, map_location=map_location)
        if isinstance(state, dict) and "state_dict" in state:
            state = state["state_dict"]
        self.load_state_dict(state, strict=strict)