File size: 6,418 Bytes
cf8239a
 
 
 
3f37f90
cf8239a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f37f90
cf8239a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f37f90
cf8239a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f37f90
cf8239a
fab7714
cf8239a
 
fab7714
 
cf8239a
 
fab7714
 
cf8239a
 
 
fab7714
cf8239a
 
fab7714
cf8239a
 
3f37f90
cf8239a
 
 
 
 
 
 
3f37f90
cf8239a
3f37f90
 
cf8239a
3f37f90
 
 
 
 
 
 
cf8239a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv, global_mean_pool

# ── Rotary Embedding ─────────────────────────────────────────
class RotaryEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv", inv)
        self._cos = self._sin = None
        self._len = 0

    def _build(self, L, dev):
        if L <= self._len: return
        t = torch.arange(L, device=dev, dtype=torch.float32)
        e = torch.cat([torch.outer(t, self.inv)] * 2, dim=-1)
        self._cos = e.cos()[None, None]
        self._sin = e.sin()[None, None]
        self._len = L

    def forward(self, q, k):
        self._build(q.shape[2], q.device)
        c = self._cos[:, :, :q.shape[2]].to(q.device)
        s = self._sin[:, :, :q.shape[2]].to(q.device)
        def rot(x):
            h = x.shape[-1] // 2
            return torch.cat([-x[..., h:], x[..., :h]], dim=-1)
        return (q * c) + (rot(q) * s), (k * c) + (rot(k) * s)


# ── Transformer Block ─────────────────────────────────────────
class TBlock(nn.Module):
    def __init__(self, d, h, ff):
        super().__init__()
        self.dh = d // h; self.h = h; self.d = d
        self.qkv = nn.Linear(d, 3 * d, bias=False)
        self.out  = nn.Linear(d, d, bias=False)
        self.n1   = nn.LayerNorm(d)
        self.n2   = nn.LayerNorm(d)
        self.ff   = nn.Sequential(nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d))
        self.rope = RotaryEmb(self.dh)
        self.drop = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        q, k, v = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4)
        q, k    = self.rope(q, k)
        a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh)
        if mask is not None:
            a = a.masked_fill(mask[:, None, None, :], float("-inf"))
        a = self.drop(F.softmax(a, dim=-1))
        o = (a @ v).transpose(1, 2).reshape(B, T, self.d)
        x = self.n1(x + self.drop(self.out(o)))
        return self.n2(x + self.drop(self.ff(x)))


# ── Java Encoder ─────────────────────────────────────────────
class JavaEncoder(nn.Module):
    def __init__(self, vocab=8000, d=192, h=4, n=4, ff=384, d_proj=48):
        super().__init__()
        self.emb    = nn.Embedding(vocab, d, padding_idx=0)
        self.blocks = nn.ModuleList([TBlock(d, h, ff) for _ in range(n)])
        self.norm   = nn.LayerNorm(d)
        self.proj   = nn.Sequential(nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj))
        self.drop   = nn.Dropout(0.1)

    def encode(self, t):
        mask = (t == 0)
        x    = self.drop(self.emb(t))
        for b in self.blocks:
            x = b(x, mask)
        x = self.norm(x)
        v = (~mask).unsqueeze(-1).float()
        return (x * v).sum(1) / v.sum(1).clamp(min=1)   # [B, D]

    def forward(self, t):
        return self.proj(self.encode(t))                  # [B, D_PROJ]


# ── XML Encoder ──────────────────────────────────────────────
class XMLEncoder(nn.Module):
    def __init__(self, node_dim=64, hidden=192, n=3, d_proj=48):
        super().__init__()
        self.inp   = nn.Linear(node_dim, hidden)
        self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(n)])
        self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(n)])
        self.drop  = nn.Dropout(0.1)
        self.proj  = nn.Sequential(
            nn.Linear(hidden, hidden // 2), nn.GELU(),
            nn.Linear(hidden // 2, d_proj)
        )

    def encode(self, data):
        x, ei, b = data.x, data.edge_index, data.batch
        x = F.relu(self.inp(x))
        for conv, norm in zip(self.convs, self.norms):
            x = F.relu(self.drop(norm(conv(x, ei))))
        return global_mean_pool(x, b)   # [B, D]

    def forward(self, data):
        return self.proj(self.encode(data))   # [B, D_PROJ]


# ── Fusion Model ─────────────────────────────────────────────
class FusionModel(nn.Module):
    def __init__(self, d=192, d_proj=48):
        super().__init__()
        self.fusion = nn.Sequential(
            nn.Linear(d * 2, d), nn.GELU(),
            nn.Linear(d, d),     nn.LayerNorm(d),
        )
        self.proj = nn.Sequential(
            nn.Linear(d, d // 2), nn.GELU(),
            nn.Linear(d // 2, d_proj),
        )

    def fuse(self, j, x):
        return self.fusion(torch.cat([j, x], dim=-1))   # [B, 192]

    def forward(self, j, x):
        return self.proj(self.fuse(j, x))               # [B, 48]


# ── Full APK-BERT Model ──────────────────────────────────────
class APKBert(nn.Module):
    def __init__(self, vocab=8000):
        super().__init__()
        self.java_encoder = JavaEncoder(vocab=vocab)
        self.xml_encoder  = XMLEncoder()
        self.fusion       = FusionModel()

    def forward(self, java_tokens, xml_graph):
        """
        Forward pass: takes Java token IDs and XML graph data,
        returns fused APK representation [B, 48].
        """
        j = self.java_encoder(java_tokens)   # [B, 48]
        x = self.xml_encoder(xml_graph)      # [B, 48]
        fused = self.fusion(j, x)            # [B, 48]
        return fused

    @classmethod
    def from_pretrained(cls, model_dir, vocab=8000, device="cpu"):
        model = cls(vocab=vocab)
        model.java_encoder.load_state_dict(
            torch.load(f"{model_dir}/java_encoder.pt", map_location=device))
        model.xml_encoder.load_state_dict(
            torch.load(f"{model_dir}/xml_encoder.pt",  map_location=device))
        model.fusion.load_state_dict(
            torch.load(f"{model_dir}/fusion.pt",       map_location=device))
        model = model.to(device)
        model.eval()
        print(f"βœ… APKBert loaded from {model_dir}")
        return model