File size: 6,418 Bytes
cf8239a 3f37f90 cf8239a 3f37f90 cf8239a 3f37f90 cf8239a 3f37f90 cf8239a fab7714 cf8239a fab7714 cf8239a fab7714 cf8239a fab7714 cf8239a fab7714 cf8239a 3f37f90 cf8239a 3f37f90 cf8239a 3f37f90 cf8239a 3f37f90 cf8239a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv, global_mean_pool
# ββ Rotary Embedding βββββββββββββββββββββββββββββββββββββββββ
class RotaryEmb(nn.Module):
def __init__(self, dim):
super().__init__()
inv = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv", inv)
self._cos = self._sin = None
self._len = 0
def _build(self, L, dev):
if L <= self._len: return
t = torch.arange(L, device=dev, dtype=torch.float32)
e = torch.cat([torch.outer(t, self.inv)] * 2, dim=-1)
self._cos = e.cos()[None, None]
self._sin = e.sin()[None, None]
self._len = L
def forward(self, q, k):
self._build(q.shape[2], q.device)
c = self._cos[:, :, :q.shape[2]].to(q.device)
s = self._sin[:, :, :q.shape[2]].to(q.device)
def rot(x):
h = x.shape[-1] // 2
return torch.cat([-x[..., h:], x[..., :h]], dim=-1)
return (q * c) + (rot(q) * s), (k * c) + (rot(k) * s)
# ββ Transformer Block βββββββββββββββββββββββββββββββββββββββββ
class TBlock(nn.Module):
def __init__(self, d, h, ff):
super().__init__()
self.dh = d // h; self.h = h; self.d = d
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.out = nn.Linear(d, d, bias=False)
self.n1 = nn.LayerNorm(d)
self.n2 = nn.LayerNorm(d)
self.ff = nn.Sequential(nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d))
self.rope = RotaryEmb(self.dh)
self.drop = nn.Dropout(0.1)
def forward(self, x, mask=None):
B, T, _ = x.shape
q, k, v = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4)
q, k = self.rope(q, k)
a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh)
if mask is not None:
a = a.masked_fill(mask[:, None, None, :], float("-inf"))
a = self.drop(F.softmax(a, dim=-1))
o = (a @ v).transpose(1, 2).reshape(B, T, self.d)
x = self.n1(x + self.drop(self.out(o)))
return self.n2(x + self.drop(self.ff(x)))
# ββ Java Encoder βββββββββββββββββββββββββββββββββββββββββββββ
class JavaEncoder(nn.Module):
def __init__(self, vocab=8000, d=192, h=4, n=4, ff=384, d_proj=48):
super().__init__()
self.emb = nn.Embedding(vocab, d, padding_idx=0)
self.blocks = nn.ModuleList([TBlock(d, h, ff) for _ in range(n)])
self.norm = nn.LayerNorm(d)
self.proj = nn.Sequential(nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj))
self.drop = nn.Dropout(0.1)
def encode(self, t):
mask = (t == 0)
x = self.drop(self.emb(t))
for b in self.blocks:
x = b(x, mask)
x = self.norm(x)
v = (~mask).unsqueeze(-1).float()
return (x * v).sum(1) / v.sum(1).clamp(min=1) # [B, D]
def forward(self, t):
return self.proj(self.encode(t)) # [B, D_PROJ]
# ββ XML Encoder ββββββββββββββββββββββββββββββββββββββββββββββ
class XMLEncoder(nn.Module):
def __init__(self, node_dim=64, hidden=192, n=3, d_proj=48):
super().__init__()
self.inp = nn.Linear(node_dim, hidden)
self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(n)])
self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(n)])
self.drop = nn.Dropout(0.1)
self.proj = nn.Sequential(
nn.Linear(hidden, hidden // 2), nn.GELU(),
nn.Linear(hidden // 2, d_proj)
)
def encode(self, data):
x, ei, b = data.x, data.edge_index, data.batch
x = F.relu(self.inp(x))
for conv, norm in zip(self.convs, self.norms):
x = F.relu(self.drop(norm(conv(x, ei))))
return global_mean_pool(x, b) # [B, D]
def forward(self, data):
return self.proj(self.encode(data)) # [B, D_PROJ]
# ββ Fusion Model βββββββββββββββββββββββββββββββββββββββββββββ
class FusionModel(nn.Module):
def __init__(self, d=192, d_proj=48):
super().__init__()
self.fusion = nn.Sequential(
nn.Linear(d * 2, d), nn.GELU(),
nn.Linear(d, d), nn.LayerNorm(d),
)
self.proj = nn.Sequential(
nn.Linear(d, d // 2), nn.GELU(),
nn.Linear(d // 2, d_proj),
)
def fuse(self, j, x):
return self.fusion(torch.cat([j, x], dim=-1)) # [B, 192]
def forward(self, j, x):
return self.proj(self.fuse(j, x)) # [B, 48]
# ββ Full APK-BERT Model ββββββββββββββββββββββββββββββββββββββ
class APKBert(nn.Module):
def __init__(self, vocab=8000):
super().__init__()
self.java_encoder = JavaEncoder(vocab=vocab)
self.xml_encoder = XMLEncoder()
self.fusion = FusionModel()
def forward(self, java_tokens, xml_graph):
"""
Forward pass: takes Java token IDs and XML graph data,
returns fused APK representation [B, 48].
"""
j = self.java_encoder(java_tokens) # [B, 48]
x = self.xml_encoder(xml_graph) # [B, 48]
fused = self.fusion(j, x) # [B, 48]
return fused
@classmethod
def from_pretrained(cls, model_dir, vocab=8000, device="cpu"):
model = cls(vocab=vocab)
model.java_encoder.load_state_dict(
torch.load(f"{model_dir}/java_encoder.pt", map_location=device))
model.xml_encoder.load_state_dict(
torch.load(f"{model_dir}/xml_encoder.pt", map_location=device))
model.fusion.load_state_dict(
torch.load(f"{model_dir}/fusion.pt", map_location=device))
model = model.to(device)
model.eval()
print(f"β
APKBert loaded from {model_dir}")
return model
|