File size: 3,372 Bytes
17f2e83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Any
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Batch

# Import builder from dataloader for inference
from dataloader import CodeGraphBuilder

from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion

class StructuralEncoderOnlyGraph(nn.Module):
    """
    Ablation variant 1: Pure Structural Encoder.
    Removes GraphCodeBERT and uses only the graph path (R-GNN).
    """

    def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
        super().__init__()
        self.device = torch.device(device)
        
        self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
        self.graph_encoder.to(self.device)

    def forward(self, codes: List[str], graph_batch: Batch) -> torch.Tensor:
        return self.graph_encoder(graph_batch)
    
    def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
        builder = CodeGraphBuilder()
        codes = df["code"].tolist()
        batches = range(0, len(codes), batch_size)
        all_embeddings: List[torch.Tensor] = []

        for start in tqdm(batches, desc=desc):
            batch_codes = codes[start:start + batch_size]
            
            data_list = [builder.build(c) for c in batch_codes]
            graph_batch = Batch.from_data_list(data_list)
            
            with torch.no_grad():
                out = self.forward(graph_batch)
            all_embeddings.append(out.cpu())

        embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
        if save_path is not None:
            np.save(save_path, embeddings)
        return embeddings

    def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
        if not checkpoint_path:
            raise ValueError("checkpoint_path must be provided")
        state = torch.load(checkpoint_path, map_location=map_location)
        if isinstance(state, dict) and "state_dict" in state:
            state = state["state_dict"]
        self.load_state_dict(state, strict=strict)


class StructuralEncoderConcat(StructuralEncoderV2):
    """
    Ablation variant 2: Concatenation Fusion.
    Keeps both text and graph paths but fuses them via simple concatenation + projection
    instead of Gated Fusion.
    """

    def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
        super().__init__(device, graph_hidden_dim, graph_layers)
        
        text_dim = self.text_model.config.hidden_size
        graph_dim = self.text_model.config.hidden_size
        
        self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
        self.concat_proj.to(self.device)
        
        del self.fusion

    def forward(self, codes: List[str], graph_batch: Batch) -> torch.Tensor:
        text_embeddings = self.encode_text(codes)
        graph_embeddings = self.graph_encoder(graph_batch)
        
        combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
        return self.concat_proj(combined)