File size: 4,820 Bytes
73d1860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
# encoding: utf-8

from typing import Optional, Tuple
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
import sys, copy, math

from .pooling import *
from .loss import *

@dataclass
class AllOutput(ModelOutput):
    losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
    outputs: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    global_attentions: Optional[Tuple[torch.FloatTensor]] = None
    contacts: Optional[Tuple[torch.FloatTensor]] = None
    losses_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
    outputs_b: Optional[dict[str, dict[str, torch.FloatTensor]]] = None
    hidden_states_b: Optional[Tuple[torch.FloatTensor]] = None
    attentions_b: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
    global_attentions_b: Optional[Tuple[torch.FloatTensor]] = None
    contacts_b: Optional[Tuple[torch.FloatTensor]] = None
    pair_outputs: Optional[Tuple[torch.FloatTensor]] = None
    pair_losses: Optional[dict[str, dict[str, torch.FloatTensor]]] = None


def create_pooler(task_level_type, task_level_name, config, args):
    '''

    pooler building

    :param task_level_type:

    :param task_level_name:

    :param config:

    :param args:

    :return:

    '''
    hidden_size = config.hidden_size[task_level_type][task_level_name]
    pooling_type = args.pooling_type[task_level_type][task_level_name]

    if pooling_type == "max":
        return GlobalMaskMaxPooling1D()
    elif pooling_type == "sum":
        return GlobalMaskSumPooling1D(axis=1)
    elif pooling_type == "avg":
        return GlobalMaskAvgPooling1D()
    elif pooling_type == "attention":
        return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
    elif pooling_type == "context_attention":
        return GlobalMaskContextAttentionPooling1D(embed_size=hidden_size)
    elif pooling_type == "weighted_attention":
        return GlobalMaskWeightedAttentionPooling1D(embed_size=hidden_size)
    elif pooling_type == "value_attention":
        return GlobalMaskValueAttentionPooling1D(embed_size=hidden_size)
    elif pooling_type == "transformer":
        copy_config = copy.deepcopy(config)
        copy_config.hidden_size = hidden_size
        return GlobalMaskTransformerPooling1D(copy_config)
    else:
        return None


def create_output_loss_lucagplm(task_level_type, task_level_name, config):
    '''not cls module'''
    if not hasattr(config, "sigmoid"):
        config.sigmoid = {task_level_type: {}}
    elif task_level_type not in config.sigmoid:
        config.sigmoid[task_level_type] = {}
    config.sigmoid[task_level_type][task_level_name] = False if config.output_mode[task_level_type][task_level_name] \
                                                              in ["multi_class", "multi-class", "regression"] else True
    # 特殊情况,contact需要是sigmoid, 需要思考strcuture需不需要sigmoid
    if task_level_name == "prot_contact":
        config.sigmoid[task_level_type][task_level_name] = True
    config.num_labels = config.label_size[task_level_type][task_level_name]
    if task_level_type in ["token_level", "whole_level"]:
        return_types = ["output", "loss"]
    else:
        return_types = ["dropout", "hidden_layer", "hidden_act", "classifier", "output", "loss"]
    return create_loss_function(config,
                                task_level_type=task_level_type,
                                task_level_name=task_level_name,
                                sigmoid=config.sigmoid[task_level_type][task_level_name],
                                output_mode=config.output_mode[task_level_type][task_level_name],
                                num_labels=config.num_labels,
                                loss_type=config.loss_type[task_level_type][task_level_name],
                                ignore_index=config.ignore_index,
                                pair_level=True if task_level_type == "pair_level" else False,
                                return_types=return_types)


def create_output_loss(task_level_type, task_level_name, cls_module, config, args):
    cls = None
    if task_level_type in ["token_level", "whole_level"]:
        cls = cls_module(config)
    dropout, hidden_layer, hidden_act, classifier, output, loss_fct = create_output_loss_lucagplm(task_level_type, task_level_name, config, args)
    return cls, dropout, hidden_layer, hidden_act, classifier, output, loss_fct