joebruce1313's picture
Upload 38004 files
1f5470c verified
from functools import wraps
from keras.src import tree
from keras.src.backend.common.global_state import get_global_attribute
from keras.src.backend.common.global_state import set_global_attribute
from keras.src.utils import python_utils
class DotNotTrackScope:
def __enter__(self):
self.original_value = is_tracking_enabled()
set_global_attribute("tracking_on", False)
def __exit__(self, *args, **kwargs):
set_global_attribute("tracking_on", self.original_value)
def is_tracking_enabled():
return get_global_attribute("tracking_on", True)
def no_automatic_dependency_tracking(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
with DotNotTrackScope():
return fn(*args, **kwargs)
return wrapper
class Tracker:
"""Attribute tracker, used for e.g. Variable tracking.
Monitors certain attribute types
and put them in appropriate lists in case of a match.
Also passively tracks certain mutable collections
(dict, list) so that items added to them later
still get tracked. This is done by wrapping these
collections into an equivalent, tracking-aware object.
Example:
```python
def __init__(self):
self.tracker = Tracker(
# Format: `name: (test_fn, store)`
{
"variables":
(lambda x: isinstance(x, Variable), self._variables),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (lambda x: isinstance(x, Layer), self._layers),
}
)
def __setattr__(self, name, value):
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
```
"""
def __init__(self, config, exclusions=None):
self.config = config
self.stored_ids = {name: set() for name in self.config.keys()}
self.locked = False
self._lock_violation_msg = None
self.exclusions = exclusions or {}
def track(self, attr):
if not is_tracking_enabled():
return attr
for store_name, (is_attr_type, _) in self.config.items():
if is_attr_type(attr):
if store_name in self.exclusions:
for excl in self.exclusions[store_name]:
if self.is_in_store(excl, attr):
return attr
if not self.is_in_store(store_name, attr):
self.add_to_store(store_name, attr)
return attr
if isinstance(attr, tuple) and hasattr(attr, "_fields"):
# Named tuple case.
wrapped_attr = {}
for name, e in attr._asdict().items():
wrapped_attr[name] = self.track(e)
return attr.__class__(**wrapped_attr)
if isinstance(attr, tuple):
wrapped_attr = []
for e in attr:
wrapped_attr.append(self.track(e))
return attr.__class__(wrapped_attr)
elif isinstance(attr, list):
return TrackedList(attr, self)
elif isinstance(attr, dict):
# TODO: OrderedDict?
return TrackedDict(attr, self)
elif isinstance(attr, set):
return TrackedSet(attr, self)
return attr
def untrack(self, value):
for store_name in self.stored_ids.keys():
if id(value) in self.stored_ids[store_name]:
self.stored_ids[store_name].remove(id(value))
python_utils.remove_by_id(self.config[store_name][1], value)
def lock(self, msg=None):
self.locked = True
if msg is not None:
self._lock_violation_msg = msg
def unlock(self):
self.locked = False
def add_to_store(self, store_name, value):
if self.locked:
raise ValueError(self._lock_violation_msg)
self.config[store_name][1].append(value)
self.stored_ids[store_name].add(id(value))
def is_in_store(self, store_name, value):
return id(value) in self.stored_ids[store_name]
def replace_tracked_value(self, store_name, old_value, new_value):
if not self.is_in_store(store_name, old_value):
raise ValueError(f"Unknown value: {old_value}")
store_list = self.config[store_name][1]
index = store_list.index(old_value)
store_list[index] = new_value
self.stored_ids[store_name].remove(id(old_value))
self.stored_ids[store_name].add(id(new_value))
@tree.register_tree_node_class
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = [tracker.track(v) for v in values]
super().__init__(values or [])
def append(self, value):
if self.tracker:
self.tracker.track(value)
super().append(value)
def insert(self, index, value):
if self.tracker:
self.tracker.track(value)
super().insert(index, value)
def extend(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().extend(values)
def remove(self, value):
if self.tracker:
self.tracker.untrack(value)
try:
super().remove(value)
except ValueError:
python_utils.remove_by_id(self, value)
def pop(self, index=-1):
if self.tracker:
value = self[index]
self.tracker.untrack(value)
return super().pop(index)
else:
return super().pop(index)
def clear(self):
if self.tracker:
for value in self:
self.tracker.untrack(value)
super().clear()
def __delitem__(self, index):
value = self[index] # Get value before removing
super().__delitem__(index)
if self.tracker:
self.tracker.untrack(value)
def tree_flatten(self):
# For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree / dmtree
return cls(children)
@tree.register_tree_node_class
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {k: tracker.track(v) for k, v in values.items()}
super().__init__(values or [])
def __setitem__(self, key, value):
if self.tracker:
self.tracker.track(value)
super().__setitem__(key, value)
def update(self, mapping):
if self.tracker:
mapping = {k: self.tracker.track(v) for k, v in mapping.items()}
super().update(mapping)
def pop(self, key, default=None):
if self.tracker:
value = super().pop(key, default)
if value is not default:
self.tracker.untrack(value)
return value
else:
return super().pop(key, default)
def popitem(self):
key, value = super().popitem()
if self.tracker:
self.tracker.untrack(value)
return key, value
def clear(self):
if self.tracker:
for value in self.values():
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree / dmtree
keys = sorted(list(self.keys()))
values = [self[k] for k in keys]
return values, keys, keys
@classmethod
def tree_unflatten(cls, keys, values):
# For optree / dmtree
return cls(zip(keys, values))
@tree.register_tree_node_class
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {tracker.track(v) for v in values}
super().__init__(values or [])
def add(self, value):
if self.tracker:
self.tracker.track(value)
super().add(value)
def update(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().update(values)
def remove(self, value):
if self.tracker:
self.tracker.untrack(value)
super().remove(value)
def pop(self):
value = super().pop()
if self.tracker:
self.tracker.untrack(value)
return value
def clear(self):
if self.tracker:
for value in self:
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree / dmtree
return cls(children)