dv4aby commited on
Commit
17f2e83
·
verified ·
1 Parent(s): 87d2e85

Upload source code structural_encoder_ablation.py

Browse files
Files changed (1) hide show
  1. structural_encoder_ablation.py +83 -0
structural_encoder_ablation.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Tuple, Any
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import pandas as pd
7
+ from torch_geometric.data import Batch
8
+
9
+ # Import builder from dataloader for inference
10
+ from dataloader import CodeGraphBuilder
11
+
12
+ from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion
13
+
14
+ class StructuralEncoderOnlyGraph(nn.Module):
15
+ """
16
+ Ablation variant 1: Pure Structural Encoder.
17
+ Removes GraphCodeBERT and uses only the graph path (R-GNN).
18
+ """
19
+
20
+ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
21
+ super().__init__()
22
+ self.device = torch.device(device)
23
+
24
+ self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
25
+ self.graph_encoder.to(self.device)
26
+
27
+ def forward(self, codes: List[str], graph_batch: Batch) -> torch.Tensor:
28
+ return self.graph_encoder(graph_batch)
29
+
30
+ def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
31
+ builder = CodeGraphBuilder()
32
+ codes = df["code"].tolist()
33
+ batches = range(0, len(codes), batch_size)
34
+ all_embeddings: List[torch.Tensor] = []
35
+
36
+ for start in tqdm(batches, desc=desc):
37
+ batch_codes = codes[start:start + batch_size]
38
+
39
+ data_list = [builder.build(c) for c in batch_codes]
40
+ graph_batch = Batch.from_data_list(data_list)
41
+
42
+ with torch.no_grad():
43
+ out = self.forward(graph_batch)
44
+ all_embeddings.append(out.cpu())
45
+
46
+ embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
47
+ if save_path is not None:
48
+ np.save(save_path, embeddings)
49
+ return embeddings
50
+
51
+ def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
52
+ if not checkpoint_path:
53
+ raise ValueError("checkpoint_path must be provided")
54
+ state = torch.load(checkpoint_path, map_location=map_location)
55
+ if isinstance(state, dict) and "state_dict" in state:
56
+ state = state["state_dict"]
57
+ self.load_state_dict(state, strict=strict)
58
+
59
+
60
+ class StructuralEncoderConcat(StructuralEncoderV2):
61
+ """
62
+ Ablation variant 2: Concatenation Fusion.
63
+ Keeps both text and graph paths but fuses them via simple concatenation + projection
64
+ instead of Gated Fusion.
65
+ """
66
+
67
+ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
68
+ super().__init__(device, graph_hidden_dim, graph_layers)
69
+
70
+ text_dim = self.text_model.config.hidden_size
71
+ graph_dim = self.text_model.config.hidden_size
72
+
73
+ self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
74
+ self.concat_proj.to(self.device)
75
+
76
+ del self.fusion
77
+
78
+ def forward(self, codes: List[str], graph_batch: Batch) -> torch.Tensor:
79
+ text_embeddings = self.encode_text(codes)
80
+ graph_embeddings = self.graph_encoder(graph_batch)
81
+
82
+ combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
83
+ return self.concat_proj(combined)