|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Configuration definitions for multi-task training."""
|
| import dataclasses
|
| from typing import Optional, Tuple
|
|
|
| from official.core import config_definitions as cfg
|
| from official.modeling import hyperparams
|
| from official.modeling.privacy import configs as dp_configs
|
|
|
|
|
| @dataclasses.dataclass
|
| class TaskRoutine(hyperparams.Config):
|
|
|
| task_name: str = ""
|
| task_config: cfg.TaskConfig = None
|
| eval_steps: Optional[int] = None
|
| task_weight: Optional[float] = 1.0
|
|
|
|
|
| @dataclasses.dataclass
|
| class MultiTaskConfig(hyperparams.Config):
|
| init_checkpoint: str = ""
|
| model: hyperparams.Config = None
|
| task_routines: Tuple[TaskRoutine, ...] = ()
|
|
|
|
|
|
|
|
|
| differential_privacy_config: Optional[
|
| dp_configs.DifferentialPrivacyConfig] = None
|
|
|
|
|
| @dataclasses.dataclass
|
| class ProportionalSampleConfig(hyperparams.Config):
|
| alpha: float = 1.0
|
|
|
|
|
| @dataclasses.dataclass
|
| class AnnealingSampleConfig(hyperparams.Config):
|
| steps_per_epoch: int = 5
|
| total_steps: int = 20
|
|
|
|
|
| @dataclasses.dataclass
|
| class TaskSamplingConfig(hyperparams.OneOfConfig):
|
| type: str = ""
|
| uniform: hyperparams.Config = dataclasses.field(
|
| default_factory=hyperparams.Config
|
| )
|
| proportional: ProportionalSampleConfig = dataclasses.field(
|
| default_factory=ProportionalSampleConfig
|
| )
|
| annealing: AnnealingSampleConfig = dataclasses.field(
|
| default_factory=AnnealingSampleConfig
|
| )
|
|
|
|
|
| @dataclasses.dataclass
|
| class MultiTaskTrainerConfig(cfg.TrainerConfig):
|
| trainer_type: str = "interleaving"
|
| task_sampler: TaskSamplingConfig = dataclasses.field(
|
| default_factory=lambda: TaskSamplingConfig(type="proportional")
|
| )
|
|
|
|
|
| @dataclasses.dataclass
|
| class MultiTaskExperimentConfig(hyperparams.Config):
|
| """An experiment config for multi-task training and multi-task evaluation."""
|
| task: MultiTaskConfig = dataclasses.field(default_factory=MultiTaskConfig)
|
| trainer: MultiTaskTrainerConfig = dataclasses.field(
|
| default_factory=MultiTaskTrainerConfig
|
| )
|
| runtime: cfg.RuntimeConfig = dataclasses.field(
|
| default_factory=cfg.RuntimeConfig
|
| )
|
|
|
|
|
| @dataclasses.dataclass
|
| class MultiEvalExperimentConfig(cfg.ExperimentConfig):
|
| """An experiment config for single-task training and multi-task evaluation.
|
|
|
| Attributes:
|
| eval_tasks: individual evaluation tasks.
|
| """
|
| eval_tasks: Tuple[TaskRoutine, ...] = ()
|
|
|