File size: 6,001 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from dataclasses import dataclass, field
from typing import List
from .linear_probe import LinearProbe, LinearProbeConfig
from .transformer_probe import TransformerForSequenceClassification, TransformerForTokenClassification, TransformerProbeConfig
from .retrievalnet import RetrievalNetForSequenceClassification, RetrievalNetForTokenClassification, RetrievalNetConfig
from .lyra_probe import LyraForSequenceClassification, LyraForTokenClassification, LyraConfig
@dataclass
class ProbeArguments:
def __init__(
self,
probe_type: str = 'linear', # valid options: linear, transformer, retrievalnet
tokenwise: bool = False,
### Linear Probe
input_size: int = 960,
hidden_size: int = 8192,
dropout: float = 0.2,
num_labels: int = 2,
n_layers: int = 1,
task_type: str = 'singlelabel',
pre_ln: bool = True,
sim_type: str = 'dot',
token_attention: bool = False,
use_bias: bool = False,
add_token_ids: bool = False,
### Transformer Probe
transformer_hidden_size: int = 512, # For transformer probe
classifier_size: int = 4096,
transformer_dropout: float = 0.1,
classifier_dropout: float = 0.2,
n_heads: int = 4,
rotary: bool = True,
probe_pooling_types: List[str] = field(default_factory=lambda: ['mean', 'cls']),
### RetrievalNet
# TODO
### LoRA
lora: bool = False,
lora_r: int = 8,
lora_alpha: float = 32.0,
lora_dropout: float = 0.01,
**kwargs,
):
self.probe_type = probe_type
self.tokenwise = tokenwise
self.input_size = input_size
self.hidden_size = hidden_size
self.transformer_hidden_size = transformer_hidden_size
self.dropout = dropout
self.num_labels = num_labels
self.n_layers = n_layers
self.sim_type = sim_type
self.token_attention = token_attention
self.add_token_ids = add_token_ids
self.task_type = task_type
self.pre_ln = pre_ln
self.use_bias = use_bias
self.classifier_size = classifier_size
self.transformer_dropout = transformer_dropout
self.classifier_dropout = classifier_dropout
self.n_heads = n_heads
self.rotary = rotary
self.pooling_types = probe_pooling_types
self.lora = lora
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
def get_probe(args: ProbeArguments):
if args.probe_type == 'linear' and not args.tokenwise:
config = LinearProbeConfig(**args.__dict__)
return LinearProbe(config)
elif args.probe_type == 'transformer' and not args.tokenwise:
# Use transformer_hidden_size for the transformer probe
transformer_args = args.__dict__.copy()
transformer_args['hidden_size'] = args.transformer_hidden_size
config = TransformerProbeConfig(**transformer_args)
return TransformerForSequenceClassification(config)
elif args.probe_type == 'transformer' and args.tokenwise:
# Use transformer_hidden_size for the transformer probe's internal dimension
transformer_args = args.__dict__.copy()
transformer_args['hidden_size'] = args.transformer_hidden_size
config = TransformerProbeConfig(**transformer_args)
return TransformerForTokenClassification(config)
elif args.probe_type == 'retrievalnet' and not args.tokenwise:
config = RetrievalNetConfig(**args.__dict__)
return RetrievalNetForSequenceClassification(config)
elif args.probe_type == 'retrievalnet' and args.tokenwise:
config = RetrievalNetConfig(**args.__dict__)
return RetrievalNetForTokenClassification(config)
elif args.probe_type == 'lyra' and not args.tokenwise:
config = LyraConfig(**args.__dict__)
return LyraForSequenceClassification(config)
elif args.probe_type == 'lyra' and args.tokenwise:
config = LyraConfig(**args.__dict__)
return LyraForTokenClassification(config)
else:
raise ValueError(f"Invalid combination of probe type and tokenwise: {args.probe_type} {args.tokenwise}")
def rebuild_probe_from_saved_config(
probe_type: str,
tokenwise: bool,
probe_config: dict,
):
config_dict = dict(probe_config)
if "num_labels" not in config_dict and "id2label" in config_dict:
config_dict["num_labels"] = len(config_dict["id2label"])
if "pooling_types" in config_dict and "probe_pooling_types" not in config_dict:
config_dict["probe_pooling_types"] = config_dict["pooling_types"]
if probe_type == "linear" and not tokenwise:
config = LinearProbeConfig(**config_dict)
return LinearProbe(config)
if probe_type == "transformer" and not tokenwise:
config = TransformerProbeConfig(**config_dict)
return TransformerForSequenceClassification(config)
if probe_type == "transformer" and tokenwise:
config = TransformerProbeConfig(**config_dict)
return TransformerForTokenClassification(config)
if probe_type == "retrievalnet" and not tokenwise:
config = RetrievalNetConfig(**config_dict)
return RetrievalNetForSequenceClassification(config)
if probe_type == "retrievalnet" and tokenwise:
config = RetrievalNetConfig(**config_dict)
return RetrievalNetForTokenClassification(config)
if probe_type == "lyra" and not tokenwise:
config = LyraConfig(**config_dict)
return LyraForSequenceClassification(config)
if probe_type == "lyra" and tokenwise:
config = LyraConfig(**config_dict)
return LyraForTokenClassification(config)
raise ValueError(f"Unsupported saved probe configuration: {probe_type} tokenwise={tokenwise}")
|