adtrack-backend / model_arch.py
cracker0935's picture
add Backend_files
eb48929 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
class GatedFeatureFusion(nn.Module):
def __init__(self, embed_dim, feature_dim):
super().__init__()
self.feat_proj = nn.Linear(feature_dim, embed_dim)
self.gate = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.Sigmoid()
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, text_embeds, raw_features):
feat_embeds = F.relu(self.feat_proj(raw_features))
combined = torch.cat([text_embeds, feat_embeds], dim=2)
z = self.gate(combined)
fused = z * text_embeds + (1 - z) * feat_embeds
return self.norm(fused)
class ResearchHybridModel(nn.Module):
def __init__(self, model_name='microsoft/deberta-base', feature_dim=6):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.bert_hidden = 768
self.fusion = GatedFeatureFusion(self.bert_hidden, feature_dim)
self.lstm = nn.LSTM(
input_size=self.bert_hidden,
hidden_size=256,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.3
)
self.attention = nn.Sequential(
nn.Linear(512, 128),
nn.Tanh(),
nn.Linear(128, 1)
)
self.classifier = nn.Sequential(
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 2)
)
def forward(self, input_ids, attention_mask, linguistic_features, lengths):
batch_size, seq_len, word_len = input_ids.shape
flat_input = input_ids.view(-1, word_len)
flat_mask = attention_mask.view(-1, word_len)
bert_out = self.bert(flat_input, attention_mask=flat_mask).last_hidden_state
sent_embeds = bert_out[:, 0, :].view(batch_size, seq_len, -1)
fused = self.fusion(sent_embeds, linguistic_features)
packed = torch.nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_out, _ = self.lstm(packed)
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=seq_len)
attn_scores = self.attention(lstm_out)
mask = (torch.arange(seq_len, device=input_ids.device)[None, :] < lengths.to(input_ids.device)[:, None]).float().unsqueeze(2)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(attn_scores, dim=1)
context = torch.sum(lstm_out * attn_weights, dim=1)
return self.classifier(context), attn_weights.squeeze()