from functools import wraps from keras.src import tree from keras.src.backend.common.global_state import get_global_attribute from keras.src.backend.common.global_state import set_global_attribute from keras.src.utils import python_utils class DotNotTrackScope: def __enter__(self): self.original_value = is_tracking_enabled() set_global_attribute("tracking_on", False) def __exit__(self, *args, **kwargs): set_global_attribute("tracking_on", self.original_value) def is_tracking_enabled(): return get_global_attribute("tracking_on", True) def no_automatic_dependency_tracking(fn): @wraps(fn) def wrapper(*args, **kwargs): with DotNotTrackScope(): return fn(*args, **kwargs) return wrapper class Tracker: """Attribute tracker, used for e.g. Variable tracking. Monitors certain attribute types and put them in appropriate lists in case of a match. Also passively tracks certain mutable collections (dict, list) so that items added to them later still get tracked. This is done by wrapping these collections into an equivalent, tracking-aware object. Example: ```python def __init__(self): self.tracker = Tracker( # Format: `name: (test_fn, store)` { "variables": (lambda x: isinstance(x, Variable), self._variables), "metrics": (lambda x: isinstance(x, Metric), self._metrics), "layers": (lambda x: isinstance(x, Layer), self._layers), } ) def __setattr__(self, name, value): if hasattr(self, "_tracker"): value = self._tracker.track(value) return super().__setattr__(name, value) ``` """ def __init__(self, config, exclusions=None): self.config = config self.stored_ids = {name: set() for name in self.config.keys()} self.locked = False self._lock_violation_msg = None self.exclusions = exclusions or {} def track(self, attr): if not is_tracking_enabled(): return attr for store_name, (is_attr_type, _) in self.config.items(): if is_attr_type(attr): if store_name in self.exclusions: for excl in self.exclusions[store_name]: if self.is_in_store(excl, attr): return attr if not self.is_in_store(store_name, attr): self.add_to_store(store_name, attr) return attr if isinstance(attr, tuple) and hasattr(attr, "_fields"): # Named tuple case. wrapped_attr = {} for name, e in attr._asdict().items(): wrapped_attr[name] = self.track(e) return attr.__class__(**wrapped_attr) if isinstance(attr, tuple): wrapped_attr = [] for e in attr: wrapped_attr.append(self.track(e)) return attr.__class__(wrapped_attr) elif isinstance(attr, list): return TrackedList(attr, self) elif isinstance(attr, dict): # TODO: OrderedDict? return TrackedDict(attr, self) elif isinstance(attr, set): return TrackedSet(attr, self) return attr def untrack(self, value): for store_name in self.stored_ids.keys(): if id(value) in self.stored_ids[store_name]: self.stored_ids[store_name].remove(id(value)) python_utils.remove_by_id(self.config[store_name][1], value) def lock(self, msg=None): self.locked = True if msg is not None: self._lock_violation_msg = msg def unlock(self): self.locked = False def add_to_store(self, store_name, value): if self.locked: raise ValueError(self._lock_violation_msg) self.config[store_name][1].append(value) self.stored_ids[store_name].add(id(value)) def is_in_store(self, store_name, value): return id(value) in self.stored_ids[store_name] def replace_tracked_value(self, store_name, old_value, new_value): if not self.is_in_store(store_name, old_value): raise ValueError(f"Unknown value: {old_value}") store_list = self.config[store_name][1] index = store_list.index(old_value) store_list[index] = new_value self.stored_ids[store_name].remove(id(old_value)) self.stored_ids[store_name].add(id(new_value)) @tree.register_tree_node_class class TrackedList(list): def __init__(self, values=None, tracker=None): self.tracker = tracker if tracker and values: values = [tracker.track(v) for v in values] super().__init__(values or []) def append(self, value): if self.tracker: self.tracker.track(value) super().append(value) def insert(self, index, value): if self.tracker: self.tracker.track(value) super().insert(index, value) def extend(self, values): if self.tracker: values = [self.tracker.track(v) for v in values] super().extend(values) def remove(self, value): if self.tracker: self.tracker.untrack(value) try: super().remove(value) except ValueError: python_utils.remove_by_id(self, value) def pop(self, index=-1): if self.tracker: value = self[index] self.tracker.untrack(value) return super().pop(index) else: return super().pop(index) def clear(self): if self.tracker: for value in self: self.tracker.untrack(value) super().clear() def __delitem__(self, index): value = self[index] # Get value before removing super().__delitem__(index) if self.tracker: self.tracker.untrack(value) def tree_flatten(self): # For optree / dmtree return (self, None) @classmethod def tree_unflatten(cls, metadata, children): # For optree / dmtree return cls(children) @tree.register_tree_node_class class TrackedDict(dict): def __init__(self, values=None, tracker=None): self.tracker = tracker if tracker and values: values = {k: tracker.track(v) for k, v in values.items()} super().__init__(values or []) def __setitem__(self, key, value): if self.tracker: self.tracker.track(value) super().__setitem__(key, value) def update(self, mapping): if self.tracker: mapping = {k: self.tracker.track(v) for k, v in mapping.items()} super().update(mapping) def pop(self, key, default=None): if self.tracker: value = super().pop(key, default) if value is not default: self.tracker.untrack(value) return value else: return super().pop(key, default) def popitem(self): key, value = super().popitem() if self.tracker: self.tracker.untrack(value) return key, value def clear(self): if self.tracker: for value in self.values(): self.tracker.untrack(value) super().clear() def tree_flatten(self): # For optree / dmtree keys = sorted(list(self.keys())) values = [self[k] for k in keys] return values, keys, keys @classmethod def tree_unflatten(cls, keys, values): # For optree / dmtree return cls(zip(keys, values)) @tree.register_tree_node_class class TrackedSet(set): def __init__(self, values=None, tracker=None): self.tracker = tracker if tracker and values: values = {tracker.track(v) for v in values} super().__init__(values or []) def add(self, value): if self.tracker: self.tracker.track(value) super().add(value) def update(self, values): if self.tracker: values = [self.tracker.track(v) for v in values] super().update(values) def remove(self, value): if self.tracker: self.tracker.untrack(value) super().remove(value) def pop(self): value = super().pop() if self.tracker: self.tracker.untrack(value) return value def clear(self): if self.tracker: for value in self: self.tracker.untrack(value) super().clear() def tree_flatten(self): # For optree / dmtree return (self, None) @classmethod def tree_unflatten(cls, metadata, children): # For optree / dmtree return cls(children)