Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class ClassificationHead(nn.Module): | |
| """Head for sentence-level classification tasks.""" | |
| def __init__(self, hidden_dim): | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_dim, hidden_dim) | |
| self.Dropout = nn.Dropout(0.1) | |
| self.out_proj = nn.Linear(hidden_dim, 1) | |
| self.rnn_pool = nn.GRU(input_size=768, | |
| hidden_size=768, | |
| num_layers=1, | |
| batch_first=True) | |
| self.func_dense = nn.Linear(hidden_dim, hidden_dim) | |
| self.func_out_proj = nn.Linear(hidden_dim, 2) | |
| def forward(self, hidden): | |
| x = self.Dropout(hidden) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.Dropout(x) | |
| x = self.out_proj(x) | |
| out, func_x = self.rnn_pool(hidden) | |
| func_x = func_x.squeeze(0) | |
| func_x = self.Dropout(func_x) | |
| func_x = self.func_dense(func_x) | |
| func_x = torch.tanh(func_x) | |
| func_x = self.Dropout(func_x) | |
| func_x = self.func_out_proj(func_x) | |
| return x.squeeze(-1), func_x | |
| class StatementT5(nn.Module): | |
| def __init__(self, t5, tokenizer, device, hidden_dim=768): | |
| super(StatementT5, self).__init__() | |
| self.max_num_statement = 155 | |
| self.word_embedding = t5.shared | |
| self.rnn_statement_embedding = nn.GRU(input_size=768, | |
| hidden_size=768, | |
| num_layers=1, | |
| batch_first=True) | |
| self.t5 = t5 | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| # CLS head | |
| self.classifier = ClassificationHead(hidden_dim=hidden_dim) | |
| def forward(self, input_ids, statement_mask, labels=None, func_labels=None): | |
| statement_mask = statement_mask[:, :self.max_num_statement] | |
| if self.training: | |
| embed = self.word_embedding(input_ids) | |
| inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device) | |
| for i in range(len(embed)): | |
| statement_of_tokens = embed[i] | |
| out, statement_embed = self.rnn_statement_embedding(statement_of_tokens) | |
| inputs_embeds[i, :, :] = statement_embed | |
| inputs_embeds = inputs_embeds[:, :self.max_num_statement, :] | |
| rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state | |
| logits, func_logits = self.classifier(rep) | |
| loss_fct = nn.CrossEntropyLoss() | |
| statement_loss = loss_fct(logits, labels) | |
| loss_fct_2 = nn.CrossEntropyLoss() | |
| func_loss = loss_fct_2(func_logits, func_labels) | |
| return statement_loss, func_loss | |
| else: | |
| embed = self.word_embedding(input_ids) | |
| inputs_embeds = torch.randn(embed.shape[0], embed.shape[1], embed.shape[3]).to(self.device) | |
| for i in range(len(embed)): | |
| statement_of_tokens = embed[i] | |
| out, statement_embed = self.rnn_statement_embedding(statement_of_tokens) | |
| inputs_embeds[i, :, :] = statement_embed | |
| inputs_embeds = inputs_embeds[:, :self.max_num_statement, :] | |
| rep = self.t5(inputs_embeds=inputs_embeds, attention_mask=statement_mask).last_hidden_state | |
| logits, func_logits = self.classifier(rep) | |
| probs = torch.sigmoid(logits) | |
| func_probs = torch.softmax(func_logits, dim=-1) | |
| return probs, func_probs |