import torch import torch.nn as nn import torch.nn.functional as F import math from torch_geometric.nn import GCNConv, global_mean_pool # ── Rotary Embedding ───────────────────────────────────────── class RotaryEmb(nn.Module): def __init__(self, dim): super().__init__() inv = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv", inv) self._cos = self._sin = None self._len = 0 def _build(self, L, dev): if L <= self._len: return t = torch.arange(L, device=dev, dtype=torch.float32) e = torch.cat([torch.outer(t, self.inv)] * 2, dim=-1) self._cos = e.cos()[None, None] self._sin = e.sin()[None, None] self._len = L def forward(self, q, k): self._build(q.shape[2], q.device) c = self._cos[:, :, :q.shape[2]].to(q.device) s = self._sin[:, :, :q.shape[2]].to(q.device) def rot(x): h = x.shape[-1] // 2 return torch.cat([-x[..., h:], x[..., :h]], dim=-1) return (q * c) + (rot(q) * s), (k * c) + (rot(k) * s) # ── Transformer Block ───────────────────────────────────────── class TBlock(nn.Module): def __init__(self, d, h, ff): super().__init__() self.dh = d // h; self.h = h; self.d = d self.qkv = nn.Linear(d, 3 * d, bias=False) self.out = nn.Linear(d, d, bias=False) self.n1 = nn.LayerNorm(d) self.n2 = nn.LayerNorm(d) self.ff = nn.Sequential(nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d)) self.rope = RotaryEmb(self.dh) self.drop = nn.Dropout(0.1) def forward(self, x, mask=None): B, T, _ = x.shape q, k, v = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) q, k = self.rope(q, k) a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) if mask is not None: a = a.masked_fill(mask[:, None, None, :], float("-inf")) a = self.drop(F.softmax(a, dim=-1)) o = (a @ v).transpose(1, 2).reshape(B, T, self.d) x = self.n1(x + self.drop(self.out(o))) return self.n2(x + self.drop(self.ff(x))) # ── Java Encoder ───────────────────────────────────────────── class JavaEncoder(nn.Module): def __init__(self, vocab=8000, d=192, h=4, n=4, ff=384, d_proj=48): super().__init__() self.emb = nn.Embedding(vocab, d, padding_idx=0) self.blocks = nn.ModuleList([TBlock(d, h, ff) for _ in range(n)]) self.norm = nn.LayerNorm(d) self.proj = nn.Sequential(nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj)) self.drop = nn.Dropout(0.1) def encode(self, t): mask = (t == 0) x = self.drop(self.emb(t)) for b in self.blocks: x = b(x, mask) x = self.norm(x) v = (~mask).unsqueeze(-1).float() return (x * v).sum(1) / v.sum(1).clamp(min=1) # [B, D] def forward(self, t): return self.proj(self.encode(t)) # [B, D_PROJ] # ── XML Encoder ────────────────────────────────────────────── class XMLEncoder(nn.Module): def __init__(self, node_dim=64, hidden=192, n=3, d_proj=48): super().__init__() self.inp = nn.Linear(node_dim, hidden) self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(n)]) self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(n)]) self.drop = nn.Dropout(0.1) self.proj = nn.Sequential( nn.Linear(hidden, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, d_proj) ) def encode(self, data): x, ei, b = data.x, data.edge_index, data.batch x = F.relu(self.inp(x)) for conv, norm in zip(self.convs, self.norms): x = F.relu(self.drop(norm(conv(x, ei)))) return global_mean_pool(x, b) # [B, D] def forward(self, data): return self.proj(self.encode(data)) # [B, D_PROJ] # ── Fusion Model ───────────────────────────────────────────── class FusionModel(nn.Module): def __init__(self, d=192, d_proj=48): super().__init__() self.fusion = nn.Sequential( nn.Linear(d * 2, d), nn.GELU(), nn.Linear(d, d), nn.LayerNorm(d), ) self.proj = nn.Sequential( nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj), ) def fuse(self, j, x): return self.fusion(torch.cat([j, x], dim=-1)) # [B, 192] def forward(self, j, x): return self.proj(self.fuse(j, x)) # [B, 48] # ── Full APK-BERT Model ────────────────────────────────────── class APKBert(nn.Module): def __init__(self, vocab=8000): super().__init__() self.java_encoder = JavaEncoder(vocab=vocab) self.xml_encoder = XMLEncoder() self.fusion = FusionModel() def forward(self, java_tokens, xml_graph): """ Forward pass: takes Java token IDs and XML graph data, returns fused APK representation [B, 48]. """ j = self.java_encoder(java_tokens) # [B, 48] x = self.xml_encoder(xml_graph) # [B, 48] fused = self.fusion(j, x) # [B, 48] return fused @classmethod def from_pretrained(cls, model_dir, vocab=8000, device="cpu"): model = cls(vocab=vocab) model.java_encoder.load_state_dict( torch.load(f"{model_dir}/java_encoder.pt", map_location=device)) model.xml_encoder.load_state_dict( torch.load(f"{model_dir}/xml_encoder.pt", map_location=device)) model.fusion.load_state_dict( torch.load(f"{model_dir}/fusion.pt", map_location=device)) model = model.to(device) model.eval() print(f"✅ APKBert loaded from {model_dir}") return model