APKBert / modeling_apkbert.py
Tanjid0's picture
Update modeling_apkbert.py
fab7714 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv, global_mean_pool
# ── Rotary Embedding ─────────────────────────────────────────
class RotaryEmb(nn.Module):
def __init__(self, dim):
super().__init__()
inv = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv", inv)
self._cos = self._sin = None
self._len = 0
def _build(self, L, dev):
if L <= self._len: return
t = torch.arange(L, device=dev, dtype=torch.float32)
e = torch.cat([torch.outer(t, self.inv)] * 2, dim=-1)
self._cos = e.cos()[None, None]
self._sin = e.sin()[None, None]
self._len = L
def forward(self, q, k):
self._build(q.shape[2], q.device)
c = self._cos[:, :, :q.shape[2]].to(q.device)
s = self._sin[:, :, :q.shape[2]].to(q.device)
def rot(x):
h = x.shape[-1] // 2
return torch.cat([-x[..., h:], x[..., :h]], dim=-1)
return (q * c) + (rot(q) * s), (k * c) + (rot(k) * s)
# ── Transformer Block ─────────────────────────────────────────
class TBlock(nn.Module):
def __init__(self, d, h, ff):
super().__init__()
self.dh = d // h; self.h = h; self.d = d
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.out = nn.Linear(d, d, bias=False)
self.n1 = nn.LayerNorm(d)
self.n2 = nn.LayerNorm(d)
self.ff = nn.Sequential(nn.Linear(d, ff), nn.GELU(), nn.Linear(ff, d))
self.rope = RotaryEmb(self.dh)
self.drop = nn.Dropout(0.1)
def forward(self, x, mask=None):
B, T, _ = x.shape
q, k, v = self.qkv(x).reshape(B, T, 3, self.h, self.dh).permute(2, 0, 3, 1, 4)
q, k = self.rope(q, k)
a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh)
if mask is not None:
a = a.masked_fill(mask[:, None, None, :], float("-inf"))
a = self.drop(F.softmax(a, dim=-1))
o = (a @ v).transpose(1, 2).reshape(B, T, self.d)
x = self.n1(x + self.drop(self.out(o)))
return self.n2(x + self.drop(self.ff(x)))
# ── Java Encoder ─────────────────────────────────────────────
class JavaEncoder(nn.Module):
def __init__(self, vocab=8000, d=192, h=4, n=4, ff=384, d_proj=48):
super().__init__()
self.emb = nn.Embedding(vocab, d, padding_idx=0)
self.blocks = nn.ModuleList([TBlock(d, h, ff) for _ in range(n)])
self.norm = nn.LayerNorm(d)
self.proj = nn.Sequential(nn.Linear(d, d // 2), nn.GELU(), nn.Linear(d // 2, d_proj))
self.drop = nn.Dropout(0.1)
def encode(self, t):
mask = (t == 0)
x = self.drop(self.emb(t))
for b in self.blocks:
x = b(x, mask)
x = self.norm(x)
v = (~mask).unsqueeze(-1).float()
return (x * v).sum(1) / v.sum(1).clamp(min=1) # [B, D]
def forward(self, t):
return self.proj(self.encode(t)) # [B, D_PROJ]
# ── XML Encoder ──────────────────────────────────────────────
class XMLEncoder(nn.Module):
def __init__(self, node_dim=64, hidden=192, n=3, d_proj=48):
super().__init__()
self.inp = nn.Linear(node_dim, hidden)
self.convs = nn.ModuleList([GCNConv(hidden, hidden) for _ in range(n)])
self.norms = nn.ModuleList([nn.BatchNorm1d(hidden) for _ in range(n)])
self.drop = nn.Dropout(0.1)
self.proj = nn.Sequential(
nn.Linear(hidden, hidden // 2), nn.GELU(),
nn.Linear(hidden // 2, d_proj)
)
def encode(self, data):
x, ei, b = data.x, data.edge_index, data.batch
x = F.relu(self.inp(x))
for conv, norm in zip(self.convs, self.norms):
x = F.relu(self.drop(norm(conv(x, ei))))
return global_mean_pool(x, b) # [B, D]
def forward(self, data):
return self.proj(self.encode(data)) # [B, D_PROJ]
# ── Fusion Model ─────────────────────────────────────────────
class FusionModel(nn.Module):
def __init__(self, d=192, d_proj=48):
super().__init__()
self.fusion = nn.Sequential(
nn.Linear(d * 2, d), nn.GELU(),
nn.Linear(d, d), nn.LayerNorm(d),
)
self.proj = nn.Sequential(
nn.Linear(d, d // 2), nn.GELU(),
nn.Linear(d // 2, d_proj),
)
def fuse(self, j, x):
return self.fusion(torch.cat([j, x], dim=-1)) # [B, 192]
def forward(self, j, x):
return self.proj(self.fuse(j, x)) # [B, 48]
# ── Full APK-BERT Model ──────────────────────────────────────
class APKBert(nn.Module):
def __init__(self, vocab=8000):
super().__init__()
self.java_encoder = JavaEncoder(vocab=vocab)
self.xml_encoder = XMLEncoder()
self.fusion = FusionModel()
def forward(self, java_tokens, xml_graph):
"""
Forward pass: takes Java token IDs and XML graph data,
returns fused APK representation [B, 48].
"""
j = self.java_encoder(java_tokens) # [B, 48]
x = self.xml_encoder(xml_graph) # [B, 48]
fused = self.fusion(j, x) # [B, 48]
return fused
@classmethod
def from_pretrained(cls, model_dir, vocab=8000, device="cpu"):
model = cls(vocab=vocab)
model.java_encoder.load_state_dict(
torch.load(f"{model_dir}/java_encoder.pt", map_location=device))
model.xml_encoder.load_state_dict(
torch.load(f"{model_dir}/xml_encoder.pt", map_location=device))
model.fusion.load_state_dict(
torch.load(f"{model_dir}/fusion.pt", map_location=device))
model = model.to(device)
model.eval()
print(f"βœ… APKBert loaded from {model_dir}")
return model