cracker0935's picture
Audit Backend
97ea4f7
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class GatedFeatureFusion(nn.Module):
def __init__(self, embed_dim, feature_dim):
super().__init__()
self.feat_proj = nn.Linear(feature_dim, embed_dim)
self.gate = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.Sigmoid()
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, text_embeds, raw_features):
feat_embeds = F.relu(self.feat_proj(raw_features))
combined = torch.cat([text_embeds, feat_embeds], dim=2)
z = self.gate(combined)
# Gated Residual Connection
fused = z * text_embeds + (1 - z) * feat_embeds
return self.norm(fused)
class ResearchHybridModel(nn.Module):
def __init__(self, model_name='microsoft/deberta-base', feature_dim=6):
super().__init__()
# 1. DeBERTa
self.bert = AutoModel.from_pretrained(model_name)
# Enable Gradient Checkpointing for Memory Efficiency
self.bert.gradient_checkpointing_enable()
self.bert_hidden = 768
# FREEZE bottom 6 layers (Keep basic grammar frozen)
for i in range(6):
for param in self.bert.encoder.layer[i].parameters():
param.requires_grad = False
# UNFREEZE top 6 layers (Allow adaptation to medical speech)
for i in range(6, 12):
for param in self.bert.encoder.layer[i].parameters():
param.requires_grad = True
# 2. Components
self.fusion = GatedFeatureFusion(self.bert_hidden, feature_dim)
# BiLSTM
self.lstm = nn.LSTM(
input_size=self.bert_hidden,
hidden_size=256,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3
)
# Attention Pooling
self.attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
# Classifier
self.classifier = nn.Sequential(
nn.Linear(512 + 2, 128), # Added +2 for age and gender
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 2)
)
def forward(self, input_ids, attention_mask, linguistic_features, lengths, age, gender):
batch_size, seq_len, word_len = input_ids.shape
# DeBERTa
flat_input = input_ids.view(-1, word_len)
flat_mask = attention_mask.view(-1, word_len)
bert_out = self.bert(flat_input, attention_mask=flat_mask).last_hidden_state
sent_embeds = bert_out[:, 0, :].view(batch_size, seq_len, -1)
# Fusion
fused = self.fusion(sent_embeds, linguistic_features)
# BiLSTM
packed = torch.nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
lstm_out, _ = self.lstm(packed)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True, total_length=seq_len)
# Attention
attn_scores = self.attention(lstm_out)
mask = (torch.arange(seq_len, device=input_ids.device)[None, :] < lengths.to(input_ids.device)[:, None]).float().unsqueeze(2)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(attn_scores, dim=1)
context = torch.sum(lstm_out * attn_weights, dim=1)
# Concatenate Age and Gender
global_feats = torch.stack([age, gender], dim=1)
combined_context = torch.cat([context, global_feats], dim=1)
return self.classifier(combined_context)
class ExplainableModel(ResearchHybridModel):
def forward(self, input_ids, attention_mask, linguistic_features, lengths, age, gender):
batch_size, seq_len, word_len = input_ids.shape
# A. DeBERTa
flat_input = input_ids.view(-1, word_len)
flat_mask = attention_mask.view(-1, word_len)
bert_out = self.bert(flat_input, attention_mask=flat_mask).last_hidden_state
sent_embeds = bert_out[:, 0, :].view(batch_size, seq_len, -1)
# B. Fusion
fused = self.fusion(sent_embeds, linguistic_features)
# C. BiLSTM
packed = torch.nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_out, _ = self.lstm(packed)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=seq_len)
# D. INTERCEPT ATTENTION WEIGHTS
# The 'self.attention' is nn.Sequential(Linear, Tanh, Linear)
attn_scores = self.attention(lstm_out) # Shape: [Batch, Seq, 1]
# Mask padding
mask = (torch.arange(seq_len, device=input_ids.device)[None, :] < lengths.to(input_ids.device)[:, None]).float().unsqueeze(2)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
# Get the actual importance weights (0 to 1)
attn_weights = F.softmax(attn_scores, dim=1)
# Compute Context & Logits
context = torch.sum(lstm_out * attn_weights, dim=1)
# Concatenate Age and Gender
global_feats = torch.stack([age, gender], dim=1)
combined_context = torch.cat([context, global_feats], dim=1)
logits = self.classifier(combined_context)
return logits, attn_weights.squeeze()