| import abc |
| from dataclasses import asdict, dataclass |
| from inspect import getsource |
| from typing import Any, Callable, List, Optional, Union |
|
|
|
|
| @dataclass |
| class AggMetricConfig(dict): |
| metric: Optional[str] = None |
| aggregation: Optional[str] = "mean" |
| weight_by_size: Optional[str] = False |
| |
| filter_list: Optional[Union[str, list]] = "none" |
|
|
| def __post_init__(self): |
| if self.aggregation != "mean" and not callable(self.aggregation): |
| raise ValueError( |
| f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'." |
| ) |
|
|
| if isinstance(self.filter_list, str): |
| self.filter_list = [self.filter_list] |
|
|
|
|
| @dataclass |
| class GroupConfig(dict): |
| group: Optional[str] = None |
| group_alias: Optional[str] = None |
| task: Optional[Union[str, list]] = None |
| aggregate_metric_list: Optional[ |
| Union[List[AggMetricConfig], AggMetricConfig, dict] |
| ] = None |
| metadata: Optional[dict] = ( |
| None |
| ) |
|
|
| def __getitem__(self, item): |
| return getattr(self, item) |
|
|
| def __setitem__(self, item, value): |
| return setattr(self, item, value) |
|
|
| def __post_init__(self): |
| if self.aggregate_metric_list is not None: |
| if isinstance(self.aggregate_metric_list, dict): |
| self.aggregate_metric_list = [self.aggregate_metric_list] |
|
|
| self.aggregate_metric_list = [ |
| AggMetricConfig(**item) if isinstance(item, dict) else item |
| for item in self.aggregate_metric_list |
| ] |
|
|
| def to_dict(self, keep_callable: bool = False) -> dict: |
| """dumps the current config as a dictionary object, as a printable format. |
| null fields will not be printed. |
| Used for dumping results alongside full task configuration |
| |
| :return: dict |
| A printable dictionary version of the TaskConfig object. |
| |
| # TODO: should any default value in the TaskConfig not be printed? |
| """ |
| cfg_dict = asdict(self) |
| |
| for k, v in list(cfg_dict.items()): |
| if callable(v): |
| cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) |
| return cfg_dict |
|
|
| def serialize_function( |
| self, value: Union[Callable, str], keep_callable=False |
| ) -> Union[Callable, str]: |
| """Serializes a given function or string. |
| |
| If 'keep_callable' is True, the original callable is returned. |
| Otherwise, attempts to return the source code of the callable using 'getsource'. |
| """ |
| if keep_callable: |
| return value |
| else: |
| try: |
| return getsource(value) |
| except (TypeError, OSError): |
| return str(value) |
|
|
|
|
| class ConfigurableGroup(abc.ABC): |
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| ) -> None: |
| self._config = GroupConfig(**config) |
|
|
| @property |
| def group(self): |
| return self._config.group |
|
|
| @property |
| def group_alias(self): |
| return self._config.group_alias |
|
|
| @property |
| def version(self): |
| return self._config.version |
|
|
| @property |
| def config(self): |
| return self._config.to_dict() |
|
|
| @property |
| def group_name(self) -> Any: |
| return self._config.group |
|
|
| def __repr__(self): |
| return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" |
|
|