""" text_detector_model.py ====================== Standalone model definition for HybridAITextDetector. Import this in both training scripts and the Gradio app. Architecture: DeBERTa-v3-small → [BiLSTM | CNN | Transformer] → CrossAttentionFusion → Classifier """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel # ─── Constants ─────────────────────────────────────────────────────────────── MODEL_NAME = "microsoft/deberta-v3-small" MAX_LENGTH = 128 NUM_CLASSES = 1 # binary: sigmoid output # ─── Sub-modules ───────────────────────────────────────────────────────────── class AttentionPool(nn.Module): """Soft attention pooling over a sequence of vectors.""" def __init__(self, dim: int): super().__init__() self.attn = nn.Linear(dim, 1) def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: weights = self.attn(x) # (B, T, 1) if mask is not None: weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float("-inf")) weights = torch.softmax(weights, dim=1) # (B, T, 1) return (weights * x).sum(dim=1) # (B, dim) class BiLSTMBranch(nn.Module): """2-layer Bidirectional LSTM with Attention Pooling.""" def __init__(self, input_dim: int, hidden_dim: int = 128): super().__init__() self.lstm = nn.LSTM( input_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True, ) self.pool = AttentionPool(hidden_dim * 2) self.proj = nn.Linear(hidden_dim * 2, 128) def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: out, _ = self.lstm(x) # (B, T, 256) pooled = self.pool(out, mask) # (B, 256) return F.gelu(self.proj(pooled)) # (B, 128) class CNNBranch(nn.Module): """Multi-kernel 1D CNN with Global MaxPooling.""" def __init__(self, input_dim: int): super().__init__() self.conv3 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1) self.conv5 = nn.Conv1d(input_dim, 64, kernel_size=5, padding=2) self.conv7 = nn.Conv1d(input_dim, 64, kernel_size=7, padding=3) self.bn3 = nn.BatchNorm1d(64) self.bn5 = nn.BatchNorm1d(64) self.bn7 = nn.BatchNorm1d(64) self.proj = nn.Linear(192, 128) def forward(self, x: torch.Tensor) -> torch.Tensor: x_t = x.permute(0, 2, 1) # (B, D, T) c3 = F.gelu(self.bn3(self.conv3(x_t))) c5 = F.gelu(self.bn5(self.conv5(x_t))) c7 = F.gelu(self.bn7(self.conv7(x_t))) p3 = c3.max(dim=-1).values p5 = c5.max(dim=-1).values p7 = c7.max(dim=-1).values cat = torch.cat([p3, p5, p7], dim=-1) # (B, 192) return F.gelu(self.proj(cat)) # (B, 128) class TransformerBranch(nn.Module): """Lightweight Transformer Encoder with Attention Pooling.""" def __init__(self, input_dim: int): super().__init__() self.proj_in = nn.Linear(input_dim, 128) encoder_layer = nn.TransformerEncoderLayer( d_model=128, nhead=4, dim_feedforward=256, dropout=0.1, batch_first=True, norm_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2) self.pool = AttentionPool(128) def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: x = F.gelu(self.proj_in(x)) # (B, T, 128) src_key_padding_mask = (mask == 0) if mask is not None else None out = self.transformer(x, src_key_padding_mask=src_key_padding_mask) return self.pool(out, mask) # (B, 128) class CrossAttentionFusion(nn.Module): """Fuse 3 branch outputs via multi-head self-attention (3-token sequence).""" def __init__(self, dim: int = 128): super().__init__() self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.scale = dim ** 0.5 self.proj = nn.Linear(dim, dim) def forward( self, lstm_out: torch.Tensor, cnn_out: torch.Tensor, trans_out: torch.Tensor, ) -> torch.Tensor: stacked = torch.stack([lstm_out, cnn_out, trans_out], dim=1) # (B, 3, 128) Q = self.q(stacked) K = self.k(stacked) V = self.v(stacked) attn = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1) out = torch.bmm(attn, V).mean(dim=1) # (B, 128) return F.gelu(self.proj(out)) # ─── Main Model ────────────────────────────────────────────────────────────── class HybridAITextDetector(nn.Module): """ Hybrid AI-generated text detector. Inputs ------ input_ids : (B, T) long tensor attention_mask : (B, T) long tensor — 1 = real token, 0 = pad token_type_ids : (B, T) long tensor Output ------ logits : (B, 1) float — apply sigmoid to get P(AI-generated) """ def __init__(self): super().__init__() self.deberta = AutoModel.from_pretrained(MODEL_NAME) # Freeze first 6 transformer layers for name, param in self.deberta.named_parameters(): if any(f"layer.{i}." in name for i in range(6)): param.requires_grad = False else: param.requires_grad = True hidden = self.deberta.config.hidden_size # 768 for deberta-v3-small self.lstm_branch = BiLSTMBranch(hidden) self.cnn_branch = CNNBranch(hidden) self.trans_branch = TransformerBranch(hidden) self.fusion = CrossAttentionFusion(dim=128) self.classifier = nn.Sequential( nn.LayerNorm(128), nn.Linear(128, 128), nn.GELU(), nn.Dropout(0.4), nn.Linear(128, 64), nn.GELU(), nn.Dropout(0.3), nn.Linear(64, 1), ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, ) -> torch.Tensor: out = self.deberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) hidden = out.last_hidden_state # (B, T, 768) lstm_out = self.lstm_branch(hidden, attention_mask) cnn_out = self.cnn_branch(hidden) trans_out = self.trans_branch(hidden, attention_mask) fused = self.fusion(lstm_out, cnn_out, trans_out) return self.classifier(fused) # (B, 1) # ─── Convenience inference helper ──────────────────────────────────────────── def load_model(checkpoint_path: str, device: torch.device = None) -> HybridAITextDetector: """Load a trained HybridAITextDetector from a .pt checkpoint.""" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = HybridAITextDetector() model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.eval().to(device) return model