| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from torch_geometric.nn import GCNConv, global_mean_pool |
|
|
| |
| class RotaryEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| inv = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv", inv) |
| self._cos = self._sin = None |
| self._len = 0 |
|
|
| def _build(self, L, dev): |
| if L <= self._len: return |
| t = torch.arange(L, device=dev, dtype=torch.float32) |
| e = torch.cat([torch.outer(t, self.inv)] * 2, dim=-1) |
| self._cos = e.cos()[None, None] |
| self._sin = e.sin()[None, None] |
| self._len = L |
|
|
| def forward(self, q, k): |
| self._build(q.shape[2], q.device) |
| c = self._cos[:, :, :q.shape[2]].to(q.device) |
| s = self._sin[:, :, :q.shape[2]].to(q.device) |
| def rot(x): |
| h = x.shape[-1] // 2 |
| return torch.cat([-x[..., h:], x[..., :h]], dim=-1) |
| return (q * c) + (rot(q) * s), (k * c) + (rot(k) * s) |
|
|
|
|
| |
| class TBlock(nn.Module): |
| def __init__(self, d, h, ff): |
| super().__init__() |
| self.dh = d // h; self.h = h; self.d = d |
| self.qkv = nn.Linear(d, 3 * d, bias=False) |
| self.out = nn.Linear(d, d, bias=False) |
| self.n1 = nn.LayerNorm(d) |
| self.n2 = nn.LayerNorm(d) |
| self.ff = nn.Sequential(nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d)) |
| self.rope = RotaryEmb(self.dh) |
| self.drop = nn.Dropout(0.1) |
|
|
| def forward(self, x, mask=None): |
| B, T, _ = x.shape |
| q, k, v = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) |
| q, k = self.rope(q, k) |
| a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) |
| if mask is not None: |
| a = a.masked_fill(mask[:, None, None, :], float("-inf")) |
| a = self.drop(F.softmax(a, dim=-1)) |
| o = (a @ v).transpose(1, 2).reshape(B, T, self.d) |
| x = self.n1(x + self.drop(self.out(o))) |
| return self.n2(x + self.drop(self.ff(x))) |
|
|
|
|
| |
| class JavaEncoder(nn.Module): |
| def __init__(self, vocab=8000, d=192, h=4, n=4, ff=384, d_proj=48): |
| super().__init__() |
| self.emb = nn.Embedding(vocab, d, padding_idx=0) |
| self.blocks = nn.ModuleList([TBlock(d, h, ff) for _ in range(n)]) |
| self.norm = nn.LayerNorm(d) |
| self.proj = nn.Sequential(nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj)) |
| self.drop = nn.Dropout(0.1) |
|
|
| def encode(self, t): |
| mask = (t == 0) |
| x = self.drop(self.emb(t)) |
| for b in self.blocks: |
| x = b(x, mask) |
| x = self.norm(x) |
| v = (~mask).unsqueeze(-1).float() |
| return (x * v).sum(1) / v.sum(1).clamp(min=1) |
|
|
| def forward(self, t): |
| return self.proj(self.encode(t)) |
|
|
|
|
| |
| class XMLEncoder(nn.Module): |
| def __init__(self, node_dim=64, hidden=192, n=3, d_proj=48): |
| super().__init__() |
| self.inp = nn.Linear(node_dim, hidden) |
| self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(n)]) |
| self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(n)]) |
| self.drop = nn.Dropout(0.1) |
| self.proj = nn.Sequential( |
| nn.Linear(hidden, hidden // 2), nn.GELU(), |
| nn.Linear(hidden // 2, d_proj) |
| ) |
|
|
| def encode(self, data): |
| x, ei, b = data.x, data.edge_index, data.batch |
| x = F.relu(self.inp(x)) |
| for conv, norm in zip(self.convs, self.norms): |
| x = F.relu(self.drop(norm(conv(x, ei)))) |
| return global_mean_pool(x, b) |
|
|
| def forward(self, data): |
| return self.proj(self.encode(data)) |
|
|
|
|
| |
| class FusionModel(nn.Module): |
| def __init__(self, d=192, d_proj=48): |
| super().__init__() |
| self.fusion = nn.Sequential( |
| nn.Linear(d * 2, d), nn.GELU(), |
| nn.Linear(d, d), nn.LayerNorm(d), |
| ) |
| self.proj = nn.Sequential( |
| nn.Linear(d, d // 2), nn.GELU(), |
| nn.Linear(d // 2, d_proj), |
| ) |
|
|
| def fuse(self, j, x): |
| return self.fusion(torch.cat([j, x], dim=-1)) |
|
|
| def forward(self, j, x): |
| return self.proj(self.fuse(j, x)) |
|
|
|
|
| |
| class APKBert(nn.Module): |
| def __init__(self, vocab=8000): |
| super().__init__() |
| self.java_encoder = JavaEncoder(vocab=vocab) |
| self.xml_encoder = XMLEncoder() |
| self.fusion = FusionModel() |
|
|
| def forward(self, java_tokens, xml_graph): |
| """ |
| Forward pass: takes Java token IDs and XML graph data, |
| returns fused APK representation [B, 48]. |
| """ |
| j = self.java_encoder(java_tokens) |
| x = self.xml_encoder(xml_graph) |
| fused = self.fusion(j, x) |
| return fused |
|
|
| @classmethod |
| def from_pretrained(cls, model_dir, vocab=8000, device="cpu"): |
| model = cls(vocab=vocab) |
| model.java_encoder.load_state_dict( |
| torch.load(f"{model_dir}/java_encoder.pt", map_location=device)) |
| model.xml_encoder.load_state_dict( |
| torch.load(f"{model_dir}/xml_encoder.pt", map_location=device)) |
| model.fusion.load_state_dict( |
| torch.load(f"{model_dir}/fusion.pt", map_location=device)) |
| model = model.to(device) |
| model.eval() |
| print(f"β
APKBert loaded from {model_dir}") |
| return model |
|
|