| | from dataclasses import dataclass, field |
| |
|
| | try: |
| | import trackio.utils as utils |
| | except ImportError: |
| | import utils |
| |
|
| |
|
| | @dataclass |
| | class RunSelection: |
| | choices: list[str] = field(default_factory=list) |
| | selected: list[str] = field(default_factory=list) |
| | locked: bool = False |
| |
|
| | def update_choices( |
| | self, runs: list[str], preferred: list[str] | None = None |
| | ) -> bool: |
| | if self.choices == runs: |
| | return False |
| | new_choices = set(runs) - set(self.choices) |
| | self.choices = list(runs) |
| | if self.locked: |
| | base = set(self.selected) | new_choices |
| | elif preferred: |
| | base = set(preferred) |
| | else: |
| | base = set(runs) |
| | self.selected = [run for run in self.choices if run in base] |
| | return True |
| |
|
| | def select(self, runs: list[str]) -> list[str]: |
| | choice_set = set(self.choices) |
| | self.selected = [run for run in runs if run in choice_set] |
| | self.locked = True |
| | return self.selected |
| |
|
| | def replace_group( |
| | self, group_runs: list[str], new_subset: list[str] | None |
| | ) -> tuple[list[str], list[str]]: |
| | new_subset = utils.ordered_subset(group_runs, new_subset) |
| | selection_set = set(self.selected) |
| | selection_set.difference_update(group_runs) |
| | selection_set.update(new_subset) |
| | self.selected = [run for run in self.choices if run in selection_set] |
| | self.locked = True |
| | return new_subset, self.selected |
| |
|