Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel | |
| class GatedFeatureFusion(nn.Module): | |
| def __init__(self, embed_dim, feature_dim): | |
| super().__init__() | |
| self.feat_proj = nn.Linear(feature_dim, embed_dim) | |
| self.gate = nn.Sequential( | |
| nn.Linear(embed_dim * 2, embed_dim), | |
| nn.Sigmoid() | |
| ) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| def forward(self, text_embeds, raw_features): | |
| feat_embeds = F.relu(self.feat_proj(raw_features)) | |
| combined = torch.cat([text_embeds, feat_embeds], dim=2) | |
| z = self.gate(combined) | |
| fused = z * text_embeds + (1 - z) * feat_embeds | |
| return self.norm(fused) | |
| class ResearchHybridModel(nn.Module): | |
| def __init__(self, model_name='microsoft/deberta-base', feature_dim=6): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| self.bert_hidden = 768 | |
| self.fusion = GatedFeatureFusion(self.bert_hidden, feature_dim) | |
| self.lstm = nn.LSTM( | |
| input_size=self.bert_hidden, | |
| hidden_size=256, | |
| num_layers=2, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=0.3 | |
| ) | |
| self.attention = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.Tanh(), | |
| nn.Linear(128, 1) | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(0.4), | |
| nn.Linear(128, 2) | |
| ) | |
| def forward(self, input_ids, attention_mask, linguistic_features, lengths): | |
| batch_size, seq_len, word_len = input_ids.shape | |
| flat_input = input_ids.view(-1, word_len) | |
| flat_mask = attention_mask.view(-1, word_len) | |
| bert_out = self.bert(flat_input, attention_mask=flat_mask).last_hidden_state | |
| sent_embeds = bert_out[:, 0, :].view(batch_size, seq_len, -1) | |
| fused = self.fusion(sent_embeds, linguistic_features) | |
| packed = torch.nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False) | |
| packed_out, _ = self.lstm(packed) | |
| lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=seq_len) | |
| attn_scores = self.attention(lstm_out) | |
| mask = (torch.arange(seq_len, device=input_ids.device)[None, :] < lengths.to(input_ids.device)[:, None]).float().unsqueeze(2) | |
| attn_scores = attn_scores.masked_fill(mask == 0, -1e9) | |
| attn_weights = F.softmax(attn_scores, dim=1) | |
| context = torch.sum(lstm_out * attn_weights, dim=1) | |
| return self.classifier(context), attn_weights.squeeze() |