File size: 10,346 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import pathlib
from copy import copy
import numpy as np
import torch
import yaml
from colorama import Fore
from omegaconf import OmegaConf
from yaml.constructor import ConstructorError

KNOWN_TAGS = ["target", "context", "info", "debug"]


class CustomPath(pathlib.Path):
    """A custom path class that can be formatted to display as a hyperlink in terminal."""

    # This is a hack to inherit pathlib.Path and initialize the _flavour property.
    # https://stackoverflow.com/questions/61689391/error-with-simple-subclassing-of-pathlib-path-no-flavour-attribute
    # noinspection PyProtectedMember
    # noinspection PyUnresolvedReferences
    _flavour = type(pathlib.Path())._flavour

    def __format__(self, format_spec):
        if format_spec == '':
            return str(self)
        elif format_spec == 'link':
            if self.exists():
                return _create_hyperlink(self.resolve())
            else:
                # Missing path: find first existing parent
                missing_path = self.resolve()
                existing_parent = self.parent
                while existing_parent and not existing_parent.exists():
                    existing_parent = existing_parent.parent

                # Build base error message
                base_msg = f"\033[1;31m{missing_path} does not exist.\033[0m"

                if existing_parent and existing_parent.exists():
                    parent_link = _create_hyperlink(existing_parent.resolve())

                    # Gather existing parent’s contents
                    content_msg = ""
                    if existing_parent.is_dir():
                        content = list(existing_parent.iterdir())
                        if content:
                            content_msg = (
                                    "\n" + cyan("Nearest existing directory contents:") + "\n" +
                                    "\n".join(['  ' + _create_hyperlink(p.resolve()) for p in content])
                            )

                    return f"{base_msg}\nNearest existing directory: {parent_link}{content_msg}"
                else:
                    return f"{base_msg}\n(No existing parent found.)"
        elif format_spec.startswith('last'):
            i = int(format_spec[4:])
            return "/".join(self.parts[-i:])
        elif format_spec == 'exists':
            if self.exists():
                # Normal case: just print the link
                return _create_hyperlink(self.resolve())
            else:
                return _create_hyperlink(self.resolve()) + ' does not exist. \nParent directory: ' + _create_hyperlink(
                    self.parent.resolve())
        else:
            return str(self).__format__(format_spec)

    def __iadd__(self, other: str):
        return CustomPath(str(self) + other)

    def __add__(self, other: str):
        return CustomPath(str(self) + other)

    def is_json(self):
        return self.suffix == '.json'

    def is_yaml(self):
        return self.suffix == '.yaml'

    def json_encoder(self):
        return str(self)

    def __sub__(self, other):
        return CustomPath(self.resolve().relative_to(other.resolve()))


def _create_hyperlink(text: str | pathlib.Path):
    if isinstance(text, pathlib.Path):
        text = str(text)
    return f'file:///' + text.replace('\\', '/')


def cyan(text: str) -> str:
    return f"{Fore.CYAN}{text}{Fore.RESET}"


class FrequencyScheduler:
    def __init__(
            self,
            last_step: int,
            frequencies: list[int] | None = None,
            steps: list[int] | None = None,
            iters: list[int] | None = None,
            enable_target: bool = True,
            enable_context: bool = True,
            enable_info: bool = True,
            enable_debug: bool = True,
    ):
        if iters is not None:
            print("FrequencyScheduler: using iters argument, ignoring frequencies and steps.")
            # assert frequencies is None and steps is None, "When iters is provided, frequencies and steps must be None"
        elif frequencies is None and steps is None:
            # Make sure frequencies and steps are both either None or lists of the same length
            frequencies = [99999999]  # effectively never
            steps = [0]
        elif frequencies is None or steps is None:
            raise ValueError("frequencies and steps must both be None or both be lists")
        else:
            assert len(frequencies) == len(
                steps), f"frequencies and steps must be same length. Got {len(frequencies)} and {len(steps)}"
            assert steps[0] == 0, f"first step must be 0. Got {steps}"

        if iters is not None:
            self.iterations = copy(iters)
            # check if last step in iters, else add it to iters and sort, remove higher than last_step
            self.iterations = sorted([i for i in self.iterations if i <= last_step])
            if last_step not in self.iterations:
                self.iterations.append(last_step)
            if 0 not in self.iterations:
                self.iterations.insert(0, 0)
        else:
            frequencies = copy(frequencies)
            steps = copy(steps)
            steps.pop(0)  # remove the first step which is always 0
            if last_step not in steps:
                steps.append(last_step)  # ensure last step is included
            pairs = list(zip(frequencies, steps))
            self.iterations: list[int] = self.get_all_iterations(pairs, last_step)

        self.verbose = False
        self.last_step = last_step

        self.enabled_tags = {
            "target": enable_target,
            "context": enable_context,
            "info": enable_info,
            "debug": enable_debug
        }

        self.is_disabled = False

    def set_verbose(self, verbose: bool):
        self.verbose = verbose

    def set_all_tags(self, enabled: bool):
        for key in self.enabled_tags:
            self.enabled_tags[key] = enabled

    def check_iteration(self, iteration: int, tag: str) -> bool:
        """Returns True if any frequency event occurs at this iteration."""
        assert tag in KNOWN_TAGS, f"Invalid tag: {tag}, must be in {KNOWN_TAGS}"
        if self.enabled_tags[tag]:
            return iteration in self.iterations
        else:
            return False

    def _occurs_at(self, iteration: int, pairs, last_step) -> bool:
        """Returns True if any frequency event occurs at this iteration."""

        if iteration == last_step:
            return True

        for freq, end in pairs:
            if iteration <= end:
                if iteration % freq == 0:
                    return True
                else:
                    break

        return False

    def get_all_iterations(self, pairs, last_step) -> list[int]:
        """Returns a list of all iterations where an event occurs up to the last step."""
        t = 0
        iterations = []
        while t <= last_step:
            if self._occurs_at(t, pairs, last_step):
                iterations.append(t)
            t += 1
        return iterations

    def get_iterations(self, length_of_event: int) -> list[int]:
        """Returns a list of all iterations where an event occurs up to the given length."""
        if self.iterations is not None and len(self.iterations) >= length_of_event:
            if length_of_event == 1:
                return [self.iterations[-1]]
            return self.iterations[:length_of_event]
        else:
            raise ValueError(
                f"Not enough iterations up to last_step {self.last_step} to get {length_of_event} events. "
                f"Only got {len(self.iterations)} events.")

    def disable(self, flag):
        self.is_disabled = flag

    def __call__(self, iteration: int, tag: str = "") -> bool:
        if self.is_disabled:
            return False
        return self.check_iteration(iteration, tag)

    def __repr__(self):
        return f"FrequencyScheduler({self.iterations})"


def log_mem(tag=""):
    torch.cuda.synchronize()
    print(f"{tag}: allocated={torch.cuda.memory_allocated() / 1e6:.1f}MB, "
          f"reserved={torch.cuda.memory_reserved() / 1e6:.1f}MB, "
          f"max_allocated={torch.cuda.max_memory_allocated() / 1e6:.1f}MB")


def read_omega_cfg(path: pathlib.Path) -> OmegaConf:
    """Reads an OmegaConf YAML file, handling custom tags safely."""
    try:
        loaded_cfg = OmegaConf.load(path)
    except ConstructorError as e:
        # --- 1. Define a safe fallback constructor for the tag ---
        def custompath_constructor(loader, node):
            # Detect if it's a scalar or sequence
            if isinstance(node, yaml.ScalarNode):
                value = loader.construct_scalar(node)
                return CustomPath(value)
            elif isinstance(node, yaml.SequenceNode):
                seq = loader.construct_sequence(node)
                # joint the seq parts into a path
                path = CustomPath()
                for part in seq:
                    path = path / str(part)
                print(path)
                return path
            else:
                raise TypeError(f"Unsupported YAML node type for CustomPath: {type(node)}")

        # Register for both the current tag and the legacy `src.` tag:
        # checkpoints released/trained before the src->optgs package rename
        # embed `...apply:src.misc.io.CustomPath` in their saved config.yaml.
        for _tag in (
            'tag:yaml.org,2002:python/object/apply:optgs.misc.io.CustomPath',
            'tag:yaml.org,2002:python/object/apply:src.misc.io.CustomPath',
        ):
            yaml.add_constructor(_tag, custompath_constructor)

        # --- 2. Load with PyYAML safely ---
        with open(path, "r") as f:
            raw_cfg = yaml.load(f, Loader=yaml.FullLoader)

        # --- 3. Convert to OmegaConf ---
        loaded_cfg = OmegaConf.create(raw_cfg)

    return loaded_cfg


if __name__ == '__main__':

    print_every = FrequencyScheduler(
        frequencies=[1, 2, 5],
        steps=[0, 5, 10],
        last_step=37,
        # iters=[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 56, 67, 100]    
    )
    for i in range(37 + 1):
        if print_every(i, "target"):
            pass

    print(print_every.get_iterations(15))