| |
|
| |
|
| |
|
| | from typing import Optional, Tuple
|
| | from dataclasses import dataclass
|
| | from transformers.modeling_outputs import ModelOutput
|
| | import sys, copy, math
|
| |
|
| | from .pooling import *
|
| | from .loss import *
|
| |
|
| | @dataclass
|
| | class AllOutput(ModelOutput):
|
| | losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
| | outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| | attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| | global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| | contacts: Optional[Tuple[torch.FloatTensor]] = None
|
| | losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
| | outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
| | hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
|
| | attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
| | cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
| | global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
|
| | contacts_b: Optional[Tuple[torch.FloatTensor]] = None
|
| | pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| | pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
|
| |
|
| |
|
| | def create_pooler(task_level_type, task_level_name, config, args):
|
| | '''
|
| | pooler building
|
| | :param task_level_type:
|
| | :param task_level_name:
|
| | :param config:
|
| | :param args:
|
| | :return:
|
| | '''
|
| | hidden_size = config.hidden_size[task_level_type][task_level_name]
|
| | pooling_type = args.pooling_type[task_level_type][task_level_name]
|
| |
|
| | if pooling_type == "max":
|
| | return GlobalMaskMaxPooling1D()
|
| | elif pooling_type == "sum":
|
| | return GlobalMaskSumPooling1D(axis=1)
|
| | elif pooling_type == "avg":
|
| | return GlobalMaskAvgPooling1D()
|
| | elif pooling_type == "attention":
|
| | return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
|
| | elif pooling_type == "context_attention":
|
| | return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
|
| | elif pooling_type == "weighted_attention":
|
| | return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
|
| | elif pooling_type == "value_attention":
|
| | return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
|
| | elif pooling_type == "transformer":
|
| | copy_config = copy.deepcopy(config)
|
| | copy_config.hidden_size = hidden_size
|
| | return GlobalMaskTransformerPooling1D(copy_config)
|
| | else:
|
| | return None
|
| |
|
| |
|
| | def create_output_loss_lucagplm(task_level_type, task_level_name, config):
|
| | '''not cls module'''
|
| | if not hasattr(config, "sigmoid"):
|
| | config.sigmoid = {task_level_type: {}}
|
| | elif task_level_type not in config.sigmoid:
|
| | config.sigmoid[task_level_type] = {}
|
| | config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
|
| | in ["multi_class", "multi-class", "regression"] else True
|
| |
|
| | if task_level_name == "prot_contact":
|
| | config.sigmoid[task_level_type][task_level_name] = True
|
| | config.num_labels = config.label_size[task_level_type][task_level_name]
|
| | if task_level_type in ["token_level", "whole_level"]:
|
| | return_types = ["output", "loss"]
|
| | else:
|
| | return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
|
| | return create_loss_function(config,
|
| | task_level_type=task_level_type,
|
| | task_level_name=task_level_name,
|
| | sigmoid=config.sigmoid[task_level_type][task_level_name],
|
| | output_mode=config.output_mode[task_level_type][task_level_name],
|
| | num_labels=config.num_labels,
|
| | loss_type=config.loss_type[task_level_type][task_level_name],
|
| | ignore_index=config.ignore_index,
|
| | pair_level=True if task_level_type == "pair_level" else False,
|
| | return_types=return_types)
|
| |
|
| |
|
| | def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
|
| | cls = None
|
| | if task_level_type in ["token_level", "whole_level"]:
|
| | cls = cls_module(config)
|
| | dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
|
| | return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct
|
| |
|