File size: 6,001 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from dataclasses import dataclass, field
from typing import List
from .linear_probe import LinearProbe, LinearProbeConfig
from .transformer_probe import TransformerForSequenceClassification, TransformerForTokenClassification, TransformerProbeConfig
from .retrievalnet import RetrievalNetForSequenceClassification, RetrievalNetForTokenClassification, RetrievalNetConfig
from .lyra_probe import LyraForSequenceClassification, LyraForTokenClassification, LyraConfig


@dataclass
class ProbeArguments:
    def __init__(
            self,
            probe_type: str = 'linear', # valid options: linear, transformer, retrievalnet
            tokenwise: bool = False,
            ### Linear Probe
            input_size: int = 960,
            hidden_size: int = 8192,
            dropout: float = 0.2,
            num_labels: int = 2,
            n_layers: int = 1,
            task_type: str = 'singlelabel',
            pre_ln: bool = True,
            sim_type: str = 'dot',
            token_attention: bool = False,
            use_bias: bool = False,
            add_token_ids: bool = False,
            ### Transformer Probe
            transformer_hidden_size: int = 512,  # For transformer probe
            classifier_size: int = 4096,
            transformer_dropout: float = 0.1,
            classifier_dropout: float = 0.2,
            n_heads: int = 4,
            rotary: bool = True,
            probe_pooling_types: List[str] = field(default_factory=lambda: ['mean', 'cls']),
            ### RetrievalNet
            # TODO
            ### LoRA
            lora: bool = False,
            lora_r: int = 8,
            lora_alpha: float = 32.0,
            lora_dropout: float = 0.01,
            **kwargs,

    ):
        self.probe_type = probe_type
        self.tokenwise = tokenwise
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.transformer_hidden_size = transformer_hidden_size
        self.dropout = dropout
        self.num_labels = num_labels
        self.n_layers = n_layers
        self.sim_type = sim_type
        self.token_attention = token_attention
        self.add_token_ids = add_token_ids
        self.task_type = task_type
        self.pre_ln = pre_ln
        self.use_bias = use_bias
        self.classifier_size = classifier_size
        self.transformer_dropout = transformer_dropout
        self.classifier_dropout = classifier_dropout
        self.n_heads = n_heads
        self.rotary = rotary
        self.pooling_types = probe_pooling_types
        self.lora = lora
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout


def get_probe(args: ProbeArguments):
    if args.probe_type == 'linear' and not args.tokenwise:
        config = LinearProbeConfig(**args.__dict__)
        return LinearProbe(config)
    elif args.probe_type == 'transformer' and not args.tokenwise:
        # Use transformer_hidden_size for the transformer probe
        transformer_args = args.__dict__.copy()
        transformer_args['hidden_size'] = args.transformer_hidden_size
        config = TransformerProbeConfig(**transformer_args)
        return TransformerForSequenceClassification(config)
    elif args.probe_type == 'transformer' and args.tokenwise:
        # Use transformer_hidden_size for the transformer probe's internal dimension
        transformer_args = args.__dict__.copy()
        transformer_args['hidden_size'] = args.transformer_hidden_size
        config = TransformerProbeConfig(**transformer_args)
        return TransformerForTokenClassification(config)
    elif args.probe_type == 'retrievalnet' and not args.tokenwise:
        config = RetrievalNetConfig(**args.__dict__)
        return RetrievalNetForSequenceClassification(config)
    elif args.probe_type == 'retrievalnet' and args.tokenwise:
        config = RetrievalNetConfig(**args.__dict__)
        return RetrievalNetForTokenClassification(config)
    elif args.probe_type == 'lyra' and not args.tokenwise:
        config = LyraConfig(**args.__dict__)
        return LyraForSequenceClassification(config)
    elif args.probe_type == 'lyra' and args.tokenwise:
        config = LyraConfig(**args.__dict__)
        return LyraForTokenClassification(config)
    else:
        raise ValueError(f"Invalid combination of probe type and tokenwise: {args.probe_type} {args.tokenwise}")


def rebuild_probe_from_saved_config(
        probe_type: str,
        tokenwise: bool,
        probe_config: dict,
    ):
    config_dict = dict(probe_config)
    if "num_labels" not in config_dict and "id2label" in config_dict:
        config_dict["num_labels"] = len(config_dict["id2label"])
    if "pooling_types" in config_dict and "probe_pooling_types" not in config_dict:
        config_dict["probe_pooling_types"] = config_dict["pooling_types"]

    if probe_type == "linear" and not tokenwise:
        config = LinearProbeConfig(**config_dict)
        return LinearProbe(config)
    if probe_type == "transformer" and not tokenwise:
        config = TransformerProbeConfig(**config_dict)
        return TransformerForSequenceClassification(config)
    if probe_type == "transformer" and tokenwise:
        config = TransformerProbeConfig(**config_dict)
        return TransformerForTokenClassification(config)
    if probe_type == "retrievalnet" and not tokenwise:
        config = RetrievalNetConfig(**config_dict)
        return RetrievalNetForSequenceClassification(config)
    if probe_type == "retrievalnet" and tokenwise:
        config = RetrievalNetConfig(**config_dict)
        return RetrievalNetForTokenClassification(config)
    if probe_type == "lyra" and not tokenwise:
        config = LyraConfig(**config_dict)
        return LyraForSequenceClassification(config)
    if probe_type == "lyra" and tokenwise:
        config = LyraConfig(**config_dict)
        return LyraForTokenClassification(config)
    raise ValueError(f"Unsupported saved probe configuration: {probe_type} tokenwise={tokenwise}")