| """ |
| text_detector_model.py |
| ====================== |
| Standalone model definition for HybridAITextDetector. |
| Import this in both training scripts and the Gradio app. |
| |
| Architecture: |
| DeBERTa-v3-small β [BiLSTM | CNN | Transformer] β CrossAttentionFusion β Classifier |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel |
|
|
| |
| MODEL_NAME = "microsoft/deberta-v3-small" |
| MAX_LENGTH = 128 |
| NUM_CLASSES = 1 |
|
|
|
|
| |
|
|
| class AttentionPool(nn.Module): |
| """Soft attention pooling over a sequence of vectors.""" |
| def __init__(self, dim: int): |
| super().__init__() |
| self.attn = nn.Linear(dim, 1) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
| weights = self.attn(x) |
| if mask is not None: |
| weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float("-inf")) |
| weights = torch.softmax(weights, dim=1) |
| return (weights * x).sum(dim=1) |
|
|
|
|
| class BiLSTMBranch(nn.Module): |
| """2-layer Bidirectional LSTM with Attention Pooling.""" |
| def __init__(self, input_dim: int, hidden_dim: int = 128): |
| super().__init__() |
| self.lstm = nn.LSTM( |
| input_dim, hidden_dim, |
| num_layers=2, |
| batch_first=True, |
| dropout=0.2, |
| bidirectional=True, |
| ) |
| self.pool = AttentionPool(hidden_dim * 2) |
| self.proj = nn.Linear(hidden_dim * 2, 128) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
| out, _ = self.lstm(x) |
| pooled = self.pool(out, mask) |
| return F.gelu(self.proj(pooled)) |
|
|
|
|
| class CNNBranch(nn.Module): |
| """Multi-kernel 1D CNN with Global MaxPooling.""" |
| def __init__(self, input_dim: int): |
| super().__init__() |
| self.conv3 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1) |
| self.conv5 = nn.Conv1d(input_dim, 64, kernel_size=5, padding=2) |
| self.conv7 = nn.Conv1d(input_dim, 64, kernel_size=7, padding=3) |
| self.bn3 = nn.BatchNorm1d(64) |
| self.bn5 = nn.BatchNorm1d(64) |
| self.bn7 = nn.BatchNorm1d(64) |
| self.proj = nn.Linear(192, 128) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_t = x.permute(0, 2, 1) |
| c3 = F.gelu(self.bn3(self.conv3(x_t))) |
| c5 = F.gelu(self.bn5(self.conv5(x_t))) |
| c7 = F.gelu(self.bn7(self.conv7(x_t))) |
| p3 = c3.max(dim=-1).values |
| p5 = c5.max(dim=-1).values |
| p7 = c7.max(dim=-1).values |
| cat = torch.cat([p3, p5, p7], dim=-1) |
| return F.gelu(self.proj(cat)) |
|
|
|
|
| class TransformerBranch(nn.Module): |
| """Lightweight Transformer Encoder with Attention Pooling.""" |
| def __init__(self, input_dim: int): |
| super().__init__() |
| self.proj_in = nn.Linear(input_dim, 128) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=128, nhead=4, |
| dim_feedforward=256, |
| dropout=0.1, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2) |
| self.pool = AttentionPool(128) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
| x = F.gelu(self.proj_in(x)) |
| src_key_padding_mask = (mask == 0) if mask is not None else None |
| out = self.transformer(x, src_key_padding_mask=src_key_padding_mask) |
| return self.pool(out, mask) |
|
|
|
|
| class CrossAttentionFusion(nn.Module): |
| """Fuse 3 branch outputs via multi-head self-attention (3-token sequence).""" |
| def __init__(self, dim: int = 128): |
| super().__init__() |
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.scale = dim ** 0.5 |
| self.proj = nn.Linear(dim, dim) |
|
|
| def forward( |
| self, |
| lstm_out: torch.Tensor, |
| cnn_out: torch.Tensor, |
| trans_out: torch.Tensor, |
| ) -> torch.Tensor: |
| stacked = torch.stack([lstm_out, cnn_out, trans_out], dim=1) |
| Q = self.q(stacked) |
| K = self.k(stacked) |
| V = self.v(stacked) |
| attn = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1) |
| out = torch.bmm(attn, V).mean(dim=1) |
| return F.gelu(self.proj(out)) |
|
|
|
|
| |
|
|
| class HybridAITextDetector(nn.Module): |
| """ |
| Hybrid AI-generated text detector. |
| |
| Inputs |
| ------ |
| input_ids : (B, T) long tensor |
| attention_mask : (B, T) long tensor β 1 = real token, 0 = pad |
| token_type_ids : (B, T) long tensor |
| |
| Output |
| ------ |
| logits : (B, 1) float β apply sigmoid to get P(AI-generated) |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.deberta = AutoModel.from_pretrained(MODEL_NAME) |
|
|
| |
| for name, param in self.deberta.named_parameters(): |
| if any(f"layer.{i}." in name for i in range(6)): |
| param.requires_grad = False |
| else: |
| param.requires_grad = True |
|
|
| hidden = self.deberta.config.hidden_size |
|
|
| self.lstm_branch = BiLSTMBranch(hidden) |
| self.cnn_branch = CNNBranch(hidden) |
| self.trans_branch = TransformerBranch(hidden) |
| self.fusion = CrossAttentionFusion(dim=128) |
|
|
| self.classifier = nn.Sequential( |
| nn.LayerNorm(128), |
| nn.Linear(128, 128), |
| nn.GELU(), |
| nn.Dropout(0.4), |
| nn.Linear(128, 64), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(64, 1), |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| out = self.deberta( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| hidden = out.last_hidden_state |
| lstm_out = self.lstm_branch(hidden, attention_mask) |
| cnn_out = self.cnn_branch(hidden) |
| trans_out = self.trans_branch(hidden, attention_mask) |
| fused = self.fusion(lstm_out, cnn_out, trans_out) |
| return self.classifier(fused) |
|
|
|
|
| |
|
|
| def load_model(checkpoint_path: str, device: torch.device = None) -> HybridAITextDetector: |
| """Load a trained HybridAITextDetector from a .pt checkpoint.""" |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = HybridAITextDetector() |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) |
| model.eval().to(device) |
| return model |
|
|