Prompt48 commited on
Commit
deaf69e
·
verified ·
1 Parent(s): 1885203

Upload edit\Qwen3-TTS-test\.venv\Lib\site-packages\sklearn\externals\array_api_compat\torch\linalg.py with huggingface_hub

Browse files
edit//Qwen3-TTS-test//.venv//Lib//site-packages//sklearn//externals//array_api_compat//torch//linalg.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from typing import Optional, Union, Tuple
5
+
6
+ from torch.linalg import * # noqa: F403
7
+
8
+ # torch.linalg doesn't define __all__
9
+ # from torch.linalg import __all__ as linalg_all
10
+ from torch import linalg as torch_linalg
11
+ linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
12
+
13
+ # outer is implemented in torch but aren't in the linalg namespace
14
+ from torch import outer
15
+ from ._aliases import _fix_promotion, sum
16
+ # These functions are in both the main and linalg namespaces
17
+ from ._aliases import matmul, matrix_transpose, tensordot
18
+ from ._typing import Array, DType
19
+ from ..common._typing import JustInt, JustFloat
20
+
21
+ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
22
+ # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
23
+
24
+ # torch.cross also does not support broadcasting when it would add new
25
+ # dimensions https://github.com/pytorch/pytorch/issues/39656
26
+ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
27
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
28
+ if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
29
+ raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
30
+ if not (x1.shape[axis] == x2.shape[axis] == 3):
31
+ raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
32
+ x1, x2 = torch.broadcast_tensors(x1, x2)
33
+ return torch_linalg.cross(x1, x2, dim=axis)
34
+
35
+ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
36
+ from ._aliases import isdtype
37
+
38
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
39
+
40
+ # torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
41
+ if x1.shape[axis] != x2.shape[axis]:
42
+ raise ValueError("x1 and x2 must have the same size along the given axis")
43
+
44
+ # torch.linalg.vecdot doesn't support integer dtypes
45
+ if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
46
+ if kwargs:
47
+ raise RuntimeError("vecdot kwargs not supported for integral dtypes")
48
+
49
+ x1_ = torch.moveaxis(x1, axis, -1)
50
+ x2_ = torch.moveaxis(x2, axis, -1)
51
+ x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
52
+
53
+ res = x1_[..., None, :] @ x2_[..., None]
54
+ return res[..., 0, 0]
55
+ return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
56
+
57
+ def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
58
+ x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
59
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
60
+ # whenever
61
+ # 1. x1.ndim - 1 == x2.ndim
62
+ # 2. x1.shape[:-1] == x2.shape
63
+ #
64
+ # See linalg_solve_is_vector_rhs in
65
+ # aten/src/ATen/native/LinearAlgebraUtils.h and
66
+ # TORCH_META_FUNC(_linalg_solve_ex) in
67
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
68
+ #
69
+ # The easiest way to work around this is to prepend a size 1 dimension to
70
+ # x2, since x2 is already one dimension less than x1.
71
+ #
72
+ # See https://github.com/pytorch/pytorch/issues/52915
73
+ if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
74
+ x2 = x2[None]
75
+ return torch.linalg.solve(x1, x2, **kwargs)
76
+
77
+ # torch.trace doesn't support the offset argument and doesn't support stacking
78
+ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
79
+ # Use our wrapped sum to make sure it does upcasting correctly
80
+ return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
81
+
82
+ def vector_norm(
83
+ x: Array,
84
+ /,
85
+ *,
86
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
87
+ keepdims: bool = False,
88
+ # JustFloat stands for inf | -inf, which are not valid for Literal
89
+ ord: JustInt | JustFloat = 2,
90
+ **kwargs,
91
+ ) -> Array:
92
+ # torch.vector_norm incorrectly treats axis=() the same as axis=None
93
+ if axis == ():
94
+ out = kwargs.get('out')
95
+ if out is None:
96
+ dtype = None
97
+ if x.dtype == torch.complex64:
98
+ dtype = torch.float32
99
+ elif x.dtype == torch.complex128:
100
+ dtype = torch.float64
101
+
102
+ out = torch.zeros_like(x, dtype=dtype)
103
+
104
+ # The norm of a single scalar works out to abs(x) in every case except
105
+ # for ord=0, which is x != 0.
106
+ if ord == 0:
107
+ out[:] = (x != 0)
108
+ else:
109
+ out[:] = torch.abs(x)
110
+ return out
111
+ return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
112
+
113
+ __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
114
+ 'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
115
+
116
+ _all_ignore = ['torch_linalg', 'sum']
117
+
118
+ del linalg_all
119
+
120
+ def __dir__() -> list[str]:
121
+ return __all__