File size: 1,900 Bytes
6f72e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# encoding: utf-8
"""Miscellaneous context managers."""

from __future__ import annotations

import warnings
from types import TracebackType
from typing import Any

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.


class preserve_keys:
    """Preserve a set of keys in a dictionary.

    Upon entering the context manager the current values of the keys
    will be saved. Upon exiting, the dictionary will be updated to
    restore the original value of the preserved keys. Preserved keys
    which did not exist when entering the context manager will be
    deleted.

    Examples
    --------

    >>> d = {'a': 1, 'b': 2, 'c': 3}
    >>> with preserve_keys(d, 'b', 'c', 'd'):
    ...     del d['a']
    ...     del d['b']      # will be reset to 2
    ...     d['c'] = None   # will be reset to 3
    ...     d['d'] = 4      # will be deleted
    ...     d['e'] = 5
    ...     print(sorted(d.items()))
    ...
    [('c', None), ('d', 4), ('e', 5)]
    >>> print(sorted(d.items()))
    [('b', 2), ('c', 3), ('e', 5)]
    """

    def __init__(self, dictionary: dict[Any, Any], *keys: Any) -> None:
        self.dictionary = dictionary
        self.keys = keys

    def __enter__(self) -> None:
        # Actions to perform upon exiting.
        to_delete: list[Any] = []
        to_update: dict[Any, Any] = {}

        d = self.dictionary
        for k in self.keys:
            if k in d:
                to_update[k] = d[k]
            else:
                to_delete.append(k)

        self.to_delete = to_delete
        self.to_update = to_update

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        d = self.dictionary

        for k in self.to_delete:
            d.pop(k, None)
        d.update(self.to_update)