Procedural-Reranker-Synthetic / structural_encoder_ablation.py
dv4aby's picture
Upload source code structural_encoder_ablation.py
8ca94cc verified
raw
history blame
3.75 kB
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Any
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Batch
from transformers import AutoTokenizer
# Import builder from dataloader for inference
from dataloader import CodeGraphBuilder
from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion
class StructuralEncoderOnlyGraph(nn.Module):
"""
Ablation variant 1: Pure Structural Encoder.
Removes GraphCodeBERT and uses only the graph path (R-GNN).
"""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
super().__init__()
self.device = torch.device(device)
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
self.graph_encoder.to(self.device)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
# Ignore text inputs for OnlyGraph
return self.graph_encoder(graph_batch)
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
builder = CodeGraphBuilder()
codes = df["code"].tolist()
batches = range(0, len(codes), batch_size)
all_embeddings: List[torch.Tensor] = []
for start in tqdm(batches, desc=desc):
batch_codes = codes[start:start + batch_size]
data_list = [builder.build(c) for c in batch_codes]
graph_batch = Batch.from_data_list(data_list)
# Dummy inputs for signature compatibility
dummy_ids = torch.zeros((1,1), device=self.device)
dummy_mask = torch.zeros((1,1), device=self.device)
with torch.no_grad():
out = self.forward(dummy_ids, dummy_mask, graph_batch)
all_embeddings.append(out.cpu())
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
if save_path is not None:
np.save(save_path, embeddings)
return embeddings
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
if not checkpoint_path:
raise ValueError("checkpoint_path must be provided")
state = torch.load(checkpoint_path, map_location=map_location)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
self.load_state_dict(state, strict=strict)
class StructuralEncoderConcat(StructuralEncoderV2):
"""
Ablation variant 2: Concatenation Fusion.
Keeps both text and graph paths but fuses them via simple concatenation + projection
instead of Gated Fusion.
"""
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
super().__init__(device, graph_hidden_dim, graph_layers)
text_dim = self.text_model.config.hidden_size
graph_dim = self.text_model.config.hidden_size
self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
self.concat_proj.to(self.device)
del self.fusion
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
text_embeddings = self.encode_text(input_ids, attention_mask)
graph_embeddings = self.graph_encoder(graph_batch)
combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
return self.concat_proj(combined)