Spaces:
Sleeping
Sleeping
| import pathlib | |
| from copy import copy | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from colorama import Fore | |
| from omegaconf import OmegaConf | |
| from yaml.constructor import ConstructorError | |
| KNOWN_TAGS = ["target", "context", "info", "debug"] | |
| class CustomPath(pathlib.Path): | |
| """A custom path class that can be formatted to display as a hyperlink in terminal.""" | |
| # This is a hack to inherit pathlib.Path and initialize the _flavour property. | |
| # https://stackoverflow.com/questions/61689391/error-with-simple-subclassing-of-pathlib-path-no-flavour-attribute | |
| # noinspection PyProtectedMember | |
| # noinspection PyUnresolvedReferences | |
| _flavour = type(pathlib.Path())._flavour | |
| def __format__(self, format_spec): | |
| if format_spec == '': | |
| return str(self) | |
| elif format_spec == 'link': | |
| if self.exists(): | |
| return _create_hyperlink(self.resolve()) | |
| else: | |
| # Missing path: find first existing parent | |
| missing_path = self.resolve() | |
| existing_parent = self.parent | |
| while existing_parent and not existing_parent.exists(): | |
| existing_parent = existing_parent.parent | |
| # Build base error message | |
| base_msg = f"\033[1;31m{missing_path} does not exist.\033[0m" | |
| if existing_parent and existing_parent.exists(): | |
| parent_link = _create_hyperlink(existing_parent.resolve()) | |
| # Gather existing parent’s contents | |
| content_msg = "" | |
| if existing_parent.is_dir(): | |
| content = list(existing_parent.iterdir()) | |
| if content: | |
| content_msg = ( | |
| "\n" + cyan("Nearest existing directory contents:") + "\n" + | |
| "\n".join([' ' + _create_hyperlink(p.resolve()) for p in content]) | |
| ) | |
| return f"{base_msg}\nNearest existing directory: {parent_link}{content_msg}" | |
| else: | |
| return f"{base_msg}\n(No existing parent found.)" | |
| elif format_spec.startswith('last'): | |
| i = int(format_spec[4:]) | |
| return "/".join(self.parts[-i:]) | |
| elif format_spec == 'exists': | |
| if self.exists(): | |
| # Normal case: just print the link | |
| return _create_hyperlink(self.resolve()) | |
| else: | |
| return _create_hyperlink(self.resolve()) + ' does not exist. \nParent directory: ' + _create_hyperlink( | |
| self.parent.resolve()) | |
| else: | |
| return str(self).__format__(format_spec) | |
| def __iadd__(self, other: str): | |
| return CustomPath(str(self) + other) | |
| def __add__(self, other: str): | |
| return CustomPath(str(self) + other) | |
| def is_json(self): | |
| return self.suffix == '.json' | |
| def is_yaml(self): | |
| return self.suffix == '.yaml' | |
| def json_encoder(self): | |
| return str(self) | |
| def __sub__(self, other): | |
| return CustomPath(self.resolve().relative_to(other.resolve())) | |
| def _create_hyperlink(text: str | pathlib.Path): | |
| if isinstance(text, pathlib.Path): | |
| text = str(text) | |
| return f'file:///' + text.replace('\\', '/') | |
| def cyan(text: str) -> str: | |
| return f"{Fore.CYAN}{text}{Fore.RESET}" | |
| class FrequencyScheduler: | |
| def __init__( | |
| self, | |
| last_step: int, | |
| frequencies: list[int] | None = None, | |
| steps: list[int] | None = None, | |
| iters: list[int] | None = None, | |
| enable_target: bool = True, | |
| enable_context: bool = True, | |
| enable_info: bool = True, | |
| enable_debug: bool = True, | |
| ): | |
| if iters is not None: | |
| print("FrequencyScheduler: using iters argument, ignoring frequencies and steps.") | |
| # assert frequencies is None and steps is None, "When iters is provided, frequencies and steps must be None" | |
| elif frequencies is None and steps is None: | |
| # Make sure frequencies and steps are both either None or lists of the same length | |
| frequencies = [99999999] # effectively never | |
| steps = [0] | |
| elif frequencies is None or steps is None: | |
| raise ValueError("frequencies and steps must both be None or both be lists") | |
| else: | |
| assert len(frequencies) == len( | |
| steps), f"frequencies and steps must be same length. Got {len(frequencies)} and {len(steps)}" | |
| assert steps[0] == 0, f"first step must be 0. Got {steps}" | |
| if iters is not None: | |
| self.iterations = copy(iters) | |
| # check if last step in iters, else add it to iters and sort, remove higher than last_step | |
| self.iterations = sorted([i for i in self.iterations if i <= last_step]) | |
| if last_step not in self.iterations: | |
| self.iterations.append(last_step) | |
| if 0 not in self.iterations: | |
| self.iterations.insert(0, 0) | |
| else: | |
| frequencies = copy(frequencies) | |
| steps = copy(steps) | |
| steps.pop(0) # remove the first step which is always 0 | |
| if last_step not in steps: | |
| steps.append(last_step) # ensure last step is included | |
| pairs = list(zip(frequencies, steps)) | |
| self.iterations: list[int] = self.get_all_iterations(pairs, last_step) | |
| self.verbose = False | |
| self.last_step = last_step | |
| self.enabled_tags = { | |
| "target": enable_target, | |
| "context": enable_context, | |
| "info": enable_info, | |
| "debug": enable_debug | |
| } | |
| self.is_disabled = False | |
| def set_verbose(self, verbose: bool): | |
| self.verbose = verbose | |
| def set_all_tags(self, enabled: bool): | |
| for key in self.enabled_tags: | |
| self.enabled_tags[key] = enabled | |
| def check_iteration(self, iteration: int, tag: str) -> bool: | |
| """Returns True if any frequency event occurs at this iteration.""" | |
| assert tag in KNOWN_TAGS, f"Invalid tag: {tag}, must be in {KNOWN_TAGS}" | |
| if self.enabled_tags[tag]: | |
| return iteration in self.iterations | |
| else: | |
| return False | |
| def _occurs_at(self, iteration: int, pairs, last_step) -> bool: | |
| """Returns True if any frequency event occurs at this iteration.""" | |
| if iteration == last_step: | |
| return True | |
| for freq, end in pairs: | |
| if iteration <= end: | |
| if iteration % freq == 0: | |
| return True | |
| else: | |
| break | |
| return False | |
| def get_all_iterations(self, pairs, last_step) -> list[int]: | |
| """Returns a list of all iterations where an event occurs up to the last step.""" | |
| t = 0 | |
| iterations = [] | |
| while t <= last_step: | |
| if self._occurs_at(t, pairs, last_step): | |
| iterations.append(t) | |
| t += 1 | |
| return iterations | |
| def get_iterations(self, length_of_event: int) -> list[int]: | |
| """Returns a list of all iterations where an event occurs up to the given length.""" | |
| if self.iterations is not None and len(self.iterations) >= length_of_event: | |
| if length_of_event == 1: | |
| return [self.iterations[-1]] | |
| return self.iterations[:length_of_event] | |
| else: | |
| raise ValueError( | |
| f"Not enough iterations up to last_step {self.last_step} to get {length_of_event} events. " | |
| f"Only got {len(self.iterations)} events.") | |
| def disable(self, flag): | |
| self.is_disabled = flag | |
| def __call__(self, iteration: int, tag: str = "") -> bool: | |
| if self.is_disabled: | |
| return False | |
| return self.check_iteration(iteration, tag) | |
| def __repr__(self): | |
| return f"FrequencyScheduler({self.iterations})" | |
| def log_mem(tag=""): | |
| torch.cuda.synchronize() | |
| print(f"{tag}: allocated={torch.cuda.memory_allocated() / 1e6:.1f}MB, " | |
| f"reserved={torch.cuda.memory_reserved() / 1e6:.1f}MB, " | |
| f"max_allocated={torch.cuda.max_memory_allocated() / 1e6:.1f}MB") | |
| def read_omega_cfg(path: pathlib.Path) -> OmegaConf: | |
| """Reads an OmegaConf YAML file, handling custom tags safely.""" | |
| try: | |
| loaded_cfg = OmegaConf.load(path) | |
| except ConstructorError as e: | |
| # --- 1. Define a safe fallback constructor for the tag --- | |
| def custompath_constructor(loader, node): | |
| # Detect if it's a scalar or sequence | |
| if isinstance(node, yaml.ScalarNode): | |
| value = loader.construct_scalar(node) | |
| return CustomPath(value) | |
| elif isinstance(node, yaml.SequenceNode): | |
| seq = loader.construct_sequence(node) | |
| # joint the seq parts into a path | |
| path = CustomPath() | |
| for part in seq: | |
| path = path / str(part) | |
| print(path) | |
| return path | |
| else: | |
| raise TypeError(f"Unsupported YAML node type for CustomPath: {type(node)}") | |
| # Register for both the current tag and the legacy `src.` tag: | |
| # checkpoints released/trained before the src->optgs package rename | |
| # embed `...apply:src.misc.io.CustomPath` in their saved config.yaml. | |
| for _tag in ( | |
| 'tag:yaml.org,2002:python/object/apply:optgs.misc.io.CustomPath', | |
| 'tag:yaml.org,2002:python/object/apply:src.misc.io.CustomPath', | |
| ): | |
| yaml.add_constructor(_tag, custompath_constructor) | |
| # --- 2. Load with PyYAML safely --- | |
| with open(path, "r") as f: | |
| raw_cfg = yaml.load(f, Loader=yaml.FullLoader) | |
| # --- 3. Convert to OmegaConf --- | |
| loaded_cfg = OmegaConf.create(raw_cfg) | |
| return loaded_cfg | |
| if __name__ == '__main__': | |
| print_every = FrequencyScheduler( | |
| frequencies=[1, 2, 5], | |
| steps=[0, 5, 10], | |
| last_step=37, | |
| # iters=[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 56, 67, 100] | |
| ) | |
| for i in range(37 + 1): | |
| if print_every(i, "target"): | |
| pass | |
| print(print_every.get_iterations(15)) | |