Multi_Modal_Deepfake_Detection / text_detector_model.py
pavankumarvk's picture
Upload 2 files
a7e89c0 verified
"""
text_detector_model.py
======================
Standalone model definition for HybridAITextDetector.
Import this in both training scripts and the Gradio app.
Architecture:
DeBERTa-v3-small β†’ [BiLSTM | CNN | Transformer] β†’ CrossAttentionFusion β†’ Classifier
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
# ─── Constants ───────────────────────────────────────────────────────────────
MODEL_NAME = "microsoft/deberta-v3-small"
MAX_LENGTH = 128
NUM_CLASSES = 1 # binary: sigmoid output
# ─── Sub-modules ─────────────────────────────────────────────────────────────
class AttentionPool(nn.Module):
"""Soft attention pooling over a sequence of vectors."""
def __init__(self, dim: int):
super().__init__()
self.attn = nn.Linear(dim, 1)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
weights = self.attn(x) # (B, T, 1)
if mask is not None:
weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float("-inf"))
weights = torch.softmax(weights, dim=1) # (B, T, 1)
return (weights * x).sum(dim=1) # (B, dim)
class BiLSTMBranch(nn.Module):
"""2-layer Bidirectional LSTM with Attention Pooling."""
def __init__(self, input_dim: int, hidden_dim: int = 128):
super().__init__()
self.lstm = nn.LSTM(
input_dim, hidden_dim,
num_layers=2,
batch_first=True,
dropout=0.2,
bidirectional=True,
)
self.pool = AttentionPool(hidden_dim * 2)
self.proj = nn.Linear(hidden_dim * 2, 128)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
out, _ = self.lstm(x) # (B, T, 256)
pooled = self.pool(out, mask) # (B, 256)
return F.gelu(self.proj(pooled)) # (B, 128)
class CNNBranch(nn.Module):
"""Multi-kernel 1D CNN with Global MaxPooling."""
def __init__(self, input_dim: int):
super().__init__()
self.conv3 = nn.Conv1d(input_dim, 64, kernel_size=3, padding=1)
self.conv5 = nn.Conv1d(input_dim, 64, kernel_size=5, padding=2)
self.conv7 = nn.Conv1d(input_dim, 64, kernel_size=7, padding=3)
self.bn3 = nn.BatchNorm1d(64)
self.bn5 = nn.BatchNorm1d(64)
self.bn7 = nn.BatchNorm1d(64)
self.proj = nn.Linear(192, 128)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_t = x.permute(0, 2, 1) # (B, D, T)
c3 = F.gelu(self.bn3(self.conv3(x_t)))
c5 = F.gelu(self.bn5(self.conv5(x_t)))
c7 = F.gelu(self.bn7(self.conv7(x_t)))
p3 = c3.max(dim=-1).values
p5 = c5.max(dim=-1).values
p7 = c7.max(dim=-1).values
cat = torch.cat([p3, p5, p7], dim=-1) # (B, 192)
return F.gelu(self.proj(cat)) # (B, 128)
class TransformerBranch(nn.Module):
"""Lightweight Transformer Encoder with Attention Pooling."""
def __init__(self, input_dim: int):
super().__init__()
self.proj_in = nn.Linear(input_dim, 128)
encoder_layer = nn.TransformerEncoderLayer(
d_model=128, nhead=4,
dim_feedforward=256,
dropout=0.1,
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
self.pool = AttentionPool(128)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
x = F.gelu(self.proj_in(x)) # (B, T, 128)
src_key_padding_mask = (mask == 0) if mask is not None else None
out = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
return self.pool(out, mask) # (B, 128)
class CrossAttentionFusion(nn.Module):
"""Fuse 3 branch outputs via multi-head self-attention (3-token sequence)."""
def __init__(self, dim: int = 128):
super().__init__()
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.scale = dim ** 0.5
self.proj = nn.Linear(dim, dim)
def forward(
self,
lstm_out: torch.Tensor,
cnn_out: torch.Tensor,
trans_out: torch.Tensor,
) -> torch.Tensor:
stacked = torch.stack([lstm_out, cnn_out, trans_out], dim=1) # (B, 3, 128)
Q = self.q(stacked)
K = self.k(stacked)
V = self.v(stacked)
attn = torch.softmax(torch.bmm(Q, K.transpose(1, 2)) / self.scale, dim=-1)
out = torch.bmm(attn, V).mean(dim=1) # (B, 128)
return F.gelu(self.proj(out))
# ─── Main Model ──────────────────────────────────────────────────────────────
class HybridAITextDetector(nn.Module):
"""
Hybrid AI-generated text detector.
Inputs
------
input_ids : (B, T) long tensor
attention_mask : (B, T) long tensor β€” 1 = real token, 0 = pad
token_type_ids : (B, T) long tensor
Output
------
logits : (B, 1) float β€” apply sigmoid to get P(AI-generated)
"""
def __init__(self):
super().__init__()
self.deberta = AutoModel.from_pretrained(MODEL_NAME)
# Freeze first 6 transformer layers
for name, param in self.deberta.named_parameters():
if any(f"layer.{i}." in name for i in range(6)):
param.requires_grad = False
else:
param.requires_grad = True
hidden = self.deberta.config.hidden_size # 768 for deberta-v3-small
self.lstm_branch = BiLSTMBranch(hidden)
self.cnn_branch = CNNBranch(hidden)
self.trans_branch = TransformerBranch(hidden)
self.fusion = CrossAttentionFusion(dim=128)
self.classifier = nn.Sequential(
nn.LayerNorm(128),
nn.Linear(128, 128),
nn.GELU(),
nn.Dropout(0.4),
nn.Linear(128, 64),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
) -> torch.Tensor:
out = self.deberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
hidden = out.last_hidden_state # (B, T, 768)
lstm_out = self.lstm_branch(hidden, attention_mask)
cnn_out = self.cnn_branch(hidden)
trans_out = self.trans_branch(hidden, attention_mask)
fused = self.fusion(lstm_out, cnn_out, trans_out)
return self.classifier(fused) # (B, 1)
# ─── Convenience inference helper ────────────────────────────────────────────
def load_model(checkpoint_path: str, device: torch.device = None) -> HybridAITextDetector:
"""Load a trained HybridAITextDetector from a .pt checkpoint."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridAITextDetector()
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval().to(device)
return model