from dataclasses import dataclass, field from typing import List from .linear_probe import LinearProbe, LinearProbeConfig from .transformer_probe import TransformerForSequenceClassification, TransformerForTokenClassification, TransformerProbeConfig from .retrievalnet import RetrievalNetForSequenceClassification, RetrievalNetForTokenClassification, RetrievalNetConfig from .lyra_probe import LyraForSequenceClassification, LyraForTokenClassification, LyraConfig @dataclass class ProbeArguments: def __init__( self, probe_type: str = 'linear', # valid options: linear, transformer, retrievalnet tokenwise: bool = False, ### Linear Probe input_size: int = 960, hidden_size: int = 8192, dropout: float = 0.2, num_labels: int = 2, n_layers: int = 1, task_type: str = 'singlelabel', pre_ln: bool = True, sim_type: str = 'dot', token_attention: bool = False, use_bias: bool = False, add_token_ids: bool = False, ### Transformer Probe transformer_hidden_size: int = 512, # For transformer probe classifier_size: int = 4096, transformer_dropout: float = 0.1, classifier_dropout: float = 0.2, n_heads: int = 4, rotary: bool = True, probe_pooling_types: List[str] = field(default_factory=lambda: ['mean', 'cls']), ### RetrievalNet # TODO ### LoRA lora: bool = False, lora_r: int = 8, lora_alpha: float = 32.0, lora_dropout: float = 0.01, **kwargs, ): self.probe_type = probe_type self.tokenwise = tokenwise self.input_size = input_size self.hidden_size = hidden_size self.transformer_hidden_size = transformer_hidden_size self.dropout = dropout self.num_labels = num_labels self.n_layers = n_layers self.sim_type = sim_type self.token_attention = token_attention self.add_token_ids = add_token_ids self.task_type = task_type self.pre_ln = pre_ln self.use_bias = use_bias self.classifier_size = classifier_size self.transformer_dropout = transformer_dropout self.classifier_dropout = classifier_dropout self.n_heads = n_heads self.rotary = rotary self.pooling_types = probe_pooling_types self.lora = lora self.lora_r = lora_r self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout def get_probe(args: ProbeArguments): if args.probe_type == 'linear' and not args.tokenwise: config = LinearProbeConfig(**args.__dict__) return LinearProbe(config) elif args.probe_type == 'transformer' and not args.tokenwise: # Use transformer_hidden_size for the transformer probe transformer_args = args.__dict__.copy() transformer_args['hidden_size'] = args.transformer_hidden_size config = TransformerProbeConfig(**transformer_args) return TransformerForSequenceClassification(config) elif args.probe_type == 'transformer' and args.tokenwise: # Use transformer_hidden_size for the transformer probe's internal dimension transformer_args = args.__dict__.copy() transformer_args['hidden_size'] = args.transformer_hidden_size config = TransformerProbeConfig(**transformer_args) return TransformerForTokenClassification(config) elif args.probe_type == 'retrievalnet' and not args.tokenwise: config = RetrievalNetConfig(**args.__dict__) return RetrievalNetForSequenceClassification(config) elif args.probe_type == 'retrievalnet' and args.tokenwise: config = RetrievalNetConfig(**args.__dict__) return RetrievalNetForTokenClassification(config) elif args.probe_type == 'lyra' and not args.tokenwise: config = LyraConfig(**args.__dict__) return LyraForSequenceClassification(config) elif args.probe_type == 'lyra' and args.tokenwise: config = LyraConfig(**args.__dict__) return LyraForTokenClassification(config) else: raise ValueError(f"Invalid combination of probe type and tokenwise: {args.probe_type} {args.tokenwise}") def rebuild_probe_from_saved_config( probe_type: str, tokenwise: bool, probe_config: dict, ): config_dict = dict(probe_config) if "num_labels" not in config_dict and "id2label" in config_dict: config_dict["num_labels"] = len(config_dict["id2label"]) if "pooling_types" in config_dict and "probe_pooling_types" not in config_dict: config_dict["probe_pooling_types"] = config_dict["pooling_types"] if probe_type == "linear" and not tokenwise: config = LinearProbeConfig(**config_dict) return LinearProbe(config) if probe_type == "transformer" and not tokenwise: config = TransformerProbeConfig(**config_dict) return TransformerForSequenceClassification(config) if probe_type == "transformer" and tokenwise: config = TransformerProbeConfig(**config_dict) return TransformerForTokenClassification(config) if probe_type == "retrievalnet" and not tokenwise: config = RetrievalNetConfig(**config_dict) return RetrievalNetForSequenceClassification(config) if probe_type == "retrievalnet" and tokenwise: config = RetrievalNetConfig(**config_dict) return RetrievalNetForTokenClassification(config) if probe_type == "lyra" and not tokenwise: config = LyraConfig(**config_dict) return LyraForSequenceClassification(config) if probe_type == "lyra" and tokenwise: config = LyraConfig(**config_dict) return LyraForTokenClassification(config) raise ValueError(f"Unsupported saved probe configuration: {probe_type} tokenwise={tokenwise}")