File size: 8,152 Bytes
c6c6312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import NotImplementedType
from typing import Any, Optional, TypeVar, Union

import numpy as np
from typing_extensions import Protocol

from cirq._doc import doc_private
from cirq.protocols import qid_shape_protocol
from cirq.protocols.apply_unitary_protocol import apply_unitaries, ApplyUnitaryArgs
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits

# This is a special indicator value used by the unitary method to determine
# whether or not the caller provided a 'default' argument. It must be of type
# np.ndarray to ensure the method has the correct type signature in that case.
# It is checked for using `is`, so it won't have a false positive if the user
# provides a different np.array([]) value.
RaiseTypeErrorIfNotProvided: np.ndarray = np.array([])

TDefault = TypeVar('TDefault')


class SupportsUnitary(Protocol):
    """An object that may be describable by a unitary matrix."""

    @doc_private
    def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
        """A unitary matrix describing this value, e.g. the matrix of a gate.

        This method is used by the global `cirq.unitary` method. If this method
        is not present, or returns NotImplemented, it is assumed that the
        receiving object doesn't have a unitary matrix (resulting in a TypeError
        or default result when calling `cirq.unitary` on it). (The ability to
        return NotImplemented is useful when a class cannot know if it has a
        matrix until runtime, e.g. cirq.X**c normally has a matrix but
        cirq.X**sympy.Symbol('a') doesn't.)

        The order of cells in the matrix is always implicit with respect to the
        object being called. For example, for gates the matrix must be ordered
        with respect to the list of qubits that the gate is applied to. For
        operations, the matrix is ordered to match the list returned by its
        `qubits` attribute. The qubit-to-amplitude order mapping matches the
        ordering of numpy.kron(A, B), where A is a qubit earlier in the list
        than the qubit B.

        Returns:
            A unitary matrix describing this value, or NotImplemented if there
            is no such matrix.
        """

    @doc_private
    def _has_unitary_(self) -> bool:
        """Whether this value has a unitary matrix representation.

        This method is used by the global `cirq.has_unitary` method.  If this
        method is not present, or returns NotImplemented, it will fallback
        to using _unitary_ with a default value, or False if neither exist.

        Returns:
            True if the value has a unitary matrix representation, False
            otherwise.
        """


def unitary(
    val: Any, default: Union[np.ndarray, TDefault] = RaiseTypeErrorIfNotProvided
) -> Union[np.ndarray, TDefault]:
    """Returns a unitary matrix describing the given value.

    The matrix is determined by any one of the following techniques:

    - The value has a `_unitary_` method that returns something besides None or
        NotImplemented. The matrix is whatever the method returned.
    - The value has a `_decompose_` method that returns a list of operations,
        and each operation in the list has a unitary effect. The matrix is
        created by aggregating the sub-operations' unitary effects.
    - The value has an `_apply_unitary_` method, and it returns something
        besides None or NotImplemented. The matrix is created by applying
        `_apply_unitary_` to an identity matrix.

    If none of these techniques succeeds, it is assumed that `val` doesn't have
    a unitary effect. The order in which techniques are attempted is
    unspecified.

    Args:
        val: The value to describe with a unitary matrix.
        default: Determines the fallback behavior when `val` doesn't have
            a unitary effect. If `default` is not set, a TypeError is raised. If
            `default` is set to a value, that value is returned.

    Returns:
        If `val` has a unitary effect, the corresponding unitary matrix.
        Otherwise, if `default` is specified, it is returned.

    Raises:
        TypeError: `val` doesn't have a unitary effect and no default value was
            specified.
    """
    strats = [
        _strat_unitary_from_unitary,
        _strat_unitary_from_apply_unitary,
        _strat_unitary_from_decompose,
    ]
    for strat in strats:
        result = strat(val)
        if result is None:
            break
        if result is not NotImplemented:
            return result

    if default is not RaiseTypeErrorIfNotProvided:
        return default
    raise TypeError(
        "cirq.unitary failed. "
        "Value doesn't have a (non-parameterized) unitary effect.\n"
        "\n"
        f"type: {type(val)}\n"
        f"value: {val!r}\n"
        "\n"
        "The value failed to satisfy any of the following criteria:\n"
        "- A `_unitary_(self)` method that returned a value "
        "besides None or NotImplemented.\n"
        "- A `_decompose_(self)` method that returned a "
        "list of unitary operations.\n"
        "- An `_apply_unitary_(self, args) method that returned a value "
        "besides None or NotImplemented."
    )


def _strat_unitary_from_unitary(val: Any) -> Optional[np.ndarray]:
    """Attempts to compute a value's unitary via its _unitary_ method."""
    getter = getattr(val, '_unitary_', None)
    if getter is None:
        return NotImplemented
    return getter()


def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
    """Attempts to compute a value's unitary via its _apply_unitary_ method."""
    # Check for the magic method.
    method = getattr(val, '_apply_unitary_', None)
    if method is None:
        return NotImplemented

    # Get the qid_shape.
    val_qid_shape = qid_shape_protocol.qid_shape(val, None)
    if val_qid_shape is None:
        return NotImplemented

    # Apply unitary effect to an identity matrix.
    result = method(ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape))

    if result is NotImplemented or result is None:
        return result
    state_len = np.prod(val_qid_shape, dtype=np.int64)
    return result.reshape((state_len, state_len))


def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
    """Attempts to compute a value's unitary via its _decompose_ method."""
    # Check if there's a decomposition.
    operations, qubits, val_qid_shape = _try_decompose_into_operations_and_qubits(val)
    if operations is None:
        return NotImplemented

    all_qubits = frozenset(q for op in operations for q in op.qubits)
    work_qubits = frozenset(qubits)
    ancillas = tuple(sorted(all_qubits.difference(work_qubits)))

    ordered_qubits = ancillas + tuple(qubits)
    val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape

    # Apply sub-operations' unitary effects to an identity matrix.
    result = apply_unitaries(
        operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(qid_shape=val_qid_shape), None
    )

    # Package result.
    if result is None:
        return None

    state_len = np.prod(val_qid_shape, dtype=np.int64)
    result = result.reshape((state_len, state_len))
    # Assuming borrowable qubits are restored to their original state and
    # clean qubits restord to the zero state then the desired unitary is
    # the upper left square.
    work_state_len = np.prod(val_qid_shape[len(ancillas) :], dtype=np.int64)
    return result[:work_state_len, :work_state_len]