Spaces:
Sleeping
Sleeping
File size: 1,229 Bytes
8953138 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
class ComplexityFusionModel(nn.Module):
def __init__(self, model_name, num_labels, num_static_features, static_hidden_dim=16):
super(ComplexityFusionModel, self).__init__()
# Load config and base model
self.config = AutoConfig.from_pretrained(model_name)
self.codebert = AutoModel.from_pretrained(model_name)
self.static_mlp = nn.Sequential(
nn.Linear(num_static_features, static_hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
)
fusion_dim = self.config.hidden_size + static_hidden_dim
self.classifier = nn.Linear(fusion_dim, num_labels)
def forward(self, input_ids=None, attention_mask=None, static_features=None):
outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask)
bert_output = outputs.last_hidden_state[:, 0, :]
static_output = self.static_mlp(static_features)
combined_features = torch.cat((bert_output, static_output), dim=1)
logits = self.classifier(combined_features)
return logits |