nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
from dataclasses import dataclass, field
from typing import List
from .linear_probe import LinearProbe, LinearProbeConfig
from .transformer_probe import TransformerForSequenceClassification, TransformerForTokenClassification, TransformerProbeConfig
from .retrievalnet import RetrievalNetForSequenceClassification, RetrievalNetForTokenClassification, RetrievalNetConfig
from .lyra_probe import LyraForSequenceClassification, LyraForTokenClassification, LyraConfig
@dataclass
class ProbeArguments:
def __init__(
self,
probe_type: str = 'linear', # valid options: linear, transformer, retrievalnet
tokenwise: bool = False,
### Linear Probe
input_size: int = 960,
hidden_size: int = 8192,
dropout: float = 0.2,
num_labels: int = 2,
n_layers: int = 1,
task_type: str = 'singlelabel',
pre_ln: bool = True,
sim_type: str = 'dot',
token_attention: bool = False,
use_bias: bool = False,
add_token_ids: bool = False,
### Transformer Probe
transformer_hidden_size: int = 512, # For transformer probe
classifier_size: int = 4096,
transformer_dropout: float = 0.1,
classifier_dropout: float = 0.2,
n_heads: int = 4,
rotary: bool = True,
probe_pooling_types: List[str] = field(default_factory=lambda: ['mean', 'cls']),
### RetrievalNet
# TODO
### LoRA
lora: bool = False,
lora_r: int = 8,
lora_alpha: float = 32.0,
lora_dropout: float = 0.01,
**kwargs,
):
self.probe_type = probe_type
self.tokenwise = tokenwise
self.input_size = input_size
self.hidden_size = hidden_size
self.transformer_hidden_size = transformer_hidden_size
self.dropout = dropout
self.num_labels = num_labels
self.n_layers = n_layers
self.sim_type = sim_type
self.token_attention = token_attention
self.add_token_ids = add_token_ids
self.task_type = task_type
self.pre_ln = pre_ln
self.use_bias = use_bias
self.classifier_size = classifier_size
self.transformer_dropout = transformer_dropout
self.classifier_dropout = classifier_dropout
self.n_heads = n_heads
self.rotary = rotary
self.pooling_types = probe_pooling_types
self.lora = lora
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
def get_probe(args: ProbeArguments):
if args.probe_type == 'linear' and not args.tokenwise:
config = LinearProbeConfig(**args.__dict__)
return LinearProbe(config)
elif args.probe_type == 'transformer' and not args.tokenwise:
# Use transformer_hidden_size for the transformer probe
transformer_args = args.__dict__.copy()
transformer_args['hidden_size'] = args.transformer_hidden_size
config = TransformerProbeConfig(**transformer_args)
return TransformerForSequenceClassification(config)
elif args.probe_type == 'transformer' and args.tokenwise:
# Use transformer_hidden_size for the transformer probe's internal dimension
transformer_args = args.__dict__.copy()
transformer_args['hidden_size'] = args.transformer_hidden_size
config = TransformerProbeConfig(**transformer_args)
return TransformerForTokenClassification(config)
elif args.probe_type == 'retrievalnet' and not args.tokenwise:
config = RetrievalNetConfig(**args.__dict__)
return RetrievalNetForSequenceClassification(config)
elif args.probe_type == 'retrievalnet' and args.tokenwise:
config = RetrievalNetConfig(**args.__dict__)
return RetrievalNetForTokenClassification(config)
elif args.probe_type == 'lyra' and not args.tokenwise:
config = LyraConfig(**args.__dict__)
return LyraForSequenceClassification(config)
elif args.probe_type == 'lyra' and args.tokenwise:
config = LyraConfig(**args.__dict__)
return LyraForTokenClassification(config)
else:
raise ValueError(f"Invalid combination of probe type and tokenwise: {args.probe_type} {args.tokenwise}")
def rebuild_probe_from_saved_config(
probe_type: str,
tokenwise: bool,
probe_config: dict,
):
config_dict = dict(probe_config)
if "num_labels" not in config_dict and "id2label" in config_dict:
config_dict["num_labels"] = len(config_dict["id2label"])
if "pooling_types" in config_dict and "probe_pooling_types" not in config_dict:
config_dict["probe_pooling_types"] = config_dict["pooling_types"]
if probe_type == "linear" and not tokenwise:
config = LinearProbeConfig(**config_dict)
return LinearProbe(config)
if probe_type == "transformer" and not tokenwise:
config = TransformerProbeConfig(**config_dict)
return TransformerForSequenceClassification(config)
if probe_type == "transformer" and tokenwise:
config = TransformerProbeConfig(**config_dict)
return TransformerForTokenClassification(config)
if probe_type == "retrievalnet" and not tokenwise:
config = RetrievalNetConfig(**config_dict)
return RetrievalNetForSequenceClassification(config)
if probe_type == "retrievalnet" and tokenwise:
config = RetrievalNetConfig(**config_dict)
return RetrievalNetForTokenClassification(config)
if probe_type == "lyra" and not tokenwise:
config = LyraConfig(**config_dict)
return LyraForSequenceClassification(config)
if probe_type == "lyra" and tokenwise:
config = LyraConfig(**config_dict)
return LyraForTokenClassification(config)
raise ValueError(f"Unsupported saved probe configuration: {probe_type} tokenwise={tokenwise}")