|
|
from dataclasses import dataclass, field |
|
|
|
|
|
try: |
|
|
import trackio.utils as utils |
|
|
except ImportError: |
|
|
import utils |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RunSelection: |
|
|
choices: list[str] = field(default_factory=list) |
|
|
selected: list[str] = field(default_factory=list) |
|
|
locked: bool = False |
|
|
|
|
|
def update_choices( |
|
|
self, runs: list[str], preferred: list[str] | None = None |
|
|
) -> bool: |
|
|
if self.choices == runs: |
|
|
return False |
|
|
new_choices = set(runs) - set(self.choices) |
|
|
self.choices = list(runs) |
|
|
if self.locked: |
|
|
base = set(self.selected) | new_choices |
|
|
elif preferred: |
|
|
base = set(preferred) |
|
|
else: |
|
|
base = set(runs) |
|
|
self.selected = [run for run in self.choices if run in base] |
|
|
return True |
|
|
|
|
|
def select(self, runs: list[str]) -> list[str]: |
|
|
choice_set = set(self.choices) |
|
|
self.selected = [run for run in runs if run in choice_set] |
|
|
self.locked = True |
|
|
return self.selected |
|
|
|
|
|
def replace_group( |
|
|
self, group_runs: list[str], new_subset: list[str] | None |
|
|
) -> tuple[list[str], list[str]]: |
|
|
new_subset = utils.ordered_subset(group_runs, new_subset) |
|
|
selection_set = set(self.selected) |
|
|
selection_set.difference_update(group_runs) |
|
|
selection_set.update(new_subset) |
|
|
self.selected = [run for run in self.choices if run in selection_set] |
|
|
self.locked = True |
|
|
return new_subset, self.selected |
|
|
|