Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/BEq.lean +155 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/Basic.lean +134 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/DecEq.lean +212 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/FromToJson.lean +249 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Term.lean +2128 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Time.lean +26 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Util.lean +249 -0
- backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/WhereFinally.lean +29 -0
- external/alphageometry/.venv-ag/Lib/site-packages/absl/app.py +488 -0
- external/alphageometry/.venv-ag/Lib/site-packages/absl/app.pyi +88 -0
- external/alphageometry/.venv-ag/Lib/site-packages/absl/command_name.py +63 -0
- external/alphageometry/.venv-ag/Lib/site-packages/absl/py.typed +0 -0
- external/alphageometry/.venv-ag/Lib/site-packages/distutils-precedence.pth +1 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/__init__.py +27 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/__init__.py +213 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_array_methods.py +45 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_creation_functions.py +31 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_data_type_functions.py +78 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_elementwise_functions.py +75 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_fft_functions.py +25 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_linear_algebra_functions.py +28 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_manipulation_functions.py +25 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_statistical_functions.py +25 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_utility_functions.py +86 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_version.py +15 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/fft.py +33 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/linalg.py +43 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/__init__.py +13 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization.py +635 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization_test.py +493 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/__init__.py +13 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/compilation_cache.py +20 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/export/__init__.py +36 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/__init__.py +23 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/call_tf.py +682 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/__init__.py +13 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main.py +78 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +50 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/mnist_lib.py +324 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py +154 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main.py +210 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main_test.py +70 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/__init__.py +13 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/model_server_request.py +128 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/impl_no_xla.py +1287 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/jax2tf.py +0 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/__init__.py +13 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/back_compat_tf_test.py +154 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/call_tf_test.py +1821 -0
.gitattributes
CHANGED
|
@@ -4510,3 +4510,5 @@ external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64-arm.ex
|
|
| 4510 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 4511 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 4512 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 4510 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 4511 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 4512 |
external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
|
| 4513 |
+
hfenv/Lib/site-packages/setuptools/cli-arm64.exe filter=lfs diff=lfs merge=lfs -text
|
| 4514 |
+
hfenv/Lib/site-packages/setuptools/gui-arm64.exe filter=lfs diff=lfs merge=lfs -text
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/BEq.lean
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Leonardo de Moura
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Meta.Transform
|
| 8 |
+
import Lean.Elab.Deriving.Basic
|
| 9 |
+
import Lean.Elab.Deriving.Util
|
| 10 |
+
|
| 11 |
+
namespace Lean.Elab.Deriving.BEq
|
| 12 |
+
open Lean.Parser.Term
|
| 13 |
+
open Meta
|
| 14 |
+
|
| 15 |
+
def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do
|
| 16 |
+
mkHeader `BEq 2 indVal
|
| 17 |
+
|
| 18 |
+
def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
|
| 19 |
+
let discrs ← mkDiscrs header indVal
|
| 20 |
+
let alts ← mkAlts
|
| 21 |
+
`(match $[$discrs],* with $alts:matchAlt*)
|
| 22 |
+
where
|
| 23 |
+
mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do
|
| 24 |
+
let mut patterns := #[]
|
| 25 |
+
-- add `_` pattern for indices
|
| 26 |
+
for _ in [:indVal.numIndices] do
|
| 27 |
+
patterns := patterns.push (← `(_))
|
| 28 |
+
patterns := patterns.push (← `(_))
|
| 29 |
+
patterns := patterns.push (← `(_))
|
| 30 |
+
let altRhs ← `(false)
|
| 31 |
+
`(matchAltExpr| | $[$patterns:term],* => $altRhs:term)
|
| 32 |
+
|
| 33 |
+
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
|
| 34 |
+
let mut alts := #[]
|
| 35 |
+
for ctorName in indVal.ctors do
|
| 36 |
+
let ctorInfo ← getConstInfoCtor ctorName
|
| 37 |
+
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
| 38 |
+
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
| 39 |
+
let mut patterns := #[]
|
| 40 |
+
-- add `_` pattern for indices
|
| 41 |
+
for _ in [:indVal.numIndices] do
|
| 42 |
+
patterns := patterns.push (← `(_))
|
| 43 |
+
let mut ctorArgs1 := #[]
|
| 44 |
+
let mut ctorArgs2 := #[]
|
| 45 |
+
let mut rhs ← `(true)
|
| 46 |
+
let mut rhs_empty := true
|
| 47 |
+
for i in [:ctorInfo.numFields] do
|
| 48 |
+
let pos := indVal.numParams + ctorInfo.numFields - i - 1
|
| 49 |
+
let x := xs[pos]!
|
| 50 |
+
if type.containsFVar x.fvarId! then
|
| 51 |
+
-- If resulting type depends on this field, we don't need to compare
|
| 52 |
+
ctorArgs1 := ctorArgs1.push (← `(_))
|
| 53 |
+
ctorArgs2 := ctorArgs2.push (← `(_))
|
| 54 |
+
else
|
| 55 |
+
let a := mkIdent (← mkFreshUserName `a)
|
| 56 |
+
let b := mkIdent (← mkFreshUserName `b)
|
| 57 |
+
ctorArgs1 := ctorArgs1.push a
|
| 58 |
+
ctorArgs2 := ctorArgs2.push b
|
| 59 |
+
let xType ← inferType x
|
| 60 |
+
if (← isProp xType) then
|
| 61 |
+
continue
|
| 62 |
+
if xType.isAppOf indVal.name then
|
| 63 |
+
if rhs_empty then
|
| 64 |
+
rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident)
|
| 65 |
+
rhs_empty := false
|
| 66 |
+
else
|
| 67 |
+
rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident && $rhs)
|
| 68 |
+
/- If `x` appears in the type of another field, use `eq_of_beq` to
|
| 69 |
+
unify the types of the subsequent variables -/
|
| 70 |
+
else if ← xs[(pos+1)...*].anyM
|
| 71 |
+
(fun fvar => (Expr.containsFVar · x.fvarId!) <$> (inferType fvar)) then
|
| 72 |
+
rhs ← `(if h : $a:ident == $b:ident then by
|
| 73 |
+
cases (eq_of_beq h)
|
| 74 |
+
exact $rhs
|
| 75 |
+
else false)
|
| 76 |
+
rhs_empty := false
|
| 77 |
+
else
|
| 78 |
+
if rhs_empty then
|
| 79 |
+
rhs ← `($a:ident == $b:ident)
|
| 80 |
+
rhs_empty := false
|
| 81 |
+
else
|
| 82 |
+
rhs ← `($a:ident == $b:ident && $rhs)
|
| 83 |
+
-- add `_` for inductive parameters, they are inaccessible
|
| 84 |
+
for _ in [:indVal.numParams] do
|
| 85 |
+
ctorArgs1 := ctorArgs1.push (← `(_))
|
| 86 |
+
ctorArgs2 := ctorArgs2.push (← `(_))
|
| 87 |
+
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs1.reverse:term*))
|
| 88 |
+
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*))
|
| 89 |
+
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
| 90 |
+
alts := alts.push alt
|
| 91 |
+
alts := alts.push (← mkElseAlt)
|
| 92 |
+
return alts
|
| 93 |
+
|
| 94 |
+
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
|
| 95 |
+
let auxFunName := ctx.auxFunNames[i]!
|
| 96 |
+
let indVal := ctx.typeInfos[i]!
|
| 97 |
+
let header ← mkBEqHeader indVal
|
| 98 |
+
let mut body ← mkMatch header indVal auxFunName
|
| 99 |
+
if ctx.usePartial then
|
| 100 |
+
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
|
| 101 |
+
body ← mkLet letDecls body
|
| 102 |
+
let binders := header.binders
|
| 103 |
+
if ctx.usePartial then
|
| 104 |
+
`(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
|
| 105 |
+
else
|
| 106 |
+
`(@[expose] def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
|
| 107 |
+
|
| 108 |
+
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
| 109 |
+
let mut auxDefs := #[]
|
| 110 |
+
for i in [:ctx.typeInfos.size] do
|
| 111 |
+
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
|
| 112 |
+
`(mutual
|
| 113 |
+
set_option match.ignoreUnusedAlts true
|
| 114 |
+
$auxDefs:command*
|
| 115 |
+
end)
|
| 116 |
+
|
| 117 |
+
private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
|
| 118 |
+
let ctx ← mkContext "beq" declName
|
| 119 |
+
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
|
| 120 |
+
trace[Elab.Deriving.beq] "\n{cmds}"
|
| 121 |
+
return cmds
|
| 122 |
+
|
| 123 |
+
private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
|
| 124 |
+
let auxFunName := ctx.auxFunNames[0]!
|
| 125 |
+
`(@[expose] def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx)
|
| 126 |
+
|
| 127 |
+
private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
|
| 128 |
+
let ctx ← mkContext "beq" name
|
| 129 |
+
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
|
| 130 |
+
trace[Elab.Deriving.beq] "\n{cmds}"
|
| 131 |
+
return cmds
|
| 132 |
+
|
| 133 |
+
open Command
|
| 134 |
+
|
| 135 |
+
def mkBEqInstance (declName : Name) : CommandElabM Unit := do
|
| 136 |
+
let cmds ← liftTermElabM <|
|
| 137 |
+
if (← isEnumType declName) then
|
| 138 |
+
mkBEqEnumCmd declName
|
| 139 |
+
else
|
| 140 |
+
mkBEqInstanceCmds declName
|
| 141 |
+
cmds.forM elabCommand
|
| 142 |
+
|
| 143 |
+
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
| 144 |
+
if (← declNames.allM isInductive) then
|
| 145 |
+
for declName in declNames do
|
| 146 |
+
mkBEqInstance declName
|
| 147 |
+
return true
|
| 148 |
+
else
|
| 149 |
+
return false
|
| 150 |
+
|
| 151 |
+
builtin_initialize
|
| 152 |
+
registerDerivingHandler `BEq mkBEqInstanceHandler
|
| 153 |
+
registerTraceClass `Elab.Deriving.beq
|
| 154 |
+
|
| 155 |
+
end Lean.Elab.Deriving.BEq
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/Basic.lean
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Leonardo de Moura, Wojciech Nawrocki
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Elab.Command
|
| 8 |
+
import Lean.Elab.DeclarationRange
|
| 9 |
+
|
| 10 |
+
namespace Lean.Elab
|
| 11 |
+
open Command
|
| 12 |
+
|
| 13 |
+
namespace Term
|
| 14 |
+
open Meta
|
| 15 |
+
|
| 16 |
+
/-- Result for `mkInst?` -/
|
| 17 |
+
structure MkInstResult where
|
| 18 |
+
instVal : Expr
|
| 19 |
+
instType : Expr
|
| 20 |
+
outParams : Array Expr := #[]
|
| 21 |
+
|
| 22 |
+
/--
|
| 23 |
+
Construct an instance for `className out₁ ... outₙ type`.
|
| 24 |
+
The method support classes with a prefix of `outParam`s (e.g. `MonadReader`). -/
|
| 25 |
+
private partial def mkInst? (className : Name) (type : Expr) : MetaM (Option MkInstResult) := do
|
| 26 |
+
let rec go? (instType instTypeType : Expr) (outParams : Array Expr) : MetaM (Option MkInstResult) := do
|
| 27 |
+
let instTypeType ← whnfD instTypeType
|
| 28 |
+
unless instTypeType.isForall do
|
| 29 |
+
return none
|
| 30 |
+
let d := instTypeType.bindingDomain!
|
| 31 |
+
if d.isOutParam then
|
| 32 |
+
let mvar ← mkFreshExprMVar d
|
| 33 |
+
go? (mkApp instType mvar) (instTypeType.bindingBody!.instantiate1 mvar) (outParams.push mvar)
|
| 34 |
+
else
|
| 35 |
+
unless (← isDefEqGuarded (← inferType type) d) do
|
| 36 |
+
return none
|
| 37 |
+
let instType ← instantiateMVars (mkApp instType type)
|
| 38 |
+
let instVal ← synthInstance instType
|
| 39 |
+
return some { instVal, instType, outParams }
|
| 40 |
+
let instType ← mkConstWithFreshMVarLevels className
|
| 41 |
+
go? instType (← inferType instType) #[]
|
| 42 |
+
|
| 43 |
+
def processDefDeriving (className : Name) (declName : Name) : TermElabM Bool := do
|
| 44 |
+
try
|
| 45 |
+
let ConstantInfo.defnInfo info ← getConstInfo declName | return false
|
| 46 |
+
let some result ← mkInst? className info.value | return false
|
| 47 |
+
let instTypeNew := mkApp result.instType.appFn! (Lean.mkConst declName (info.levelParams.map mkLevelParam))
|
| 48 |
+
Meta.check instTypeNew
|
| 49 |
+
let instName ← liftMacroM <| mkUnusedBaseName (declName.appendBefore "inst" |>.appendAfter className.getString!)
|
| 50 |
+
addAndCompile <| Declaration.defnDecl {
|
| 51 |
+
name := instName
|
| 52 |
+
levelParams := info.levelParams
|
| 53 |
+
type := (← instantiateMVars instTypeNew)
|
| 54 |
+
value := (← instantiateMVars result.instVal)
|
| 55 |
+
hints := info.hints
|
| 56 |
+
safety := info.safety
|
| 57 |
+
}
|
| 58 |
+
addInstance instName AttributeKind.global (eval_prio default)
|
| 59 |
+
addDeclarationRangesFromSyntax instName (← getRef)
|
| 60 |
+
return true
|
| 61 |
+
catch _ =>
|
| 62 |
+
return false
|
| 63 |
+
|
| 64 |
+
end Term
|
| 65 |
+
|
| 66 |
+
def DerivingHandler := (typeNames : Array Name) → CommandElabM Bool
|
| 67 |
+
|
| 68 |
+
builtin_initialize derivingHandlersRef : IO.Ref (NameMap (List DerivingHandler)) ← IO.mkRef {}
|
| 69 |
+
|
| 70 |
+
/-- A `DerivingHandler` is called on the fully qualified names of all types it is running for
|
| 71 |
+
as well as the syntax of a `with` argument, if present.
|
| 72 |
+
|
| 73 |
+
For example, `deriving instance Foo with fooArgs for Bar, Baz` invokes
|
| 74 |
+
``fooHandler #[`Bar, `Baz] `(fooArgs)``. -/
|
| 75 |
+
def registerDerivingHandler (className : Name) (handler : DerivingHandler) : IO Unit := do
|
| 76 |
+
unless (← initializing) do
|
| 77 |
+
throw (IO.userError "failed to register deriving handler, it can only be registered during initialization")
|
| 78 |
+
derivingHandlersRef.modify fun m => match m.find? className with
|
| 79 |
+
| some handlers => m.insert className (handler :: handlers)
|
| 80 |
+
| none => m.insert className [handler]
|
| 81 |
+
|
| 82 |
+
def defaultHandler (className : Name) (typeNames : Array Name) : CommandElabM Unit := do
|
| 83 |
+
throwError "default handlers have not been implemented yet, class: '{className}' types: {typeNames}"
|
| 84 |
+
|
| 85 |
+
def applyDerivingHandlers (className : Name) (typeNames : Array Name) : CommandElabM Unit := do
|
| 86 |
+
withTraceNode `Elab.Deriving (fun _ => return m!"running deriving handlers for '{className}'") do
|
| 87 |
+
match (← derivingHandlersRef.get).find? className with
|
| 88 |
+
| some handlers =>
|
| 89 |
+
for handler in handlers do
|
| 90 |
+
if (← handler typeNames) then
|
| 91 |
+
return ()
|
| 92 |
+
defaultHandler className typeNames
|
| 93 |
+
| none => defaultHandler className typeNames
|
| 94 |
+
|
| 95 |
+
private def tryApplyDefHandler (className : Name) (declName : Name) : CommandElabM Bool :=
|
| 96 |
+
liftTermElabM do
|
| 97 |
+
Term.processDefDeriving className declName
|
| 98 |
+
|
| 99 |
+
@[builtin_command_elab «deriving»] def elabDeriving : CommandElab
|
| 100 |
+
| `(deriving instance $[$classes],* for $[$declNames],*) => do
|
| 101 |
+
let declNames ← liftCoreM <| declNames.mapM realizeGlobalConstNoOverloadWithInfo
|
| 102 |
+
for cls in classes do
|
| 103 |
+
try
|
| 104 |
+
let className ← liftCoreM <| realizeGlobalConstNoOverloadWithInfo cls
|
| 105 |
+
withRef cls do
|
| 106 |
+
if declNames.size == 1 then
|
| 107 |
+
if (← tryApplyDefHandler className declNames[0]!) then
|
| 108 |
+
return ()
|
| 109 |
+
applyDerivingHandlers className declNames
|
| 110 |
+
catch ex =>
|
| 111 |
+
logException ex
|
| 112 |
+
| _ => throwUnsupportedSyntax
|
| 113 |
+
|
| 114 |
+
structure DerivingClassView where
|
| 115 |
+
ref : Syntax
|
| 116 |
+
className : Name
|
| 117 |
+
|
| 118 |
+
def getOptDerivingClasses (optDeriving : Syntax) : CoreM (Array DerivingClassView) := do
|
| 119 |
+
match optDeriving with
|
| 120 |
+
| `(Parser.Command.optDeriving| deriving $[$classes],*) =>
|
| 121 |
+
let mut ret := #[]
|
| 122 |
+
for cls in classes do
|
| 123 |
+
let className ← realizeGlobalConstNoOverloadWithInfo cls
|
| 124 |
+
ret := ret.push { ref := cls, className := className }
|
| 125 |
+
return ret
|
| 126 |
+
| _ => return #[]
|
| 127 |
+
|
| 128 |
+
def DerivingClassView.applyHandlers (view : DerivingClassView) (declNames : Array Name) : CommandElabM Unit :=
|
| 129 |
+
withRef view.ref do applyDerivingHandlers view.className declNames
|
| 130 |
+
|
| 131 |
+
builtin_initialize
|
| 132 |
+
registerTraceClass `Elab.Deriving
|
| 133 |
+
|
| 134 |
+
end Lean.Elab
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/DecEq.lean
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Leonardo de Moura
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Meta.Transform
|
| 8 |
+
import Lean.Meta.Inductive
|
| 9 |
+
import Lean.Elab.Deriving.Basic
|
| 10 |
+
import Lean.Elab.Deriving.Util
|
| 11 |
+
|
| 12 |
+
namespace Lean.Elab.Deriving.DecEq
|
| 13 |
+
open Lean.Parser.Term
|
| 14 |
+
open Meta
|
| 15 |
+
|
| 16 |
+
def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
|
| 17 |
+
mkHeader `DecidableEq 2 indVal
|
| 18 |
+
|
| 19 |
+
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
|
| 20 |
+
let discrs ← mkDiscrs header indVal
|
| 21 |
+
let alts ← mkAlts
|
| 22 |
+
`(match $[$discrs],* with $alts:matchAlt*)
|
| 23 |
+
where
|
| 24 |
+
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
|
| 25 |
+
| [] => ``(isTrue rfl)
|
| 26 |
+
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
|
| 27 |
+
let rhs ← if isProof then
|
| 28 |
+
`(have h : @$a = @$b := rfl; by subst h; exact $(← mkSameCtorRhs todo):term)
|
| 29 |
+
else
|
| 30 |
+
let sameCtor ← mkSameCtorRhs todo
|
| 31 |
+
`(if h : @$a = @$b then
|
| 32 |
+
by subst h; exact $sameCtor:term
|
| 33 |
+
else
|
| 34 |
+
isFalse (by intro n; injection n; apply h _; assumption))
|
| 35 |
+
if let some auxFunName := recField then
|
| 36 |
+
-- add local instance for `a = b` using the function being defined `auxFunName`
|
| 37 |
+
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
|
| 38 |
+
else
|
| 39 |
+
return rhs
|
| 40 |
+
|
| 41 |
+
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
|
| 42 |
+
let mut alts := #[]
|
| 43 |
+
for ctorName₁ in indVal.ctors do
|
| 44 |
+
let ctorInfo ← getConstInfoCtor ctorName₁
|
| 45 |
+
for ctorName₂ in indVal.ctors do
|
| 46 |
+
let mut patterns := #[]
|
| 47 |
+
-- add `_` pattern for indices
|
| 48 |
+
for _ in [:indVal.numIndices] do
|
| 49 |
+
patterns := patterns.push (← `(_))
|
| 50 |
+
if ctorName₁ == ctorName₂ then
|
| 51 |
+
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
| 52 |
+
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
| 53 |
+
let mut patterns := patterns
|
| 54 |
+
let mut ctorArgs1 := #[]
|
| 55 |
+
let mut ctorArgs2 := #[]
|
| 56 |
+
-- add `_` for inductive parameters, they are inaccessible
|
| 57 |
+
for _ in [:indVal.numParams] do
|
| 58 |
+
ctorArgs1 := ctorArgs1.push (← `(_))
|
| 59 |
+
ctorArgs2 := ctorArgs2.push (← `(_))
|
| 60 |
+
let mut todo := #[]
|
| 61 |
+
for i in [:ctorInfo.numFields] do
|
| 62 |
+
let x := xs[indVal.numParams + i]!
|
| 63 |
+
if type.containsFVar x.fvarId! then
|
| 64 |
+
-- If resulting type depends on this field, we don't need to compare
|
| 65 |
+
ctorArgs1 := ctorArgs1.push (← `(_))
|
| 66 |
+
ctorArgs2 := ctorArgs2.push (← `(_))
|
| 67 |
+
else
|
| 68 |
+
let a := mkIdent (← mkFreshUserName `a)
|
| 69 |
+
let b := mkIdent (← mkFreshUserName `b)
|
| 70 |
+
ctorArgs1 := ctorArgs1.push a
|
| 71 |
+
ctorArgs2 := ctorArgs2.push b
|
| 72 |
+
let xType ← inferType x
|
| 73 |
+
let indValNum :=
|
| 74 |
+
ctx.typeInfos.findIdx?
|
| 75 |
+
(xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
|
| 76 |
+
let recField := indValNum.map (ctx.auxFunNames[·]!)
|
| 77 |
+
let isProof ← isProp xType
|
| 78 |
+
todo := todo.push (a, b, recField, isProof)
|
| 79 |
+
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
|
| 80 |
+
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
|
| 81 |
+
let rhs ← mkSameCtorRhs todo.toList
|
| 82 |
+
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
| 83 |
+
alts := alts.push alt
|
| 84 |
+
else if (← compatibleCtors ctorName₁ ctorName₂) then
|
| 85 |
+
patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
|
| 86 |
+
let rhs ← `(isFalse (by intro h; injection h))
|
| 87 |
+
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
|
| 88 |
+
return alts
|
| 89 |
+
|
| 90 |
+
def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
|
| 91 |
+
let header ← mkDecEqHeader indVal
|
| 92 |
+
let body ← mkMatch ctx header indVal
|
| 93 |
+
let binders := header.binders
|
| 94 |
+
let target₁ := mkIdent header.targetNames[0]!
|
| 95 |
+
let target₂ := mkIdent header.targetNames[1]!
|
| 96 |
+
let termSuffix ← if indVal.isRec
|
| 97 |
+
then `(Parser.Termination.suffix|termination_by structural $target₁)
|
| 98 |
+
else `(Parser.Termination.suffix|)
|
| 99 |
+
let type ← `(Decidable ($target₁ = $target₂))
|
| 100 |
+
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
|
| 101 |
+
$termSuffix:suffix)
|
| 102 |
+
|
| 103 |
+
def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
|
| 104 |
+
let mut res : Array (TSyntax `command) := #[]
|
| 105 |
+
for i in [:ctx.auxFunNames.size] do
|
| 106 |
+
let auxFunName := ctx.auxFunNames[i]!
|
| 107 |
+
let indVal := ctx.typeInfos[i]!
|
| 108 |
+
res := res.push (← mkAuxFunction ctx auxFunName indVal)
|
| 109 |
+
`(command| mutual $[$res:command]* end)
|
| 110 |
+
|
| 111 |
+
def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
|
| 112 |
+
let ctx ← mkContext "decEq" indVal.name
|
| 113 |
+
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
|
| 114 |
+
trace[Elab.Deriving.decEq] "\n{cmds}"
|
| 115 |
+
return cmds
|
| 116 |
+
|
| 117 |
+
open Command
|
| 118 |
+
|
| 119 |
+
def mkDecEq (declName : Name) : CommandElabM Bool := do
|
| 120 |
+
let indVal ← getConstInfoInduct declName
|
| 121 |
+
if indVal.isNested then
|
| 122 |
+
return false -- nested inductive types are not supported yet
|
| 123 |
+
else
|
| 124 |
+
let cmds ← liftTermElabM <| mkDecEqCmds indVal
|
| 125 |
+
-- `cmds` can have a number of syntax nodes quadratic in the number of constructors
|
| 126 |
+
-- and thus create as many info tree nodes, which we never make use of but which can
|
| 127 |
+
-- significantly slow down e.g. the unused variables linter; avoid creating them
|
| 128 |
+
withEnableInfoTree false do
|
| 129 |
+
cmds.forM elabCommand
|
| 130 |
+
return true
|
| 131 |
+
|
| 132 |
+
partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
|
| 133 |
+
let indVal ← getConstInfoInduct declName
|
| 134 |
+
let enumType := mkConst declName
|
| 135 |
+
let ctors := indVal.ctors.toArray
|
| 136 |
+
withLocalDeclD `n (mkConst ``Nat) fun n => do
|
| 137 |
+
let cond := mkConst ``cond [1]
|
| 138 |
+
let rec mkDecTree (low high : Nat) : Expr :=
|
| 139 |
+
if low + 1 == high then
|
| 140 |
+
mkConst ctors[low]!
|
| 141 |
+
else if low + 2 == high then
|
| 142 |
+
mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkRawNatLit low)) (mkConst ctors[low]!) (mkConst ctors[low+1]!)
|
| 143 |
+
else
|
| 144 |
+
let mid := (low + high)/2
|
| 145 |
+
let lowBranch := mkDecTree low mid
|
| 146 |
+
let highBranch := mkDecTree mid high
|
| 147 |
+
mkApp4 cond enumType (mkApp2 (mkConst ``Nat.ble) (mkRawNatLit mid) n) highBranch lowBranch
|
| 148 |
+
let value ← mkLambdaFVars #[n] (mkDecTree 0 ctors.size)
|
| 149 |
+
let type ← mkArrow (mkConst ``Nat) enumType
|
| 150 |
+
addAndCompile <| Declaration.defnDecl {
|
| 151 |
+
name := Name.mkStr declName "ofNat"
|
| 152 |
+
levelParams := []
|
| 153 |
+
safety := DefinitionSafety.safe
|
| 154 |
+
hints := ReducibilityHints.abbrev
|
| 155 |
+
value, type
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
|
| 159 |
+
let indVal ← getConstInfoInduct declName
|
| 160 |
+
let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx")
|
| 161 |
+
let ofNat := mkConst (Name.mkStr declName "ofNat")
|
| 162 |
+
let enumType := mkConst declName
|
| 163 |
+
let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType
|
| 164 |
+
let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType
|
| 165 |
+
let ctors := indVal.ctors
|
| 166 |
+
withLocalDeclD `x enumType fun x => do
|
| 167 |
+
let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x
|
| 168 |
+
let motive ← mkLambdaFVars #[x] resultType
|
| 169 |
+
let casesOn := mkConst (mkCasesOnName declName) [levelZero]
|
| 170 |
+
let mut value := mkApp2 casesOn motive x
|
| 171 |
+
for ctor in ctors do
|
| 172 |
+
value := mkApp value (mkApp rflEnum (mkConst ctor))
|
| 173 |
+
value ← mkLambdaFVars #[x] value
|
| 174 |
+
let type ← mkForallFVars #[x] resultType
|
| 175 |
+
addAndCompile <| Declaration.thmDecl {
|
| 176 |
+
name := Name.mkStr declName "ofNat_toCtorIdx"
|
| 177 |
+
levelParams := []
|
| 178 |
+
value, type
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
|
| 182 |
+
liftTermElabM <| mkEnumOfNat declName
|
| 183 |
+
liftTermElabM <| mkEnumOfNatThm declName
|
| 184 |
+
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
|
| 185 |
+
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
|
| 186 |
+
let cmd ← `(
|
| 187 |
+
instance : DecidableEq $(mkIdent declName) :=
|
| 188 |
+
fun x y =>
|
| 189 |
+
if h : x.toCtorIdx = y.toCtorIdx then
|
| 190 |
+
-- We use `rfl` in the following proof because the first script fails for unit-like datatypes due to etaStruct.
|
| 191 |
+
isTrue (by first | have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption | rfl)
|
| 192 |
+
else
|
| 193 |
+
isFalse fun h => by subst h; contradiction
|
| 194 |
+
)
|
| 195 |
+
trace[Elab.Deriving.decEq] "\n{cmd}"
|
| 196 |
+
elabCommand cmd
|
| 197 |
+
|
| 198 |
+
def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
|
| 199 |
+
if (← isEnumType declName) then
|
| 200 |
+
mkDecEqEnum declName
|
| 201 |
+
return true
|
| 202 |
+
else
|
| 203 |
+
mkDecEq declName
|
| 204 |
+
|
| 205 |
+
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
| 206 |
+
declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
|
| 207 |
+
|
| 208 |
+
builtin_initialize
|
| 209 |
+
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
|
| 210 |
+
registerTraceClass `Elab.Deriving.decEq
|
| 211 |
+
|
| 212 |
+
end Lean.Elab.Deriving.DecEq
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/FromToJson.lean
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2020 Sebastian Ullrich. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Sebastian Ullrich, Dany Fabian
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Meta.Transform
|
| 8 |
+
import Lean.Elab.Deriving.Basic
|
| 9 |
+
import Lean.Elab.Deriving.Util
|
| 10 |
+
import Lean.Data.Json.FromToJson
|
| 11 |
+
|
| 12 |
+
namespace Lean.Elab.Deriving.FromToJson
|
| 13 |
+
open Lean.Elab.Command
|
| 14 |
+
open Lean.Json
|
| 15 |
+
open Lean.Parser.Term
|
| 16 |
+
open Lean.Meta
|
| 17 |
+
|
| 18 |
+
def mkToJsonHeader (indVal : InductiveVal) : TermElabM Header := do
|
| 19 |
+
mkHeader ``ToJson 1 indVal
|
| 20 |
+
|
| 21 |
+
def mkFromJsonHeader (indVal : InductiveVal) : TermElabM Header := do
|
| 22 |
+
let header ← mkHeader ``FromJson 0 indVal
|
| 23 |
+
let jsonArg ← `(bracketedBinderF|(json : Json))
|
| 24 |
+
return {header with
|
| 25 |
+
binders := header.binders.push jsonArg}
|
| 26 |
+
|
| 27 |
+
def mkJsonField (n : Name) : CoreM (Bool × Term) := do
|
| 28 |
+
let .str .anonymous s := n | throwError "invalid json field name {n}"
|
| 29 |
+
let s₁ := s.dropRightWhile (· == '?')
|
| 30 |
+
return (s != s₁, Syntax.mkStrLit s₁)
|
| 31 |
+
|
| 32 |
+
def mkToJsonBodyForStruct (header : Header) (indName : Name) : TermElabM Term := do
|
| 33 |
+
let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false)
|
| 34 |
+
let fields ← fields.mapM fun field => do
|
| 35 |
+
let (isOptField, nm) ← mkJsonField field
|
| 36 |
+
let target := mkIdent header.targetNames[0]!
|
| 37 |
+
if isOptField then ``(opt $nm $target.$(mkIdent field))
|
| 38 |
+
else ``([($nm, toJson ($target).$(mkIdent field))])
|
| 39 |
+
`(mkObj <| List.flatten [$fields,*])
|
| 40 |
+
|
| 41 |
+
def mkToJsonBodyForInduct (ctx : Context) (header : Header) (indName : Name) : TermElabM Term := do
|
| 42 |
+
let indVal ← getConstInfoInduct indName
|
| 43 |
+
let toJsonFuncId := mkIdent ctx.auxFunNames[0]!
|
| 44 |
+
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
|
| 45 |
+
-- if `id`'s type is the type we're deriving for.
|
| 46 |
+
let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do
|
| 47 |
+
if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
|
| 48 |
+
else ``(toJson $id:ident)
|
| 49 |
+
let discrs ← mkDiscrs header indVal
|
| 50 |
+
let alts ← mkAlts indVal fun ctor args userNames => do
|
| 51 |
+
let ctorStr := ctor.name.eraseMacroScopes.getString!
|
| 52 |
+
match args, userNames with
|
| 53 |
+
| #[], _ => ``(toJson $(quote ctorStr))
|
| 54 |
+
| #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))])
|
| 55 |
+
| xs, none =>
|
| 56 |
+
let xs ← xs.mapM fun (x, t) => mkToJson x t
|
| 57 |
+
``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])])
|
| 58 |
+
| xs, some userNames =>
|
| 59 |
+
let xs ← xs.mapIdxM fun idx (x, t) => do
|
| 60 |
+
`(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t)))
|
| 61 |
+
``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])])
|
| 62 |
+
`(match $[$discrs],* with $alts:matchAlt*)
|
| 63 |
+
|
| 64 |
+
where
|
| 65 |
+
mkAlts
|
| 66 |
+
(indVal : InductiveVal)
|
| 67 |
+
(rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term): TermElabM (Array (TSyntax ``matchAlt)) := do
|
| 68 |
+
let mut alts := #[]
|
| 69 |
+
for ctorName in indVal.ctors do
|
| 70 |
+
let ctorInfo ← getConstInfoCtor ctorName
|
| 71 |
+
let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do
|
| 72 |
+
let mut patterns := #[]
|
| 73 |
+
-- add `_` pattern for indices
|
| 74 |
+
for _ in [:indVal.numIndices] do
|
| 75 |
+
patterns := patterns.push (← `(_))
|
| 76 |
+
let mut ctorArgs := #[]
|
| 77 |
+
-- add `_` for inductive parameters, they are inaccessible
|
| 78 |
+
for _ in [:indVal.numParams] do
|
| 79 |
+
ctorArgs := ctorArgs.push (← `(_))
|
| 80 |
+
-- bound constructor arguments and their types
|
| 81 |
+
let mut binders := #[]
|
| 82 |
+
let mut userNames := #[]
|
| 83 |
+
for i in [:ctorInfo.numFields] do
|
| 84 |
+
let x := xs[indVal.numParams + i]!
|
| 85 |
+
let localDecl ← x.fvarId!.getDecl
|
| 86 |
+
if !localDecl.userName.hasMacroScopes then
|
| 87 |
+
userNames := userNames.push localDecl.userName
|
| 88 |
+
let a := mkIdent (← mkFreshUserName `a)
|
| 89 |
+
binders := binders.push (a, localDecl.type)
|
| 90 |
+
ctorArgs := ctorArgs.push a
|
| 91 |
+
patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*))
|
| 92 |
+
let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none)
|
| 93 |
+
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
| 94 |
+
alts := alts.push alt
|
| 95 |
+
return alts
|
| 96 |
+
|
| 97 |
+
def mkFromJsonBodyForStruct (indName : Name) : TermElabM Term := do
|
| 98 |
+
let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false)
|
| 99 |
+
let getters ← fields.mapM (fun field => do
|
| 100 |
+
let getter ← `(getObjValAs? json _ $(Prod.snd <| ← mkJsonField field))
|
| 101 |
+
let getter ← `(doElem| Except.mapError (fun s => (toString $(quote indName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter)
|
| 102 |
+
return getter
|
| 103 |
+
)
|
| 104 |
+
let fields := fields.map mkIdent
|
| 105 |
+
`(do
|
| 106 |
+
$[let $fields:ident ← $getters]*
|
| 107 |
+
return { $[$fields:ident := $(id fields)],* })
|
| 108 |
+
|
| 109 |
+
def mkFromJsonBodyForInduct (ctx : Context) (indName : Name) : TermElabM Term := do
|
| 110 |
+
let indVal ← getConstInfoInduct indName
|
| 111 |
+
let alts ← mkAlts indVal
|
| 112 |
+
let auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched"))
|
| 113 |
+
`($auxTerm)
|
| 114 |
+
where
|
| 115 |
+
mkAlts (indVal : InductiveVal) : TermElabM (Array Term) := do
|
| 116 |
+
let mut alts := #[]
|
| 117 |
+
for ctorName in indVal.ctors do
|
| 118 |
+
let ctorInfo ← getConstInfoCtor ctorName
|
| 119 |
+
let alt ← do forallTelescopeReducing ctorInfo.type fun xs _ => do
|
| 120 |
+
let mut binders := #[]
|
| 121 |
+
let mut userNames := #[]
|
| 122 |
+
for i in [:ctorInfo.numFields] do
|
| 123 |
+
let x := xs[indVal.numParams + i]!
|
| 124 |
+
let localDecl ← x.fvarId!.getDecl
|
| 125 |
+
if !localDecl.userName.hasMacroScopes then
|
| 126 |
+
userNames := userNames.push localDecl.userName
|
| 127 |
+
let a := mkIdent (← mkFreshUserName `a)
|
| 128 |
+
binders := binders.push (a, localDecl.type)
|
| 129 |
+
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]!
|
| 130 |
+
-- Return syntax to parse `id`, either via `FromJson` or recursively
|
| 131 |
+
-- if `id`'s type is the type we're deriving for.
|
| 132 |
+
let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) :=
|
| 133 |
+
if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)]!)
|
| 134 |
+
else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!)
|
| 135 |
+
let identNames := binders.map Prod.fst
|
| 136 |
+
let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type
|
| 137 |
+
let userNamesOpt ← if binders.size == userNames.size then
|
| 138 |
+
``(some #[$[$(userNames.map quote)],*])
|
| 139 |
+
else
|
| 140 |
+
``(none)
|
| 141 |
+
let stx ←
|
| 142 |
+
`((Json.parseTagged json $(quote ctorName.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind
|
| 143 |
+
(fun jsons => do
|
| 144 |
+
$[let $identNames:ident ← $fromJsons:doExpr]*
|
| 145 |
+
return $(mkIdent ctorName):ident $identNames*))
|
| 146 |
+
pure (stx, ctorInfo.numFields)
|
| 147 |
+
alts := alts.push alt
|
| 148 |
+
-- the smaller cases, especially the ones without fields are likely faster
|
| 149 |
+
let alts' := alts.qsort (fun (_, x) (_, y) => x < y)
|
| 150 |
+
return alts'.map Prod.fst
|
| 151 |
+
|
| 152 |
+
def mkToJsonBody (ctx : Context) (header : Header) (e : Expr): TermElabM Term := do
|
| 153 |
+
let indName := e.getAppFn.constName!
|
| 154 |
+
if isStructure (← getEnv) indName then
|
| 155 |
+
mkToJsonBodyForStruct header indName
|
| 156 |
+
else
|
| 157 |
+
mkToJsonBodyForInduct ctx header indName
|
| 158 |
+
|
| 159 |
+
def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
|
| 160 |
+
let auxFunName := ctx.auxFunNames[i]!
|
| 161 |
+
let header ← mkToJsonHeader ctx.typeInfos[i]!
|
| 162 |
+
let binders := header.binders
|
| 163 |
+
Term.elabBinders binders fun _ => do
|
| 164 |
+
let type ← Term.elabTerm header.targetType none
|
| 165 |
+
let mut body ← mkToJsonBody ctx header type
|
| 166 |
+
if ctx.usePartial then
|
| 167 |
+
let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames
|
| 168 |
+
body ← mkLet letDecls body
|
| 169 |
+
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
|
| 170 |
+
else
|
| 171 |
+
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
|
| 172 |
+
|
| 173 |
+
def mkFromJsonBody (ctx : Context) (e : Expr) : TermElabM Term := do
|
| 174 |
+
let indName := e.getAppFn.constName!
|
| 175 |
+
if isStructure (← getEnv) indName then
|
| 176 |
+
mkFromJsonBodyForStruct indName
|
| 177 |
+
else
|
| 178 |
+
mkFromJsonBodyForInduct ctx indName
|
| 179 |
+
|
| 180 |
+
def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
|
| 181 |
+
let auxFunName := ctx.auxFunNames[i]!
|
| 182 |
+
let indval := ctx.typeInfos[i]!
|
| 183 |
+
let header ← mkFromJsonHeader indval --TODO fix header info
|
| 184 |
+
let binders := header.binders
|
| 185 |
+
Term.elabBinders binders fun _ => do
|
| 186 |
+
let type ← Term.elabTerm header.targetType none
|
| 187 |
+
let mut body ← mkFromJsonBody ctx type
|
| 188 |
+
if ctx.usePartial || indval.isRec then
|
| 189 |
+
let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames
|
| 190 |
+
body ← mkLet letDecls body
|
| 191 |
+
`(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term)
|
| 192 |
+
else
|
| 193 |
+
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do
|
| 197 |
+
let mut auxDefs := #[]
|
| 198 |
+
for i in [:ctx.typeInfos.size] do
|
| 199 |
+
auxDefs := auxDefs.push (← mkToJsonAuxFunction ctx i)
|
| 200 |
+
`(mutual
|
| 201 |
+
$auxDefs:command*
|
| 202 |
+
end)
|
| 203 |
+
|
| 204 |
+
def mkFromJsonMutualBlock (ctx : Context) : TermElabM Command := do
|
| 205 |
+
let mut auxDefs := #[]
|
| 206 |
+
for i in [:ctx.typeInfos.size] do
|
| 207 |
+
auxDefs := auxDefs.push (← mkFromJsonAuxFunction ctx i)
|
| 208 |
+
`(mutual
|
| 209 |
+
$auxDefs:command*
|
| 210 |
+
end)
|
| 211 |
+
|
| 212 |
+
private def mkToJsonInstance (declName : Name) : TermElabM (Array Command) := do
|
| 213 |
+
let ctx ← mkContext "toJson" declName
|
| 214 |
+
let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson #[declName])
|
| 215 |
+
trace[Elab.Deriving.toJson] "\n{cmds}"
|
| 216 |
+
return cmds
|
| 217 |
+
|
| 218 |
+
private def mkFromJsonInstance (declName : Name) : TermElabM (Array Command) := do
|
| 219 |
+
let ctx ← mkContext "fromJson" declName
|
| 220 |
+
let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson #[declName])
|
| 221 |
+
trace[Elab.Deriving.fromJson] "\n{cmds}"
|
| 222 |
+
return cmds
|
| 223 |
+
|
| 224 |
+
def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
| 225 |
+
if (← declNames.allM isInductive) && declNames.size > 0 then
|
| 226 |
+
for declName in declNames do
|
| 227 |
+
let cmds ← liftTermElabM <| mkToJsonInstance declName
|
| 228 |
+
cmds.forM elabCommand
|
| 229 |
+
return true
|
| 230 |
+
else
|
| 231 |
+
return false
|
| 232 |
+
|
| 233 |
+
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
| 234 |
+
if (← declNames.allM isInductive) && declNames.size > 0 then
|
| 235 |
+
for declName in declNames do
|
| 236 |
+
let cmds ← liftTermElabM <| mkFromJsonInstance declName
|
| 237 |
+
cmds.forM elabCommand
|
| 238 |
+
return true
|
| 239 |
+
else
|
| 240 |
+
return false
|
| 241 |
+
|
| 242 |
+
builtin_initialize
|
| 243 |
+
registerDerivingHandler ``ToJson mkToJsonInstanceHandler
|
| 244 |
+
registerDerivingHandler ``FromJson mkFromJsonInstanceHandler
|
| 245 |
+
|
| 246 |
+
registerTraceClass `Elab.Deriving.toJson
|
| 247 |
+
registerTraceClass `Elab.Deriving.fromJson
|
| 248 |
+
|
| 249 |
+
end Lean.Elab.Deriving.FromToJson
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Term.lean
ADDED
|
@@ -0,0 +1,2128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Leonardo de Moura, Sebastian Ullrich
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.ReservedNameAction
|
| 8 |
+
import Lean.Meta.AppBuilder
|
| 9 |
+
import Lean.Meta.CollectMVars
|
| 10 |
+
import Lean.Meta.Coe
|
| 11 |
+
import Lean.Util.CollectLevelMVars
|
| 12 |
+
import Lean.Linter.Deprecated
|
| 13 |
+
import Lean.Elab.Config
|
| 14 |
+
import Lean.Elab.Level
|
| 15 |
+
import Lean.Elab.DeclModifiers
|
| 16 |
+
import Lean.Elab.PreDefinition.TerminationHint
|
| 17 |
+
import Lean.Elab.DeclarationRange
|
| 18 |
+
import Lean.Elab.WhereFinally
|
| 19 |
+
import Lean.Language.Basic
|
| 20 |
+
import Lean.Elab.InfoTree.InlayHints
|
| 21 |
+
|
| 22 |
+
namespace Lean.Elab
|
| 23 |
+
|
| 24 |
+
namespace Term
|
| 25 |
+
|
| 26 |
+
/-- Saved context for postponed terms and tactics to be executed. -/
|
| 27 |
+
structure SavedContext where
|
| 28 |
+
declName? : Option Name
|
| 29 |
+
options : Options
|
| 30 |
+
openDecls : List OpenDecl
|
| 31 |
+
macroStack : MacroStack
|
| 32 |
+
errToSorry : Bool
|
| 33 |
+
levelNames : List Name
|
| 34 |
+
|
| 35 |
+
/-- The kind of a tactic metavariable, used for additional error reporting. -/
|
| 36 |
+
inductive TacticMVarKind
|
| 37 |
+
/-- Standard tactic metavariable, arising from `by ...` syntax. -/
|
| 38 |
+
| term
|
| 39 |
+
/-- Tactic metavariable arising from an autoparam for a function application. -/
|
| 40 |
+
| autoParam (argName : Name)
|
| 41 |
+
/-- Tactic metavariable arising from an autoparam for a structure field. -/
|
| 42 |
+
| fieldAutoParam (fieldName structName : Name)
|
| 43 |
+
|
| 44 |
+
/-- We use synthetic metavariables as placeholders for pending elaboration steps. -/
|
| 45 |
+
inductive SyntheticMVarKind where
|
| 46 |
+
/--
|
| 47 |
+
Use typeclass resolution to synthesize value for metavariable.
|
| 48 |
+
If `extraErrorMsg?` is `some msg`, `msg` contains additional information to include in error messages
|
| 49 |
+
regarding type class synthesis failure.
|
| 50 |
+
-/
|
| 51 |
+
| typeClass (extraErrorMsg? : Option MessageData)
|
| 52 |
+
/--
|
| 53 |
+
Use coercion to synthesize value for the metavariable.
|
| 54 |
+
If synthesis fails, then throws an error.
|
| 55 |
+
- If `mkErrorMsg?` is provided, then the error `mkErrorMsg expectedType e` is thrown.
|
| 56 |
+
The `mkErrorMsg` function is allowed to throw an error itself.
|
| 57 |
+
- Otherwise, throws a default type mismatch error message.
|
| 58 |
+
If `header?` is not provided, the default header is "type mismatch".
|
| 59 |
+
If `f?` is provided, then throws an application type mismatch error.
|
| 60 |
+
-/
|
| 61 |
+
| coe (header? : Option String) (expectedType : Expr) (e : Expr) (f? : Option Expr)
|
| 62 |
+
(mkErrorMsg? : Option (MVarId → Expr → Expr → MetaM MessageData))
|
| 63 |
+
/--
|
| 64 |
+
Use tactic to synthesize value for metavariable.
|
| 65 |
+
|
| 66 |
+
If `delayOnMVars` is true, the tactic will not be executed until the goal is free of unassigned
|
| 67 |
+
expr metavariables.
|
| 68 |
+
-/
|
| 69 |
+
| tactic (tacticCode : Syntax) (ctx : SavedContext) (kind : TacticMVarKind) (delayOnMVars := false)
|
| 70 |
+
/-- Metavariable represents a hole whose elaboration has been postponed. -/
|
| 71 |
+
| postponed (ctx : SavedContext)
|
| 72 |
+
deriving Inhabited
|
| 73 |
+
|
| 74 |
+
/--
|
| 75 |
+
Convert an "extra" optional error message into a message `"\n{msg}"` (if `some msg`) and `MessageData.nil` (if `none`)
|
| 76 |
+
-/
|
| 77 |
+
def extraMsgToMsg (extraErrorMsg? : Option MessageData) : MessageData :=
|
| 78 |
+
if let some msg := extraErrorMsg? then m!"\n{msg}" else .nil
|
| 79 |
+
|
| 80 |
+
instance : ToString SyntheticMVarKind where
|
| 81 |
+
toString
|
| 82 |
+
| .typeClass .. => "typeclass"
|
| 83 |
+
| .coe .. => "coe"
|
| 84 |
+
| .tactic .. => "tactic"
|
| 85 |
+
| .postponed .. => "postponed"
|
| 86 |
+
|
| 87 |
+
structure SyntheticMVarDecl where
|
| 88 |
+
stx : Syntax
|
| 89 |
+
kind : SyntheticMVarKind
|
| 90 |
+
deriving Inhabited
|
| 91 |
+
|
| 92 |
+
/--
|
| 93 |
+
We can optionally associate an error context with a metavariable (see `MVarErrorInfo`).
|
| 94 |
+
We have three different kinds of error context.
|
| 95 |
+
-/
|
| 96 |
+
inductive MVarErrorKind where
|
| 97 |
+
/-- Metavariable for implicit arguments. `ctx` is the parent application,
|
| 98 |
+
`lctx` is a local context where it is valid (necessary for eta feature for named arguments). -/
|
| 99 |
+
| implicitArg (lctx : LocalContext) (ctx : Expr)
|
| 100 |
+
/-- Metavariable for explicit holes provided by the user (e.g., `_` and `?m`) -/
|
| 101 |
+
| hole
|
| 102 |
+
/-- "Custom", `msgData` stores the additional error messages. -/
|
| 103 |
+
| custom (msgData : MessageData)
|
| 104 |
+
deriving Inhabited
|
| 105 |
+
|
| 106 |
+
instance : ToString MVarErrorKind where
|
| 107 |
+
toString
|
| 108 |
+
| .implicitArg _ _ => "implicitArg"
|
| 109 |
+
| .hole => "hole"
|
| 110 |
+
| .custom _ => "custom"
|
| 111 |
+
|
| 112 |
+
/--
|
| 113 |
+
We can optionally associate an error context with metavariables.
|
| 114 |
+
-/
|
| 115 |
+
structure MVarErrorInfo where
|
| 116 |
+
mvarId : MVarId
|
| 117 |
+
ref : Syntax
|
| 118 |
+
kind : MVarErrorKind
|
| 119 |
+
deriving Inhabited
|
| 120 |
+
|
| 121 |
+
/--
|
| 122 |
+
When reporting unexpected universe level metavariables, it is useful to localize the errors
|
| 123 |
+
to particular terms, especially at `let` bindings and function binders,
|
| 124 |
+
where universe polymorphism is not permitted.
|
| 125 |
+
-/
|
| 126 |
+
structure LevelMVarErrorInfo where
|
| 127 |
+
lctx : LocalContext
|
| 128 |
+
expr : Expr
|
| 129 |
+
ref : Syntax
|
| 130 |
+
msgData? : Option MessageData := none
|
| 131 |
+
deriving Inhabited
|
| 132 |
+
|
| 133 |
+
/--
|
| 134 |
+
Nested `let rec` expressions are eagerly lifted by the elaborator.
|
| 135 |
+
We store the information necessary for performing the lifting here.
|
| 136 |
+
-/
|
| 137 |
+
structure LetRecToLift where
|
| 138 |
+
ref : Syntax
|
| 139 |
+
fvarId : FVarId
|
| 140 |
+
attrs : Array Attribute
|
| 141 |
+
shortDeclName : Name
|
| 142 |
+
declName : Name
|
| 143 |
+
parentName? : Option Name
|
| 144 |
+
lctx : LocalContext
|
| 145 |
+
localInstances : LocalInstances
|
| 146 |
+
type : Expr
|
| 147 |
+
val : Expr
|
| 148 |
+
mvarId : MVarId
|
| 149 |
+
termination : TerminationHints
|
| 150 |
+
deriving Inhabited
|
| 151 |
+
|
| 152 |
+
/--
|
| 153 |
+
State of the `TermElabM` monad.
|
| 154 |
+
-/
|
| 155 |
+
structure State where
|
| 156 |
+
levelNames : List Name := []
|
| 157 |
+
syntheticMVars : MVarIdMap SyntheticMVarDecl := {}
|
| 158 |
+
pendingMVars : List MVarId := {}
|
| 159 |
+
/-- List of errors associated to a metavariable that are shown to the user if the metavariable could not be fully instantiated -/
|
| 160 |
+
mvarErrorInfos : List MVarErrorInfo := []
|
| 161 |
+
/-- List of data to be able to localize universe level metavariable errors to particular expressions. -/
|
| 162 |
+
levelMVarErrorInfos : List LevelMVarErrorInfo := []
|
| 163 |
+
/--
|
| 164 |
+
`mvarArgNames` stores the argument names associated to metavariables.
|
| 165 |
+
These are used in combination with `mvarErrorInfos` for throwing errors about metavariables that could not be fully instantiated.
|
| 166 |
+
For example when elaborating `List _`, the argument name of the placeholder will be `α`.
|
| 167 |
+
|
| 168 |
+
While elaborating an application, `mvarArgNames` is set for each metavariable argument, using the available argument name.
|
| 169 |
+
This may happen before or after the `mvarErrorInfos` is set for the same metavariable.
|
| 170 |
+
|
| 171 |
+
We used to store the argument names in `mvarErrorInfos`, updating the `MVarErrorInfos` to add the argument name when it is available,
|
| 172 |
+
but this doesn't work if the argument name is available _before_ the `mvarErrorInfos` is set for that metavariable.
|
| 173 |
+
-/
|
| 174 |
+
mvarArgNames : MVarIdMap Name := {}
|
| 175 |
+
letRecsToLift : List LetRecToLift := []
|
| 176 |
+
deriving Inhabited
|
| 177 |
+
|
| 178 |
+
/--
|
| 179 |
+
Backtrackable state for the `TermElabM` monad.
|
| 180 |
+
-/
|
| 181 |
+
structure SavedState where
|
| 182 |
+
«meta» : Meta.SavedState
|
| 183 |
+
«elab» : State
|
| 184 |
+
deriving Nonempty
|
| 185 |
+
|
| 186 |
+
end Term
|
| 187 |
+
|
| 188 |
+
namespace Tactic
|
| 189 |
+
|
| 190 |
+
/--
|
| 191 |
+
State of the `TacticM` monad.
|
| 192 |
+
-/
|
| 193 |
+
structure State where
|
| 194 |
+
goals : List MVarId
|
| 195 |
+
deriving Inhabited
|
| 196 |
+
|
| 197 |
+
/--
|
| 198 |
+
Snapshots are used to implement the `save` tactic.
|
| 199 |
+
This tactic caches the state of the system, and allows us to "replay"
|
| 200 |
+
expensive proofs efficiently. This is only relevant implementing the
|
| 201 |
+
LSP server.
|
| 202 |
+
-/
|
| 203 |
+
structure Snapshot where
|
| 204 |
+
core : Core.State
|
| 205 |
+
«meta» : Meta.State
|
| 206 |
+
term : Term.State
|
| 207 |
+
tactic : Tactic.State
|
| 208 |
+
stx : Syntax
|
| 209 |
+
|
| 210 |
+
/--
|
| 211 |
+
Key for the cache used to implement the `save` tactic.
|
| 212 |
+
-/
|
| 213 |
+
structure CacheKey where
|
| 214 |
+
mvarId : MVarId -- TODO: should include all goals
|
| 215 |
+
pos : String.Pos
|
| 216 |
+
deriving BEq, Hashable, Inhabited
|
| 217 |
+
|
| 218 |
+
/--
|
| 219 |
+
Cache for the `save` tactic.
|
| 220 |
+
-/
|
| 221 |
+
structure Cache where
|
| 222 |
+
pre : PHashMap CacheKey Snapshot := {}
|
| 223 |
+
post : PHashMap CacheKey Snapshot := {}
|
| 224 |
+
deriving Inhabited
|
| 225 |
+
|
| 226 |
+
section Snapshot
|
| 227 |
+
open Language
|
| 228 |
+
|
| 229 |
+
structure SavedState where
|
| 230 |
+
term : Term.SavedState
|
| 231 |
+
tactic : State
|
| 232 |
+
|
| 233 |
+
/-- Snapshot after finishing execution of a tactic. -/
|
| 234 |
+
structure TacticFinishedSnapshot extends Language.Snapshot where
|
| 235 |
+
/-- State saved for reuse, if no fatal exception occurred. -/
|
| 236 |
+
state? : Option SavedState
|
| 237 |
+
/-- Untyped snapshots from `logSnapshotTask`, saved at this level for cancellation. -/
|
| 238 |
+
moreSnaps : Array (SnapshotTask SnapshotTree)
|
| 239 |
+
deriving Inhabited
|
| 240 |
+
instance : ToSnapshotTree TacticFinishedSnapshot where
|
| 241 |
+
toSnapshotTree s := ⟨s.toSnapshot, s.moreSnaps⟩
|
| 242 |
+
|
| 243 |
+
/-- Snapshot just before execution of a tactic. -/
|
| 244 |
+
structure TacticParsedSnapshot extends Language.Snapshot where
|
| 245 |
+
/-- Syntax tree of the tactic, stored and compared for incremental reuse. -/
|
| 246 |
+
stx : Syntax
|
| 247 |
+
/-- Task for nested incrementality, if enabled for tactic. -/
|
| 248 |
+
inner? : Option (SnapshotTask TacticParsedSnapshot) := none
|
| 249 |
+
/-- Task for state after tactic execution. -/
|
| 250 |
+
finished : SnapshotTask TacticFinishedSnapshot
|
| 251 |
+
/-- Tasks for subsequent, potentially parallel, tactic steps. -/
|
| 252 |
+
next : Array (SnapshotTask TacticParsedSnapshot) := #[]
|
| 253 |
+
deriving Inhabited
|
| 254 |
+
partial instance : ToSnapshotTree TacticParsedSnapshot where
|
| 255 |
+
toSnapshotTree := go where
|
| 256 |
+
go := fun s => ⟨s.toSnapshot,
|
| 257 |
+
s.inner?.toArray.map (·.map (sync := true) go) ++
|
| 258 |
+
#[s.finished.map (sync := true) toSnapshotTree] ++
|
| 259 |
+
s.next.map (·.map (sync := true) go)⟩
|
| 260 |
+
|
| 261 |
+
end Snapshot
|
| 262 |
+
end Tactic
|
| 263 |
+
|
| 264 |
+
namespace Term
|
| 265 |
+
|
| 266 |
+
structure Context where
|
| 267 |
+
declName? : Option Name := none
|
| 268 |
+
macroStack : MacroStack := []
|
| 269 |
+
/--
|
| 270 |
+
When `mayPostpone == true`, an elaboration function may interrupt its execution by throwing `Exception.postpone`.
|
| 271 |
+
The function `elabTerm` catches this exception and creates fresh synthetic metavariable `?m`, stores `?m` in
|
| 272 |
+
the list of pending synthetic metavariables, and returns `?m`. -/
|
| 273 |
+
mayPostpone : Bool := true
|
| 274 |
+
/--
|
| 275 |
+
When `errToSorry` is set to true, the method `elabTerm` catches
|
| 276 |
+
exceptions and converts them into synthetic `sorry`s.
|
| 277 |
+
The implementation of choice nodes and overloaded symbols rely on the fact
|
| 278 |
+
that when `errToSorry` is set to false for an elaboration function `F`, then
|
| 279 |
+
`errToSorry` remains `false` for all elaboration functions invoked by `F`.
|
| 280 |
+
That is, it is safe to transition `errToSorry` from `true` to `false`, but
|
| 281 |
+
we must not set `errToSorry` to `true` when it is currently set to `false`. -/
|
| 282 |
+
errToSorry : Bool := true
|
| 283 |
+
/--
|
| 284 |
+
When `autoBoundImplicit` is set to true, instead of producing
|
| 285 |
+
an "unknown identifier" error for unbound variables, we generate an
|
| 286 |
+
internal exception. This exception is caught at `elabBinders` and
|
| 287 |
+
`elabTypeWithUnboldImplicit`. Both methods add implicit declarations
|
| 288 |
+
for the unbound variable and try again. -/
|
| 289 |
+
autoBoundImplicit : Bool := false
|
| 290 |
+
autoBoundImplicits : PArray Expr := {}
|
| 291 |
+
/--
|
| 292 |
+
A name `n` is only eligible to be an auto implicit name if `autoBoundImplicitForbidden n = false`.
|
| 293 |
+
We use this predicate to disallow `f` to be considered an auto implicit name in a definition such
|
| 294 |
+
as
|
| 295 |
+
```
|
| 296 |
+
def f : f → Bool := fun _ => true
|
| 297 |
+
```
|
| 298 |
+
-/
|
| 299 |
+
autoBoundImplicitForbidden : Name → Bool := fun _ => false
|
| 300 |
+
/-- Map from user name to internal unique name -/
|
| 301 |
+
sectionVars : NameMap Name := {}
|
| 302 |
+
/-- Map from internal name to fvar -/
|
| 303 |
+
sectionFVars : NameMap Expr := {}
|
| 304 |
+
/-- Enable/disable implicit lambdas feature. -/
|
| 305 |
+
implicitLambda : Bool := true
|
| 306 |
+
/-- Heed `elab_as_elim` attribute. -/
|
| 307 |
+
heedElabAsElim : Bool := true
|
| 308 |
+
/-- Noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for. -/
|
| 309 |
+
isNoncomputableSection : Bool := false
|
| 310 |
+
/-- When `true` we skip TC failures. We use this option when processing patterns. -/
|
| 311 |
+
ignoreTCFailures : Bool := false
|
| 312 |
+
/-- `true` when elaborating patterns. It affects how we elaborate named holes. -/
|
| 313 |
+
inPattern : Bool := false
|
| 314 |
+
/--
|
| 315 |
+
Snapshot for incremental processing of current tactic, if any.
|
| 316 |
+
|
| 317 |
+
Invariant: if the bundle's `old?` is set, then the state *up to the start* of the tactic is
|
| 318 |
+
unchanged, i.e. reuse is possible.
|
| 319 |
+
-/
|
| 320 |
+
tacSnap? : Option (Language.SnapshotBundle Tactic.TacticParsedSnapshot) := none
|
| 321 |
+
/--
|
| 322 |
+
If `true`, we store in the `Expr` the `Syntax` for recursive applications (i.e., applications
|
| 323 |
+
of free variables tagged with `isAuxDecl`). We store the `Syntax` using `mkRecAppWithSyntax`.
|
| 324 |
+
We use the `Syntax` object to produce better error messages at `Structural.lean` and `WF.lean`. -/
|
| 325 |
+
saveRecAppSyntax : Bool := true
|
| 326 |
+
/--
|
| 327 |
+
If `holesAsSyntheticOpaque` is `true`, then we mark metavariables associated
|
| 328 |
+
with `_`s as `syntheticOpaque` if they do not occur in patterns.
|
| 329 |
+
This option is useful when elaborating terms in tactics such as `refine'` where
|
| 330 |
+
we want holes there to become new goals. See issue #1681, we have
|
| 331 |
+
`refine' (fun x => _)
|
| 332 |
+
-/
|
| 333 |
+
holesAsSyntheticOpaque : Bool := false
|
| 334 |
+
/--
|
| 335 |
+
If `checkDeprecated := true`, then `Linter.checkDeprecated` when creating constants.
|
| 336 |
+
-/
|
| 337 |
+
checkDeprecated : Bool := true
|
| 338 |
+
|
| 339 |
+
abbrev TermElabM := ReaderT Context $ StateRefT State MetaM
|
| 340 |
+
abbrev TermElab := Syntax → Option Expr → TermElabM Expr
|
| 341 |
+
|
| 342 |
+
/-
|
| 343 |
+
Make the compiler generate specialized `pure`/`bind` so we do not have to optimize through the
|
| 344 |
+
whole monad stack at every use site. May eventually be covered by `deriving`.
|
| 345 |
+
-/
|
| 346 |
+
@[always_inline]
|
| 347 |
+
instance : Monad TermElabM :=
|
| 348 |
+
let i := inferInstanceAs (Monad TermElabM)
|
| 349 |
+
{ pure := i.pure, bind := i.bind }
|
| 350 |
+
|
| 351 |
+
open Meta
|
| 352 |
+
|
| 353 |
+
instance : Inhabited (TermElabM α) where
|
| 354 |
+
default := throw default
|
| 355 |
+
|
| 356 |
+
protected def saveState : TermElabM SavedState :=
|
| 357 |
+
return { «meta» := (← Meta.saveState), «elab» := (← get) }
|
| 358 |
+
|
| 359 |
+
def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : TermElabM Unit := do
|
| 360 |
+
let traceState ← getTraceState -- We never backtrack trace message
|
| 361 |
+
let infoState ← getInfoState -- We also do not backtrack the info nodes when `restoreInfo == false`
|
| 362 |
+
s.meta.restore
|
| 363 |
+
set s.elab
|
| 364 |
+
setTraceState traceState
|
| 365 |
+
unless restoreInfo do
|
| 366 |
+
setInfoState infoState
|
| 367 |
+
|
| 368 |
+
/--
|
| 369 |
+
Like `Meta.withRestoreOrSaveFull` for `TermElabM`, but also takes a `tacSnap?` that
|
| 370 |
+
* when running `act`, is set as `Context.tacSnap?`
|
| 371 |
+
* otherwise (i.e. on restore) is used to update the new snapshot promise to the old task's
|
| 372 |
+
value.
|
| 373 |
+
This extra restore step is necessary because while `reusableResult?` can be used to replay any
|
| 374 |
+
effects on `State`, `Context.tacSnap?` is not part of it but changed via an `IO` side effect, so
|
| 375 |
+
it needs to be replayed separately.
|
| 376 |
+
|
| 377 |
+
We use an explicit parameter instead of accessing `Context.tacSnap?` directly because this prevents
|
| 378 |
+
`withRestoreOrSaveFull` and `withReader` from being used in the wrong order.
|
| 379 |
+
-/
|
| 380 |
+
@[specialize]
|
| 381 |
+
def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState))
|
| 382 |
+
(tacSnap? : Option (Language.SnapshotBundle Tactic.TacticParsedSnapshot)) (act : TermElabM α) :
|
| 383 |
+
TermElabM (α × SavedState) := do
|
| 384 |
+
if let some (_, state) := reusableResult? then
|
| 385 |
+
set state.elab
|
| 386 |
+
if let some snap := tacSnap? then
|
| 387 |
+
let some old := snap.old?
|
| 388 |
+
| throwError "withRestoreOrSaveFull: expected old snapshot in `tacSnap?`"
|
| 389 |
+
snap.new.resolve old.val.get
|
| 390 |
+
|
| 391 |
+
let reusableResult? := reusableResult?.map (fun (val, state) => (val, state.meta))
|
| 392 |
+
let (a, «meta») ← withReader ({ · with tacSnap? }) do
|
| 393 |
+
controlAt MetaM fun runInBase => do
|
| 394 |
+
Meta.withRestoreOrSaveFull reusableResult? <| runInBase act
|
| 395 |
+
return (a, { «meta», «elab» := (← get) })
|
| 396 |
+
|
| 397 |
+
instance : MonadBacktrack SavedState TermElabM where
|
| 398 |
+
saveState := Term.saveState
|
| 399 |
+
restoreState b := b.restore
|
| 400 |
+
|
| 401 |
+
/--
|
| 402 |
+
Incremental elaboration helper. Avoids leakage of data from outside syntax via the monadic context
|
| 403 |
+
when running `act` on `stx` by
|
| 404 |
+
* setting `stx` as the `ref` and
|
| 405 |
+
* deactivating `suppressElabErrors` if `stx` is `missing`-free, which also helps with not hiding
|
| 406 |
+
useful errors in this part of the input. Note that if `stx` has `missing`, this should always be
|
| 407 |
+
true for the outer syntax as well, so taking the old value of `suppressElabErrors` into account
|
| 408 |
+
should not introduce data leakage.
|
| 409 |
+
|
| 410 |
+
This combinator should always be used when narrowing reuse to a syntax subtree, usually (in the case
|
| 411 |
+
of tactics, to be generalized) via `withNarrowed(Arg)TacticReuse`.
|
| 412 |
+
-/
|
| 413 |
+
def withReuseContext [Monad m] [MonadWithReaderOf Core.Context m] (stx : Syntax) (act : m α) :
|
| 414 |
+
m α := do
|
| 415 |
+
withTheReader Core.Context (fun ctx => { ctx with
|
| 416 |
+
ref := stx
|
| 417 |
+
suppressElabErrors := ctx.suppressElabErrors && stx.hasMissing }) act
|
| 418 |
+
|
| 419 |
+
/--
|
| 420 |
+
Manages reuse information for nested tactics by `split`ting given syntax into an outer and inner
|
| 421 |
+
part. `act` is then run on the inner part but with reuse information adjusted as following:
|
| 422 |
+
* If the old (from `tacSnap?`'s `SyntaxGuarded.stx`) and new (from `stx`) outer syntax are not
|
| 423 |
+
identical according to `Syntax.eqWithInfo`, reuse is disabled.
|
| 424 |
+
* Otherwise, the old syntax as stored in `tacSnap?` is updated to the old *inner* syntax.
|
| 425 |
+
* In any case, `withReuseContext` is used on the new inner syntax to further prepare the monadic
|
| 426 |
+
context.
|
| 427 |
+
|
| 428 |
+
For any tactic that participates in reuse, `withNarrowedTacticReuse` should be applied to the
|
| 429 |
+
tactic's syntax and `act` should be used to do recursive tactic evaluation of nested parts. Also,
|
| 430 |
+
after this function, `getAndEmptySnapshotTasks` should be called and the result stored in a snapshot
|
| 431 |
+
so that the tasks don't end up in a snapshot further up and are cancelled together with it; see
|
| 432 |
+
note [Incremental Cancellation].
|
| 433 |
+
-/
|
| 434 |
+
def withNarrowedTacticReuse [Monad m] [MonadReaderOf Context m] [MonadLiftT BaseIO m]
|
| 435 |
+
[MonadWithReaderOf Core.Context m] [MonadWithReaderOf Context m] [MonadOptions m]
|
| 436 |
+
(split : Syntax → Syntax × Syntax) (act : Syntax → m α) (stx : Syntax) : m α := do
|
| 437 |
+
let (outer, inner) := split stx
|
| 438 |
+
let opts ← getOptions
|
| 439 |
+
let ctx ← readThe Term.Context
|
| 440 |
+
withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.map fun tacSnap =>
|
| 441 |
+
{ tacSnap with old? := tacSnap.old?.bind fun old => do
|
| 442 |
+
let (oldOuter, oldInner) := split old.stx
|
| 443 |
+
guard <| outer.eqWithInfoAndTraceReuse opts oldOuter
|
| 444 |
+
return { old with stx := oldInner }
|
| 445 |
+
}
|
| 446 |
+
}) do
|
| 447 |
+
if let some oldOuter := ctx.tacSnap?.bind (·.old?) then
|
| 448 |
+
if (← read).tacSnap?.bind (·.old?) |>.isNone then
|
| 449 |
+
oldOuter.val.cancelRec
|
| 450 |
+
withReuseContext inner (act inner)
|
| 451 |
+
|
| 452 |
+
/--
|
| 453 |
+
A variant of `withNarrowedTacticReuse` that uses `stx[argIdx]` as the inner syntax and all `stx`
|
| 454 |
+
child nodes before that as the outer syntax, i.e. reuse is disabled if there was any change before
|
| 455 |
+
`argIdx`.
|
| 456 |
+
|
| 457 |
+
NOTE: child nodes after `argIdx` are not tested (which would almost always disable reuse as they are
|
| 458 |
+
necessarily shifted by changes at `argIdx`) so it must be ensured that the result of `arg` does not
|
| 459 |
+
depend on them (i.e. they should not be inspected beforehand).
|
| 460 |
+
-/
|
| 461 |
+
def withNarrowedArgTacticReuse [Monad m] [MonadReaderOf Context m] [MonadLiftT BaseIO m]
|
| 462 |
+
[MonadWithReaderOf Core.Context m] [MonadWithReaderOf Context m] [MonadOptions m]
|
| 463 |
+
(argIdx : Nat) (act : Syntax → m α) (stx : Syntax) : m α :=
|
| 464 |
+
withNarrowedTacticReuse (fun stx => (mkNullNode stx.getArgs[*...argIdx], stx[argIdx])) act stx
|
| 465 |
+
|
| 466 |
+
/--
|
| 467 |
+
Disables incremental tactic reuse *and* reporting for `act` if `cond` is true by setting `tacSnap?`
|
| 468 |
+
to `none`. This should be done for tactic blocks that are run multiple times as otherwise the
|
| 469 |
+
reported progress will jump back and forth (and partial reuse for these kinds of tact blocks is
|
| 470 |
+
similarly questionable).
|
| 471 |
+
-/
|
| 472 |
+
def withoutTacticIncrementality [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
|
| 473 |
+
(cond : Bool) (act : m α) : m α := do
|
| 474 |
+
let opts ← getOptions
|
| 475 |
+
withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.filter fun tacSnap => Id.run do
|
| 476 |
+
if let some old := tacSnap.old? then
|
| 477 |
+
if cond && opts.getBool `trace.Elab.reuse then
|
| 478 |
+
dbg_trace "reuse stopped: guard failed at {old.stx}"
|
| 479 |
+
return !cond
|
| 480 |
+
}) act
|
| 481 |
+
|
| 482 |
+
/-- Disables incremental tactic reuse for `act` if `cond` is true. -/
|
| 483 |
+
def withoutTacticReuse [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
|
| 484 |
+
(cond : Bool) (act : m ��) : m α := do
|
| 485 |
+
let opts ← getOptions
|
| 486 |
+
withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.map fun tacSnap =>
|
| 487 |
+
{ tacSnap with old? := tacSnap.old?.filter fun old => Id.run do
|
| 488 |
+
if cond && opts.getBool `trace.Elab.reuse then
|
| 489 |
+
dbg_trace "reuse stopped: guard failed at {old.stx}"
|
| 490 |
+
return !cond }
|
| 491 |
+
}) act
|
| 492 |
+
|
| 493 |
+
@[inherit_doc Core.wrapAsyncAsSnapshot]
|
| 494 |
+
def wrapAsyncAsSnapshot {α : Type} (act : α → TermElabM Unit) (cancelTk? : Option IO.CancelToken)
|
| 495 |
+
(desc : String := by exact decl_name%.toString) :
|
| 496 |
+
TermElabM (α → BaseIO Language.SnapshotTree) := do
|
| 497 |
+
let ctx ← read
|
| 498 |
+
let st ← get
|
| 499 |
+
let metaCtx ← readThe Meta.Context
|
| 500 |
+
let metaSt ← getThe Meta.State
|
| 501 |
+
Core.wrapAsyncAsSnapshot (cancelTk? := cancelTk?) (desc := desc) fun a =>
|
| 502 |
+
act a |>.run ctx |>.run' st |>.run' metaCtx metaSt
|
| 503 |
+
|
| 504 |
+
abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
|
| 505 |
+
|
| 506 |
+
/--
|
| 507 |
+
Execute `x`, save resulting expression and new state.
|
| 508 |
+
We remove any `Info` created by `x`.
|
| 509 |
+
The info nodes are committed when we execute `applyResult`.
|
| 510 |
+
We use `observing` to implement overloaded notation and decls.
|
| 511 |
+
We want to save `Info` nodes for the chosen alternative.
|
| 512 |
+
-/
|
| 513 |
+
def observing (x : TermElabM α) : TermElabM (TermElabResult α) := do
|
| 514 |
+
let s ← saveState
|
| 515 |
+
try
|
| 516 |
+
let e ← x
|
| 517 |
+
let sNew ← saveState
|
| 518 |
+
s.restore (restoreInfo := true)
|
| 519 |
+
return EStateM.Result.ok e sNew
|
| 520 |
+
catch
|
| 521 |
+
| ex@(.error ..) =>
|
| 522 |
+
let sNew ← saveState
|
| 523 |
+
s.restore (restoreInfo := true)
|
| 524 |
+
return .error ex sNew
|
| 525 |
+
| ex@(.internal id _) =>
|
| 526 |
+
if id == postponeExceptionId then
|
| 527 |
+
s.restore (restoreInfo := true)
|
| 528 |
+
throw ex
|
| 529 |
+
|
| 530 |
+
/--
|
| 531 |
+
Apply the result/exception and state captured with `observing`.
|
| 532 |
+
We use this method to implement overloaded notation and symbols. -/
|
| 533 |
+
def applyResult (result : TermElabResult α) : TermElabM α := do
|
| 534 |
+
match result with
|
| 535 |
+
| .ok a r => r.restore (restoreInfo := true); return a
|
| 536 |
+
| .error ex r => r.restore (restoreInfo := true); throw ex
|
| 537 |
+
|
| 538 |
+
/--
|
| 539 |
+
Execute `x`, but keep state modifications only if `x` did not postpone.
|
| 540 |
+
This method is useful to implement elaboration functions that cannot decide whether
|
| 541 |
+
they need to postpone or not without updating the state. -/
|
| 542 |
+
def commitIfDidNotPostpone (x : TermElabM α) : TermElabM α := do
|
| 543 |
+
-- We just reuse the implementation of `observing` and `applyResult`.
|
| 544 |
+
let r ← observing x
|
| 545 |
+
applyResult r
|
| 546 |
+
|
| 547 |
+
/--
|
| 548 |
+
Return the universe level names explicitly provided by the user.
|
| 549 |
+
-/
|
| 550 |
+
def getLevelNames : TermElabM (List Name) :=
|
| 551 |
+
return (← get).levelNames
|
| 552 |
+
|
| 553 |
+
/--
|
| 554 |
+
Given a free variable `fvar`, return its declaration.
|
| 555 |
+
This function panics if `fvar` is not a free variable.
|
| 556 |
+
-/
|
| 557 |
+
def getFVarLocalDecl! (fvar : Expr) : TermElabM LocalDecl := do
|
| 558 |
+
match (← getLCtx).find? fvar.fvarId! with
|
| 559 |
+
| some d => pure d
|
| 560 |
+
| none => unreachable!
|
| 561 |
+
|
| 562 |
+
instance : AddErrorMessageContext TermElabM where
|
| 563 |
+
add ref msg := do
|
| 564 |
+
let ctx ← read
|
| 565 |
+
let ref := getBetterRef ref ctx.macroStack
|
| 566 |
+
let msg ← addMessageContext msg
|
| 567 |
+
let msg ← addMacroStack msg ctx.macroStack
|
| 568 |
+
pure (ref, msg)
|
| 569 |
+
|
| 570 |
+
/--
|
| 571 |
+
Execute `x` without storing `Syntax` for recursive applications. See `saveRecAppSyntax` field at `Context`.
|
| 572 |
+
-/
|
| 573 |
+
def withoutSavingRecAppSyntax (x : TermElabM α) : TermElabM α :=
|
| 574 |
+
withReader (fun ctx => { ctx with saveRecAppSyntax := false }) x
|
| 575 |
+
|
| 576 |
+
unsafe def mkTermElabAttributeUnsafe (ref : Name) : IO (KeyedDeclsAttribute TermElab) :=
|
| 577 |
+
mkElabAttribute TermElab `builtin_term_elab `term_elab `Lean.Parser.Term `Lean.Elab.Term.TermElab "term" ref
|
| 578 |
+
|
| 579 |
+
@[implemented_by mkTermElabAttributeUnsafe]
|
| 580 |
+
opaque mkTermElabAttribute (ref : Name) : IO (KeyedDeclsAttribute TermElab)
|
| 581 |
+
|
| 582 |
+
/--
|
| 583 |
+
Registers a term elaborator for the given syntax node kind.
|
| 584 |
+
|
| 585 |
+
A term elaborator should have type `Lean.Elab.Term.TermElab` (which is
|
| 586 |
+
`Lean.Syntax → Option Lean.Expr → Lean.Elab.Term.TermElabM Lean.Expr`), i.e. should take syntax of
|
| 587 |
+
the given syntax node kind and an optional expected type as parameters and produce an expression.
|
| 588 |
+
|
| 589 |
+
The `elab_rules` and `elab` commands should usually be preferred over using this attribute
|
| 590 |
+
directly.
|
| 591 |
+
-/
|
| 592 |
+
@[builtin_doc]
|
| 593 |
+
builtin_initialize termElabAttribute : KeyedDeclsAttribute TermElab ← mkTermElabAttribute decl_name%
|
| 594 |
+
|
| 595 |
+
/--
|
| 596 |
+
Auxiliary datatype for presenting a Lean lvalue modifier.
|
| 597 |
+
We represent an unelaborated lvalue as a `Syntax` (or `Expr`) and `List LVal`.
|
| 598 |
+
Example: `a.foo.1` is represented as the `Syntax` `a` and the list
|
| 599 |
+
`[LVal.fieldName "foo", LVal.fieldIdx 1]`.
|
| 600 |
+
-/
|
| 601 |
+
inductive LVal where
|
| 602 |
+
| fieldIdx (ref : Syntax) (i : Nat)
|
| 603 |
+
/-- Field `suffix?` is for producing better error messages because `x.y` may be a field access or a hierarchical/composite name.
|
| 604 |
+
`ref` is the syntax object representing the field. `fullRef` includes the LHS. -/
|
| 605 |
+
| fieldName (ref : Syntax) (name : String) (suffix? : Option Name) (fullRef : Syntax)
|
| 606 |
+
|
| 607 |
+
def LVal.getRef : LVal → Syntax
|
| 608 |
+
| .fieldIdx ref _ => ref
|
| 609 |
+
| .fieldName ref .. => ref
|
| 610 |
+
|
| 611 |
+
def LVal.isFieldName : LVal ��� Bool
|
| 612 |
+
| .fieldName .. => true
|
| 613 |
+
| _ => false
|
| 614 |
+
|
| 615 |
+
instance : ToString LVal where
|
| 616 |
+
toString
|
| 617 |
+
| .fieldIdx _ i => toString i
|
| 618 |
+
| .fieldName _ n .. => n
|
| 619 |
+
|
| 620 |
+
/-- Return the name of the declaration being elaborated if available. -/
|
| 621 |
+
def getDeclName? : TermElabM (Option Name) := return (← read).declName?
|
| 622 |
+
/-- Return the list of nested `let rec` declarations that need to be lifted. -/
|
| 623 |
+
def getLetRecsToLift : TermElabM (List LetRecToLift) := return (← get).letRecsToLift
|
| 624 |
+
/-- Return the declaration of the given metavariable -/
|
| 625 |
+
def getMVarDecl (mvarId : MVarId) : TermElabM MetavarDecl := return (← getMCtx).getDecl mvarId
|
| 626 |
+
|
| 627 |
+
instance : MonadParentDecl TermElabM where
|
| 628 |
+
getParentDeclName? := getDeclName?
|
| 629 |
+
|
| 630 |
+
/--
|
| 631 |
+
Executes `x` in the context of the given declaration name. Ensures that the info tree is set up
|
| 632 |
+
correctly and adjusts the declaration name generator to generate names below this name, resetting
|
| 633 |
+
the nested counter.
|
| 634 |
+
-/
|
| 635 |
+
def withDeclName (name : Name) (x : TermElabM α) : TermElabM α :=
|
| 636 |
+
withReader (fun ctx => { ctx with declName? := name }) do
|
| 637 |
+
withSaveParentDeclInfoContext do
|
| 638 |
+
withDeclNameForAuxNaming name do
|
| 639 |
+
x
|
| 640 |
+
|
| 641 |
+
/-- Update the universe level parameter names. -/
|
| 642 |
+
def setLevelNames (levelNames : List Name) : TermElabM Unit :=
|
| 643 |
+
modify fun s => { s with levelNames := levelNames }
|
| 644 |
+
|
| 645 |
+
/-- Execute `x` using `levelNames` as the universe level parameter names. See `getLevelNames`. -/
|
| 646 |
+
def withLevelNames (levelNames : List Name) (x : TermElabM α) : TermElabM α := do
|
| 647 |
+
let levelNamesSaved ← getLevelNames
|
| 648 |
+
setLevelNames levelNames
|
| 649 |
+
try x finally setLevelNames levelNamesSaved
|
| 650 |
+
|
| 651 |
+
def withoutErrToSorryImp (x : TermElabM α) : TermElabM α :=
|
| 652 |
+
withReader (fun ctx => { ctx with errToSorry := false }) x
|
| 653 |
+
|
| 654 |
+
/--
|
| 655 |
+
Execute `x` without converting errors (i.e., exceptions) to `sorry` applications.
|
| 656 |
+
Recall that when `errToSorry = true`, the method `elabTerm` catches exceptions and converts them into `sorry` applications.
|
| 657 |
+
-/
|
| 658 |
+
def withoutErrToSorry [MonadFunctorT TermElabM m] : m α → m α :=
|
| 659 |
+
monadMap (m := TermElabM) withoutErrToSorryImp
|
| 660 |
+
|
| 661 |
+
def withoutHeedElabAsElimImp (x : TermElabM α) : TermElabM α :=
|
| 662 |
+
withReader (fun ctx => { ctx with heedElabAsElim := false }) x
|
| 663 |
+
|
| 664 |
+
/--
|
| 665 |
+
Execute `x` without heeding the `elab_as_elim` attribute. Useful when there is
|
| 666 |
+
no expected type (so `elabAppArgs` would fail), but expect that the user wants
|
| 667 |
+
to use such constants.
|
| 668 |
+
-/
|
| 669 |
+
def withoutHeedElabAsElim [MonadFunctorT TermElabM m] : m α → m α :=
|
| 670 |
+
monadMap (m := TermElabM) withoutHeedElabAsElimImp
|
| 671 |
+
|
| 672 |
+
/--
|
| 673 |
+
Execute `x` but discard changes performed at `Term.State` and `Meta.State`.
|
| 674 |
+
Recall that the `Environment`, `InfoState` and messages are at `Core.State`. Thus, any updates to
|
| 675 |
+
it will be preserved.
|
| 676 |
+
This method is useful for performing computations where all metavariable must be resolved or
|
| 677 |
+
discarded.
|
| 678 |
+
The `InfoTree`s are not discarded, however, and wrapped in `InfoTree.Context`
|
| 679 |
+
to store their metavariable context.
|
| 680 |
+
-/
|
| 681 |
+
def withoutModifyingElabMetaStateWithInfo (x : TermElabM α) : TermElabM α := do
|
| 682 |
+
let s ← get
|
| 683 |
+
let sMeta ← getThe Meta.State
|
| 684 |
+
try
|
| 685 |
+
withSaveInfoContext x
|
| 686 |
+
finally
|
| 687 |
+
set s
|
| 688 |
+
set sMeta
|
| 689 |
+
|
| 690 |
+
/--
|
| 691 |
+
Execute `x` but discard changes performed to the state.
|
| 692 |
+
However, the info trees and messages are not discarded. -/
|
| 693 |
+
private def withoutModifyingStateWithInfoAndMessagesImpl (x : TermElabM α) : TermElabM α := do
|
| 694 |
+
let saved ← saveState
|
| 695 |
+
try
|
| 696 |
+
withSaveInfoContext x
|
| 697 |
+
finally
|
| 698 |
+
let saved := { saved with meta.core.infoState := (← getInfoState), meta.core.messages := (← getThe Core.State).messages }
|
| 699 |
+
restoreState saved
|
| 700 |
+
|
| 701 |
+
/--
|
| 702 |
+
Wraps the trees returned from `getInfoTrees`, if any, in an `InfoTree.context` node based on the
|
| 703 |
+
current monadic context and state. This is mainly used to report info trees early via
|
| 704 |
+
`Snapshot.infoTree?`. The trees are not removed from the `getInfoTrees` state as the final info tree
|
| 705 |
+
of the elaborated command should be complete and not depend on whether parts have been reported
|
| 706 |
+
early.
|
| 707 |
+
|
| 708 |
+
As `InfoTree.context` can have only one child, this function panics if `trees` contains more than 1
|
| 709 |
+
tree. Also, `PartialContextInfo.parentDeclCtx` is not currently generated as that information is not
|
| 710 |
+
available in the monadic context and only needed for the final info tree.
|
| 711 |
+
-/
|
| 712 |
+
def getInfoTreeWithContext? : TermElabM (Option InfoTree) := do
|
| 713 |
+
let st ← getInfoState
|
| 714 |
+
if st.trees.size > 1 then
|
| 715 |
+
return panic! "getInfoTreeWithContext: overfull tree"
|
| 716 |
+
let some t := st.trees[0]? |
|
| 717 |
+
return none
|
| 718 |
+
let t := t.substitute st.assignment
|
| 719 |
+
let ctx ← readThe Core.Context
|
| 720 |
+
let s ← getThe Core.State
|
| 721 |
+
let ctx := PartialContextInfo.commandCtx {
|
| 722 |
+
env := s.env, fileMap := ctx.fileMap, mctx := {}, currNamespace := ctx.currNamespace,
|
| 723 |
+
openDecls := ctx.openDecls, options := ctx.options, ngen := s.ngen
|
| 724 |
+
}
|
| 725 |
+
return InfoTree.context ctx t
|
| 726 |
+
|
| 727 |
+
/-- For testing `TermElabM` methods. The #eval command will sign the error. -/
|
| 728 |
+
def throwErrorIfErrors : TermElabM Unit := do
|
| 729 |
+
if (← MonadLog.hasErrors) then
|
| 730 |
+
throwError "Error(s)"
|
| 731 |
+
|
| 732 |
+
def traceAtCmdPos (cls : Name) (msg : Unit → MessageData) : TermElabM Unit :=
|
| 733 |
+
withRef Syntax.missing <| trace cls msg
|
| 734 |
+
|
| 735 |
+
def ppGoal (mvarId : MVarId) : TermElabM Format :=
|
| 736 |
+
Meta.ppGoal mvarId
|
| 737 |
+
|
| 738 |
+
open Level (LevelElabM)
|
| 739 |
+
|
| 740 |
+
def liftLevelM (x : LevelElabM α) : TermElabM α := do
|
| 741 |
+
let ctx ← read
|
| 742 |
+
let mctx ← getMCtx
|
| 743 |
+
let ngen ← getNGen
|
| 744 |
+
let lvlCtx : Level.Context := { options := (← getOptions), ref := (← getRef), autoBoundImplicit := ctx.autoBoundImplicit }
|
| 745 |
+
match (x lvlCtx).run { ngen := ngen, mctx := mctx, levelNames := (← getLevelNames) } with
|
| 746 |
+
| .ok a newS => setMCtx newS.mctx; setNGen newS.ngen; setLevelNames newS.levelNames; pure a
|
| 747 |
+
| .error ex _ => throw ex
|
| 748 |
+
|
| 749 |
+
def elabLevel (stx : Syntax) : TermElabM Level :=
|
| 750 |
+
liftLevelM <| Level.elabLevel stx
|
| 751 |
+
|
| 752 |
+
/-- Elaborate `x` with `stx` on the macro stack -/
|
| 753 |
+
def withPushMacroExpansionStack (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
|
| 754 |
+
withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
|
| 755 |
+
|
| 756 |
+
/-- Elaborate `x` with `stx` on the macro stack and produce macro expansion info -/
|
| 757 |
+
def withMacroExpansion (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
|
| 758 |
+
withMacroExpansionInfo beforeStx afterStx do
|
| 759 |
+
withPushMacroExpansionStack beforeStx afterStx x
|
| 760 |
+
|
| 761 |
+
/--
|
| 762 |
+
Add the given metavariable to the list of pending synthetic metavariables.
|
| 763 |
+
The method `synthesizeSyntheticMVars` is used to process the metavariables on this list. -/
|
| 764 |
+
def registerSyntheticMVar (stx : Syntax) (mvarId : MVarId) (kind : SyntheticMVarKind) : TermElabM Unit := do
|
| 765 |
+
modify fun s => { s with syntheticMVars := s.syntheticMVars.insert mvarId { stx, kind }, pendingMVars := mvarId :: s.pendingMVars }
|
| 766 |
+
|
| 767 |
+
def registerSyntheticMVarWithCurrRef (mvarId : MVarId) (kind : SyntheticMVarKind) : TermElabM Unit := do
|
| 768 |
+
registerSyntheticMVar (← getRef) mvarId kind
|
| 769 |
+
|
| 770 |
+
def registerMVarErrorInfo (mvarErrorInfo : MVarErrorInfo) : TermElabM Unit :=
|
| 771 |
+
modify fun s => { s with mvarErrorInfos := mvarErrorInfo :: s.mvarErrorInfos }
|
| 772 |
+
|
| 773 |
+
def registerMVarErrorHoleInfo (mvarId : MVarId) (ref : Syntax) : TermElabM Unit :=
|
| 774 |
+
registerMVarErrorInfo { mvarId, ref, kind := .hole }
|
| 775 |
+
|
| 776 |
+
def registerMVarErrorImplicitArgInfo (mvarId : MVarId) (ref : Syntax) (app : Expr) : TermElabM Unit := do
|
| 777 |
+
registerMVarErrorInfo { mvarId, ref, kind := .implicitArg (← getLCtx) app }
|
| 778 |
+
|
| 779 |
+
def registerMVarErrorCustomInfo (mvarId : MVarId) (ref : Syntax) (msgData : MessageData) : TermElabM Unit := do
|
| 780 |
+
registerMVarErrorInfo { mvarId, ref, kind := .custom msgData }
|
| 781 |
+
|
| 782 |
+
def registerCustomErrorIfMVar (e : Expr) (ref : Syntax) (msgData : MessageData) : TermElabM Unit :=
|
| 783 |
+
match e.getAppFn with
|
| 784 |
+
| Expr.mvar mvarId => registerMVarErrorCustomInfo mvarId ref msgData
|
| 785 |
+
| _ => pure ()
|
| 786 |
+
|
| 787 |
+
def registerMVarArgName (mvarId : MVarId) (argName : Name) : TermElabM Unit :=
|
| 788 |
+
modify fun s => { s with mvarArgNames := s.mvarArgNames.insert mvarId argName }
|
| 789 |
+
|
| 790 |
+
/--
|
| 791 |
+
Auxiliary method for reporting errors of the form "... contains metavariables ...".
|
| 792 |
+
This kind of error is thrown, for example, at `Match.lean` where elaboration
|
| 793 |
+
cannot continue if there are metavariables in patterns.
|
| 794 |
+
We only want to log it if we haven't logged any errors so far. -/
|
| 795 |
+
def throwMVarError (m : MessageData) : TermElabM α := do
|
| 796 |
+
if (← MonadLog.hasErrors) then
|
| 797 |
+
throwAbortTerm
|
| 798 |
+
else
|
| 799 |
+
throwError m
|
| 800 |
+
|
| 801 |
+
def MVarErrorInfo.logError (mvarErrorInfo : MVarErrorInfo) (extraMsg? : Option MessageData) : TermElabM Unit := do
|
| 802 |
+
match mvarErrorInfo.kind with
|
| 803 |
+
| MVarErrorKind.implicitArg lctx app => withLCtx lctx {} do
|
| 804 |
+
let app ← instantiateMVars app
|
| 805 |
+
let msg ← addArgName "don't know how to synthesize implicit argument"
|
| 806 |
+
let msg := msg ++ m!"{indentExpr app.setAppPPExplicitForExposingMVars}" ++ Format.line ++ "context:" ++ Format.line ++ MessageData.ofGoal mvarErrorInfo.mvarId
|
| 807 |
+
logErrorAt mvarErrorInfo.ref (appendExtra msg)
|
| 808 |
+
| MVarErrorKind.hole => do
|
| 809 |
+
let msg ← addArgName "don't know how to synthesize placeholder" " for argument"
|
| 810 |
+
let msg := msg ++ Format.line ++ "context:" ++ Format.line ++ MessageData.ofGoal mvarErrorInfo.mvarId
|
| 811 |
+
logErrorAt mvarErrorInfo.ref (MessageData.tagged `Elab.synthPlaceholder <| appendExtra msg)
|
| 812 |
+
| MVarErrorKind.custom msg =>
|
| 813 |
+
logErrorAt mvarErrorInfo.ref (appendExtra msg)
|
| 814 |
+
where
|
| 815 |
+
/-- Append the argument name (if available) to the message.
|
| 816 |
+
Remark: if the argument name contains macro scopes we do not append it. -/
|
| 817 |
+
addArgName (msg : MessageData) (extra : String := "") : TermElabM MessageData := do
|
| 818 |
+
match (← get).mvarArgNames.find? mvarErrorInfo.mvarId with
|
| 819 |
+
| none => return msg
|
| 820 |
+
| some argName => return if argName.hasMacroScopes then msg else msg ++ extra ++ m!" '{argName}'"
|
| 821 |
+
|
| 822 |
+
appendExtra (msg : MessageData) : MessageData :=
|
| 823 |
+
match extraMsg? with
|
| 824 |
+
| none => msg
|
| 825 |
+
| some extraMsg => msg ++ extraMsg
|
| 826 |
+
|
| 827 |
+
/--
|
| 828 |
+
Try to log errors for the unassigned metavariables `pendingMVarIds`.
|
| 829 |
+
|
| 830 |
+
Return `true` if there were "unfilled holes", and we should "abort" declaration.
|
| 831 |
+
TODO: try to fill "all" holes using synthetic "sorry's"
|
| 832 |
+
|
| 833 |
+
Remark: We only log the "unfilled holes" as new errors if no error has been logged so far. -/
|
| 834 |
+
def logUnassignedUsingErrorInfos (pendingMVarIds : Array MVarId) (extraMsg? : Option MessageData := none) : TermElabM Bool := do
|
| 835 |
+
if pendingMVarIds.isEmpty then
|
| 836 |
+
return false
|
| 837 |
+
else
|
| 838 |
+
let hasOtherErrors ← MonadLog.hasErrors
|
| 839 |
+
let mut hasNewErrors := false
|
| 840 |
+
let mut alreadyVisited : MVarIdSet := {}
|
| 841 |
+
let mut errors : Array MVarErrorInfo := #[]
|
| 842 |
+
for mvarErrorInfo in (← get).mvarErrorInfos do
|
| 843 |
+
let mvarId := mvarErrorInfo.mvarId
|
| 844 |
+
unless alreadyVisited.contains mvarId do
|
| 845 |
+
alreadyVisited := alreadyVisited.insert mvarId
|
| 846 |
+
/- The metavariable `mvarErrorInfo.mvarId` may have been assigned or
|
| 847 |
+
delayed assigned to another metavariable that is unassigned. -/
|
| 848 |
+
let mvarDeps ← getMVars (mkMVar mvarId)
|
| 849 |
+
if mvarDeps.any pendingMVarIds.contains then do
|
| 850 |
+
unless hasOtherErrors do
|
| 851 |
+
errors := errors.push mvarErrorInfo
|
| 852 |
+
hasNewErrors := true
|
| 853 |
+
-- To sort the errors by position use
|
| 854 |
+
-- let sortedErrors := errors.qsort fun e₁ e₂ => e₁.ref.getPos?.getD 0 < e₂.ref.getPos?.getD 0
|
| 855 |
+
for error in errors do
|
| 856 |
+
error.mvarId.withContext do
|
| 857 |
+
error.logError extraMsg?
|
| 858 |
+
return hasNewErrors
|
| 859 |
+
|
| 860 |
+
def registerLevelMVarErrorInfo (levelMVarErrorInfo : LevelMVarErrorInfo) : TermElabM Unit :=
|
| 861 |
+
modify fun s => { s with levelMVarErrorInfos := levelMVarErrorInfo :: s.levelMVarErrorInfos }
|
| 862 |
+
|
| 863 |
+
def registerLevelMVarErrorExprInfo (expr : Expr) (ref : Syntax) (msgData? : Option MessageData := none) : TermElabM Unit := do
|
| 864 |
+
registerLevelMVarErrorInfo { lctx := (← getLCtx), expr, ref, msgData? }
|
| 865 |
+
|
| 866 |
+
def exposeLevelMVars (e : Expr) : MetaM Expr :=
|
| 867 |
+
Core.transform e
|
| 868 |
+
(post := fun e => do
|
| 869 |
+
match e with
|
| 870 |
+
| .const _ us => return .done <| if us.any (·.isMVar) then e.setPPUniverses true else e
|
| 871 |
+
| .sort u => return .done <| if u.isMVar then e.setPPUniverses true else e
|
| 872 |
+
| .lam _ t _ _ => return .done <| if t.hasLevelMVar then e.setOption `pp.funBinderTypes true else e
|
| 873 |
+
| .letE _ t _ _ _ => return .done <| if t.hasLevelMVar then e.setOption `pp.letVarTypes true else e
|
| 874 |
+
| _ => return .done e)
|
| 875 |
+
|
| 876 |
+
def LevelMVarErrorInfo.logError (levelMVarErrorInfo : LevelMVarErrorInfo) : TermElabM Unit :=
|
| 877 |
+
Meta.withLCtx levelMVarErrorInfo.lctx {} do
|
| 878 |
+
let e' ← exposeLevelMVars (← instantiateMVars levelMVarErrorInfo.expr)
|
| 879 |
+
let msg := levelMVarErrorInfo.msgData?.getD m!"don't know how to synthesize universe level metavariables"
|
| 880 |
+
let msg := m!"{msg}{indentExpr e'}"
|
| 881 |
+
logErrorAt levelMVarErrorInfo.ref msg
|
| 882 |
+
|
| 883 |
+
/--
|
| 884 |
+
Try to log errors for unassigned level metavariables `pendingLevelMVarIds`.
|
| 885 |
+
|
| 886 |
+
Returns `true` if there are any relevant `LevelMVarErrorInfo`s and we should "abort" the declaration.
|
| 887 |
+
|
| 888 |
+
Remark: we only log unassigned level metavariables as new errors if no error has been logged so far.
|
| 889 |
+
-/
|
| 890 |
+
def logUnassignedLevelMVarsUsingErrorInfos (pendingLevelMVarIds : Array LMVarId) : TermElabM Bool := do
|
| 891 |
+
if pendingLevelMVarIds.isEmpty then
|
| 892 |
+
return false
|
| 893 |
+
else
|
| 894 |
+
let hasOtherErrors ← MonadLog.hasErrors
|
| 895 |
+
let mut hasNewErrors := false
|
| 896 |
+
let mut errors : Array LevelMVarErrorInfo := #[]
|
| 897 |
+
for levelMVarErrorInfo in (← get).levelMVarErrorInfos do
|
| 898 |
+
let e ← instantiateMVars levelMVarErrorInfo.expr
|
| 899 |
+
let lmvars := (collectLevelMVars {} e).result
|
| 900 |
+
if lmvars.any pendingLevelMVarIds.contains then do
|
| 901 |
+
unless hasOtherErrors do
|
| 902 |
+
errors := errors.push levelMVarErrorInfo
|
| 903 |
+
hasNewErrors := true
|
| 904 |
+
for error in errors do
|
| 905 |
+
error.logError
|
| 906 |
+
return hasNewErrors
|
| 907 |
+
|
| 908 |
+
/-- Ensure metavariables registered using `registerMVarErrorInfos` (and used in the given declaration) have been assigned. -/
|
| 909 |
+
def ensureNoUnassignedMVars (decl : Declaration) : TermElabM Unit := do
|
| 910 |
+
let pendingMVarIds ← getMVarsAtDecl decl
|
| 911 |
+
if (← logUnassignedUsingErrorInfos pendingMVarIds) then
|
| 912 |
+
throwAbortCommand
|
| 913 |
+
|
| 914 |
+
/--
|
| 915 |
+
Execute `x` without allowing it to postpone elaboration tasks.
|
| 916 |
+
That is, `tryPostpone` is a noop. -/
|
| 917 |
+
def withoutPostponing (x : TermElabM α) : TermElabM α :=
|
| 918 |
+
withReader (fun ctx => { ctx with mayPostpone := false }) x
|
| 919 |
+
|
| 920 |
+
/-- Creates syntax for `(` <ident> `:` <type> `)` -/
|
| 921 |
+
def mkExplicitBinder (ident : Syntax) (type : Syntax) : Syntax :=
|
| 922 |
+
mkNode ``Lean.Parser.Term.explicitBinder #[mkAtom "(", mkNullNode #[ident], mkNullNode #[mkAtom ":", type], mkNullNode, mkAtom ")"]
|
| 923 |
+
|
| 924 |
+
/--
|
| 925 |
+
Convert unassigned universe level metavariables into parameters.
|
| 926 |
+
The new parameter names are fresh names of the form `u_i` with regard to `ctx.levelNames`, which is updated with the new names. -/
|
| 927 |
+
def levelMVarToParam (e : Expr) (except : LMVarId → Bool := fun _ => false) : TermElabM Expr := do
|
| 928 |
+
let levelNames ← getLevelNames
|
| 929 |
+
let r := (← getMCtx).levelMVarToParam (fun n => levelNames.elem n) except e `u 1
|
| 930 |
+
-- Recall that the most recent universe is the first element of the field `levelNames`.
|
| 931 |
+
setLevelNames (r.newParamNames.reverse.toList ++ levelNames)
|
| 932 |
+
setMCtx r.mctx
|
| 933 |
+
return r.expr
|
| 934 |
+
|
| 935 |
+
/--
|
| 936 |
+
Creates a fresh inaccessible binder name based on `x`.
|
| 937 |
+
Equivalent to ``Lean.Core.mkFreshUserName `x``.
|
| 938 |
+
|
| 939 |
+
Do not confuse with `Lean.mkFreshId`, for creating fresh free variable and metavariable ids.
|
| 940 |
+
-/
|
| 941 |
+
def mkFreshBinderName [Monad m] [MonadQuotation m] : m Name :=
|
| 942 |
+
withFreshMacroScope <| MonadQuotation.addMacroScope `x
|
| 943 |
+
|
| 944 |
+
/--
|
| 945 |
+
Auxiliary method for creating a `Syntax.ident` containing
|
| 946 |
+
a fresh name. This method is intended for creating fresh binder names.
|
| 947 |
+
It is just a thin layer on top of `mkFreshUserName`. -/
|
| 948 |
+
def mkFreshIdent [Monad m] [MonadQuotation m] (ref : Syntax) (canonical := false) : m Ident :=
|
| 949 |
+
return mkIdentFrom ref (← mkFreshBinderName) canonical
|
| 950 |
+
|
| 951 |
+
private def applyAttributesCore
|
| 952 |
+
(declName : Name) (attrs : Array Attribute)
|
| 953 |
+
(applicationTime? : Option AttributeApplicationTime) : TermElabM Unit := do profileitM Exception "attribute application" (← getOptions) do
|
| 954 |
+
/-
|
| 955 |
+
Remark: if the declaration has syntax errors, `declName` may be `.anonymous` see issue #4309
|
| 956 |
+
In this case, we skip attribute application.
|
| 957 |
+
-/
|
| 958 |
+
if declName == .anonymous then
|
| 959 |
+
return
|
| 960 |
+
withDeclName declName do
|
| 961 |
+
for attr in attrs do
|
| 962 |
+
withTraceNode `Elab.attribute (fun _ => pure m!"applying [{attr.stx}]") do
|
| 963 |
+
withRef attr.stx do withLogging do
|
| 964 |
+
let env ← getEnv
|
| 965 |
+
match getAttributeImpl env attr.name with
|
| 966 |
+
| Except.error errMsg => throwError errMsg
|
| 967 |
+
| Except.ok attrImpl =>
|
| 968 |
+
let runAttr := attrImpl.add declName attr.stx attr.kind
|
| 969 |
+
let runAttr := do
|
| 970 |
+
-- not truly an elaborator, but a sensible target for go-to-definition
|
| 971 |
+
let elaborator := attrImpl.ref
|
| 972 |
+
if (← getInfoState).enabled then
|
| 973 |
+
withInfoContext (mkInfo := return .ofCommandInfo { elaborator, stx := attr.stx }) do
|
| 974 |
+
try runAttr
|
| 975 |
+
finally if attr.stx[0].isIdent || attr.stx[0].isAtom then
|
| 976 |
+
-- Add an additional node over the leading identifier if there is one to make it look more function-like.
|
| 977 |
+
-- Do this last because we want user-created infos to take precedence
|
| 978 |
+
pushInfoLeaf <| .ofCommandInfo { elaborator, stx := attr.stx[0] }
|
| 979 |
+
else
|
| 980 |
+
runAttr
|
| 981 |
+
match applicationTime? with
|
| 982 |
+
| none => runAttr
|
| 983 |
+
| some applicationTime =>
|
| 984 |
+
if applicationTime == attrImpl.applicationTime then
|
| 985 |
+
runAttr
|
| 986 |
+
|
| 987 |
+
/-- Apply given attributes **at** a given application time -/
|
| 988 |
+
def applyAttributesAt (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : TermElabM Unit :=
|
| 989 |
+
applyAttributesCore declName attrs applicationTime
|
| 990 |
+
|
| 991 |
+
def applyAttributes (declName : Name) (attrs : Array Attribute) : TermElabM Unit :=
|
| 992 |
+
applyAttributesCore declName attrs none
|
| 993 |
+
|
| 994 |
+
def mkTypeMismatchError (header? : Option MessageData) (e : Expr) (eType : Expr) (expectedType : Expr) : MetaM MessageData := do
|
| 995 |
+
let header : MessageData := match header? with
|
| 996 |
+
| some header => m!"{header} "
|
| 997 |
+
| none => m!"type mismatch{indentExpr e}\n"
|
| 998 |
+
return m!"{header}{← mkHasTypeButIsExpectedMsg eType expectedType}"
|
| 999 |
+
|
| 1000 |
+
def throwTypeMismatchError (header? : Option MessageData) (expectedType : Expr) (eType : Expr) (e : Expr)
|
| 1001 |
+
(f? : Option Expr := none) (_extraMsg? : Option MessageData := none) : MetaM α := do
|
| 1002 |
+
/-
|
| 1003 |
+
We ignore `extraMsg?` for now. In all our tests, it contained no useful information. It was
|
| 1004 |
+
always of the form:
|
| 1005 |
+
```
|
| 1006 |
+
failed to synthesize instance
|
| 1007 |
+
CoeT <eType> <e> <expectedType>
|
| 1008 |
+
```
|
| 1009 |
+
We should revisit this decision in the future and decide whether it may contain useful information
|
| 1010 |
+
or not. -/
|
| 1011 |
+
let extraMsg := Format.nil
|
| 1012 |
+
/-
|
| 1013 |
+
let extraMsg : MessageData := match extraMsg? with
|
| 1014 |
+
| none => Format.nil
|
| 1015 |
+
| some extraMsg => Format.line ++ extraMsg;
|
| 1016 |
+
-/
|
| 1017 |
+
match f? with
|
| 1018 |
+
| none => throwError "{← mkTypeMismatchError header? e eType expectedType}{extraMsg}"
|
| 1019 |
+
| some f => Meta.throwAppTypeMismatch f e
|
| 1020 |
+
|
| 1021 |
+
def withoutMacroStackAtErr (x : TermElabM α) : TermElabM α :=
|
| 1022 |
+
withTheReader Core.Context (fun (ctx : Core.Context) => { ctx with options := pp.macroStack.set ctx.options false }) x
|
| 1023 |
+
|
| 1024 |
+
namespace ContainsPendingMVar
|
| 1025 |
+
|
| 1026 |
+
abbrev M := MonadCacheT Expr Unit (OptionT MetaM)
|
| 1027 |
+
|
| 1028 |
+
/-- See `containsPostponedTerm` -/
|
| 1029 |
+
partial def visit (e : Expr) : M Unit := do
|
| 1030 |
+
checkCache e fun _ => do
|
| 1031 |
+
match e with
|
| 1032 |
+
| .forallE _ d b _ => visit d; visit b
|
| 1033 |
+
| .lam _ d b _ => visit d; visit b
|
| 1034 |
+
| .letE _ t v b _ => visit t; visit v; visit b
|
| 1035 |
+
| .app f a => visit f; visit a
|
| 1036 |
+
| .mdata _ b => visit b
|
| 1037 |
+
| .proj _ _ b => visit b
|
| 1038 |
+
| .fvar fvarId .. =>
|
| 1039 |
+
match (← fvarId.getDecl) with
|
| 1040 |
+
| .cdecl .. => return ()
|
| 1041 |
+
| .ldecl (value := v) .. => visit v
|
| 1042 |
+
| .mvar mvarId .. =>
|
| 1043 |
+
let e' ← instantiateMVars e
|
| 1044 |
+
if e' != e then
|
| 1045 |
+
visit e'
|
| 1046 |
+
else
|
| 1047 |
+
match (← getDelayedMVarAssignment? mvarId) with
|
| 1048 |
+
| some d => visit (mkMVar d.mvarIdPending)
|
| 1049 |
+
| none => failure
|
| 1050 |
+
| _ => return ()
|
| 1051 |
+
|
| 1052 |
+
end ContainsPendingMVar
|
| 1053 |
+
|
| 1054 |
+
/-- Return `true` if `e` contains a pending metavariable. Remark: it also visits let-declarations. -/
|
| 1055 |
+
def containsPendingMVar (e : Expr) : MetaM Bool := do
|
| 1056 |
+
match (← ContainsPendingMVar.visit e |>.run.run) with
|
| 1057 |
+
| some _ => return false
|
| 1058 |
+
| none => return true
|
| 1059 |
+
|
| 1060 |
+
/--
|
| 1061 |
+
Try to synthesize metavariable using type class resolution.
|
| 1062 |
+
This method assumes the local context and local instances of `instMVar` coincide
|
| 1063 |
+
with the current local context and local instances.
|
| 1064 |
+
Return `true` if the instance was synthesized successfully, and `false` if
|
| 1065 |
+
the instance contains unassigned metavariables that are blocking the type class
|
| 1066 |
+
resolution procedure. Throw an exception if resolution or assignment irrevocably fails.
|
| 1067 |
+
|
| 1068 |
+
If `extraErrorMsg?` is not none, it contains additional information that should be attached
|
| 1069 |
+
to type class synthesis failures.
|
| 1070 |
+
-/
|
| 1071 |
+
def synthesizeInstMVarCore (instMVar : MVarId) (maxResultSize? : Option Nat := none) (extraErrorMsg? : Option MessageData := none): TermElabM Bool := do
|
| 1072 |
+
let extraErrorMsg := extraMsgToMsg extraErrorMsg?
|
| 1073 |
+
let instMVarDecl ← getMVarDecl instMVar
|
| 1074 |
+
let type := instMVarDecl.type
|
| 1075 |
+
let type ← instantiateMVars type
|
| 1076 |
+
let result ← trySynthInstance type maxResultSize?
|
| 1077 |
+
match result with
|
| 1078 |
+
| LOption.some val =>
|
| 1079 |
+
if (← instMVar.isAssigned) then
|
| 1080 |
+
let oldVal ← instantiateMVars (mkMVar instMVar)
|
| 1081 |
+
unless (← isDefEq oldVal val) do
|
| 1082 |
+
if (← containsPendingMVar oldVal <||> containsPendingMVar val) then
|
| 1083 |
+
/- If `val` or `oldVal` contains metavariables directly or indirectly (e.g., in a let-declaration),
|
| 1084 |
+
we return `false` to indicate we should try again later. This is very coarse grain since
|
| 1085 |
+
the metavariable may not be responsible for the failure. We should refine the test in the future if needed.
|
| 1086 |
+
This check has been added to address dependencies between postponed metavariables. The following
|
| 1087 |
+
example demonstrates the issue fixed by this test.
|
| 1088 |
+
```
|
| 1089 |
+
structure Point where
|
| 1090 |
+
x : Nat
|
| 1091 |
+
y : Nat
|
| 1092 |
+
|
| 1093 |
+
def Point.compute (p : Point) : Point :=
|
| 1094 |
+
let p := { p with x := 1 }
|
| 1095 |
+
let p := { p with y := 0 }
|
| 1096 |
+
if (p.x - p.y) > p.x then p else p
|
| 1097 |
+
```
|
| 1098 |
+
The `isDefEq` test above fails for `Decidable (p.x - p.y ≤ p.x)` when the structure instance assigned to
|
| 1099 |
+
`p` has not been elaborated yet.
|
| 1100 |
+
-/
|
| 1101 |
+
return false -- we will try again later
|
| 1102 |
+
let oldValType ← inferType oldVal
|
| 1103 |
+
let valType ← inferType val
|
| 1104 |
+
unless (← isDefEq oldValType valType) do
|
| 1105 |
+
let (oldValType, valType) ← addPPExplicitToExposeDiff oldValType valType
|
| 1106 |
+
throwError "synthesized type class instance type is not definitionally equal to expected type, synthesized{indentExpr val}\nhas type{indentExpr valType}\nexpected{indentExpr oldValType}{extraErrorMsg}"
|
| 1107 |
+
let (oldVal, val) ← addPPExplicitToExposeDiff oldVal val
|
| 1108 |
+
throwError "synthesized type class instance is not definitionally equal to expression inferred by typing rules, synthesized{indentExpr val}\ninferred{indentExpr oldVal}{extraErrorMsg}"
|
| 1109 |
+
else
|
| 1110 |
+
unless (← isDefEq (mkMVar instMVar) val) do
|
| 1111 |
+
throwError "failed to assign synthesized type class instance{indentExpr val}{extraErrorMsg}"
|
| 1112 |
+
return true
|
| 1113 |
+
| .undef => return false -- we will try later
|
| 1114 |
+
| .none =>
|
| 1115 |
+
if (← read).ignoreTCFailures then
|
| 1116 |
+
return false
|
| 1117 |
+
else
|
| 1118 |
+
throwError "failed to synthesize{indentExpr type}{extraErrorMsg}{useDiagnosticMsg}"
|
| 1119 |
+
|
| 1120 |
+
def mkCoe (expectedType : Expr) (e : Expr) (f? : Option Expr := none) (errorMsgHeader? : Option String := none)
|
| 1121 |
+
(mkErrorMsg? : Option (MVarId → (expectedType e : Expr) → MetaM MessageData) := none)
|
| 1122 |
+
(mkImmedErrorMsg? : Option ((errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData) := none) : TermElabM Expr := do
|
| 1123 |
+
withTraceNode `Elab.coe (fun _ => return m!"adding coercion for {e} : {← inferType e} =?= {expectedType}") do
|
| 1124 |
+
try
|
| 1125 |
+
withoutMacroStackAtErr do
|
| 1126 |
+
match ← coerce? e expectedType with
|
| 1127 |
+
| .some eNew => return eNew
|
| 1128 |
+
| .none => failure
|
| 1129 |
+
| .undef =>
|
| 1130 |
+
let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
|
| 1131 |
+
registerSyntheticMVarWithCurrRef mvarAux.mvarId! (.coe errorMsgHeader? expectedType e f? mkErrorMsg?)
|
| 1132 |
+
return mvarAux
|
| 1133 |
+
catch
|
| 1134 |
+
| .error _ msg =>
|
| 1135 |
+
if let some mkImmedErrorMsg := mkImmedErrorMsg? then
|
| 1136 |
+
throwError (← mkImmedErrorMsg msg expectedType e)
|
| 1137 |
+
else
|
| 1138 |
+
throwTypeMismatchError errorMsgHeader? expectedType (← inferType e) e f? msg
|
| 1139 |
+
| _ =>
|
| 1140 |
+
if let some mkImmedErrorMsg := mkImmedErrorMsg? then
|
| 1141 |
+
throwError (← mkImmedErrorMsg none expectedType e)
|
| 1142 |
+
else
|
| 1143 |
+
throwTypeMismatchError errorMsgHeader? expectedType (← inferType e) e f?
|
| 1144 |
+
|
| 1145 |
+
def mkCoeWithErrorMsgs (expectedType : Expr) (e : Expr)
|
| 1146 |
+
(mkImmedErrorMsg : (errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData)
|
| 1147 |
+
(mkErrorMsg : MVarId → (expectedType e : Expr) → MetaM MessageData) : TermElabM Expr := do
|
| 1148 |
+
mkCoe expectedType e (mkImmedErrorMsg? := mkImmedErrorMsg) (mkErrorMsg? := mkErrorMsg)
|
| 1149 |
+
|
| 1150 |
+
/--
|
| 1151 |
+
If `expectedType?` is `some t`, then ensures `t` and `eType` are definitionally equal by inserting a coercion if necessary.
|
| 1152 |
+
|
| 1153 |
+
Argument `f?` is used only for generating error messages when inserting coercions fails.
|
| 1154 |
+
-/
|
| 1155 |
+
def ensureHasType (expectedType? : Option Expr) (e : Expr)
|
| 1156 |
+
(errorMsgHeader? : Option String := none) (f? : Option Expr := none) : TermElabM Expr := do
|
| 1157 |
+
let some expectedType := expectedType? | return e
|
| 1158 |
+
if (← isDefEq (← inferType e) expectedType) then
|
| 1159 |
+
return e
|
| 1160 |
+
else
|
| 1161 |
+
mkCoe expectedType e f? errorMsgHeader?
|
| 1162 |
+
|
| 1163 |
+
def ensureHasTypeWithErrorMsgs (expectedType? : Option Expr) (e : Expr)
|
| 1164 |
+
(mkImmedErrorMsg : (errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData)
|
| 1165 |
+
(mkErrorMsg : MVarId → (expectedType e : Expr) → MetaM MessageData) : TermElabM Expr := do
|
| 1166 |
+
let some expectedType := expectedType? | return e
|
| 1167 |
+
if (← isDefEq (← inferType e) expectedType) then
|
| 1168 |
+
return e
|
| 1169 |
+
else
|
| 1170 |
+
mkCoeWithErrorMsgs expectedType e mkImmedErrorMsg mkErrorMsg
|
| 1171 |
+
|
| 1172 |
+
/--
|
| 1173 |
+
Create a synthetic sorry for the given expected type. If `expectedType? = none`, then a fresh
|
| 1174 |
+
metavariable is created to represent the type.
|
| 1175 |
+
-/
|
| 1176 |
+
private def mkSyntheticSorryFor (expectedType? : Option Expr) : TermElabM Expr := do
|
| 1177 |
+
let expectedType ← match expectedType? with
|
| 1178 |
+
| none => mkFreshTypeMVar
|
| 1179 |
+
| some expectedType => pure expectedType
|
| 1180 |
+
mkLabeledSorry expectedType (synthetic := true) (unique := false)
|
| 1181 |
+
|
| 1182 |
+
/--
|
| 1183 |
+
Log the given exception, and create a synthetic sorry for representing the failed
|
| 1184 |
+
elaboration step with exception `ex`.
|
| 1185 |
+
-/
|
| 1186 |
+
def exceptionToSorry (ex : Exception) (expectedType? : Option Expr) : TermElabM Expr := do
|
| 1187 |
+
let syntheticSorry ← mkSyntheticSorryFor expectedType?
|
| 1188 |
+
logException ex
|
| 1189 |
+
pure syntheticSorry
|
| 1190 |
+
|
| 1191 |
+
/-- If `mayPostpone == true`, throw `Exception.postpone`. -/
|
| 1192 |
+
def tryPostpone : TermElabM Unit := do
|
| 1193 |
+
if (← read).mayPostpone then
|
| 1194 |
+
throwPostpone
|
| 1195 |
+
|
| 1196 |
+
/-- Return `true` if `e` reduces (by unfolding only `[reducible]` declarations) to `?m ...` -/
|
| 1197 |
+
def isMVarApp (e : Expr) : TermElabM Bool :=
|
| 1198 |
+
return (← whnfR e).getAppFn.isMVar
|
| 1199 |
+
|
| 1200 |
+
/-- If `mayPostpone == true` and `e`'s head is a metavariable, throw `Exception.postpone`. -/
|
| 1201 |
+
def tryPostponeIfMVar (e : Expr) : TermElabM Unit := do
|
| 1202 |
+
if (← isMVarApp e) then
|
| 1203 |
+
tryPostpone
|
| 1204 |
+
|
| 1205 |
+
/-- If `e? = some e`, then `tryPostponeIfMVar e`, otherwise it is just `tryPostpone`. -/
|
| 1206 |
+
def tryPostponeIfNoneOrMVar (e? : Option Expr) : TermElabM Unit :=
|
| 1207 |
+
match e? with
|
| 1208 |
+
| some e => tryPostponeIfMVar e
|
| 1209 |
+
| none => tryPostpone
|
| 1210 |
+
|
| 1211 |
+
/--
|
| 1212 |
+
Throws `Exception.postpone`, if `expectedType?` contains unassigned metavariables.
|
| 1213 |
+
It is a noop if `mayPostpone == false`.
|
| 1214 |
+
-/
|
| 1215 |
+
def tryPostponeIfHasMVars? (expectedType? : Option Expr) : TermElabM (Option Expr) := do
|
| 1216 |
+
tryPostponeIfNoneOrMVar expectedType?
|
| 1217 |
+
let some expectedType := expectedType? | return none
|
| 1218 |
+
let expectedType ← instantiateMVars expectedType
|
| 1219 |
+
if expectedType.hasExprMVar then
|
| 1220 |
+
tryPostpone
|
| 1221 |
+
return none
|
| 1222 |
+
return some expectedType
|
| 1223 |
+
|
| 1224 |
+
/--
|
| 1225 |
+
Throws `Exception.postpone`, if `expectedType?` contains unassigned metavariables.
|
| 1226 |
+
If `mayPostpone == false`, it throws error `msg`.
|
| 1227 |
+
-/
|
| 1228 |
+
def tryPostponeIfHasMVars (expectedType? : Option Expr) (msg : String) : TermElabM Expr := do
|
| 1229 |
+
let some expectedType ← tryPostponeIfHasMVars? expectedType? |
|
| 1230 |
+
throwError "{msg}, expected type contains metavariables{indentD expectedType?}"
|
| 1231 |
+
return expectedType
|
| 1232 |
+
|
| 1233 |
+
def withExpectedType (expectedType? : Option Expr) (x : Expr → TermElabM Expr) : TermElabM Expr := do
|
| 1234 |
+
tryPostponeIfNoneOrMVar expectedType?
|
| 1235 |
+
let some expectedType ← pure expectedType?
|
| 1236 |
+
| throwError "expected type must be known"
|
| 1237 |
+
x expectedType
|
| 1238 |
+
|
| 1239 |
+
/--
|
| 1240 |
+
Save relevant context for term elaboration postponement.
|
| 1241 |
+
-/
|
| 1242 |
+
def saveContext : TermElabM SavedContext :=
|
| 1243 |
+
return {
|
| 1244 |
+
macroStack := (← read).macroStack
|
| 1245 |
+
declName? := (← read).declName?
|
| 1246 |
+
options := (← getOptions)
|
| 1247 |
+
openDecls := (← getOpenDecls)
|
| 1248 |
+
errToSorry := (← read).errToSorry
|
| 1249 |
+
levelNames := (← get).levelNames
|
| 1250 |
+
}
|
| 1251 |
+
|
| 1252 |
+
/--
|
| 1253 |
+
Execute `x` with the context saved using `saveContext`.
|
| 1254 |
+
-/
|
| 1255 |
+
def withSavedContext (savedCtx : SavedContext) (x : TermElabM α) : TermElabM α := do
|
| 1256 |
+
withReader (fun ctx => { ctx with declName? := savedCtx.declName?, macroStack := savedCtx.macroStack, errToSorry := savedCtx.errToSorry }) <|
|
| 1257 |
+
withTheReader Core.Context (fun ctx => { ctx with options := savedCtx.options, openDecls := savedCtx.openDecls }) <|
|
| 1258 |
+
withLevelNames savedCtx.levelNames x
|
| 1259 |
+
|
| 1260 |
+
/--
|
| 1261 |
+
Delay the elaboration of `stx`, and return a fresh metavariable that works a placeholder.
|
| 1262 |
+
Remark: the caller is responsible for making sure the info tree is properly updated.
|
| 1263 |
+
This method is used only at `elabUsingElabFnsAux`.
|
| 1264 |
+
-/
|
| 1265 |
+
private def postponeElabTermCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
| 1266 |
+
trace[Elab.postpone] "{stx} : {expectedType?}"
|
| 1267 |
+
let mvar ← mkFreshExprMVar expectedType? MetavarKind.syntheticOpaque
|
| 1268 |
+
registerSyntheticMVar stx mvar.mvarId! (SyntheticMVarKind.postponed (← saveContext))
|
| 1269 |
+
return mvar
|
| 1270 |
+
|
| 1271 |
+
def getSyntheticMVarDecl? (mvarId : MVarId) : TermElabM (Option SyntheticMVarDecl) :=
|
| 1272 |
+
return (← get).syntheticMVars.find? mvarId
|
| 1273 |
+
|
| 1274 |
+
register_builtin_option debug.byAsSorry : Bool := {
|
| 1275 |
+
defValue := false
|
| 1276 |
+
group := "debug"
|
| 1277 |
+
descr := "replace `by ..` blocks with `sorry` IF the expected type is a proposition"
|
| 1278 |
+
}
|
| 1279 |
+
|
| 1280 |
+
/--
|
| 1281 |
+
Creates a new metavariable of type `type` that will be synthesized using the tactic code.
|
| 1282 |
+
The `tacticCode` syntax is the full `by ..` syntax.
|
| 1283 |
+
-/
|
| 1284 |
+
def mkTacticMVar (type : Expr) (tacticCode : Syntax) (kind : TacticMVarKind)
|
| 1285 |
+
(delayOnMVars := false) : TermElabM Expr := do
|
| 1286 |
+
if ← pure (debug.byAsSorry.get (← getOptions)) <&&> isProp type then
|
| 1287 |
+
withRef tacticCode <| mkLabeledSorry type false (unique := true)
|
| 1288 |
+
else
|
| 1289 |
+
let mvar ← mkFreshExprMVar type MetavarKind.syntheticOpaque
|
| 1290 |
+
let mvarId := mvar.mvarId!
|
| 1291 |
+
let ref ← getRef
|
| 1292 |
+
registerSyntheticMVar ref mvarId <| .tactic tacticCode (← saveContext) kind delayOnMVars
|
| 1293 |
+
return mvar
|
| 1294 |
+
|
| 1295 |
+
/--
|
| 1296 |
+
Create an auxiliary annotation to make sure we create an `Info` even if `e` is a metavariable.
|
| 1297 |
+
See `mkTermInfo`.
|
| 1298 |
+
|
| 1299 |
+
We use this function because some elaboration functions elaborate subterms that may not be immediately
|
| 1300 |
+
part of the resulting term. Example:
|
| 1301 |
+
```
|
| 1302 |
+
let_mvar% ?m := b; wait_if_type_mvar% ?m; body
|
| 1303 |
+
```
|
| 1304 |
+
If the type of `b` is not known, then `wait_if_type_mvar% ?m; body` is postponed and just returns a fresh
|
| 1305 |
+
metavariable `?n`. The elaborator for
|
| 1306 |
+
```
|
| 1307 |
+
let_mvar% ?m := b; wait_if_type_mvar% ?m; body
|
| 1308 |
+
```
|
| 1309 |
+
returns `mkSaveInfoAnnotation ?n` to make sure the info nodes created when elaborating `b` are "saved".
|
| 1310 |
+
This is a bit hackish, but elaborators like `let_mvar%` are rare.
|
| 1311 |
+
-/
|
| 1312 |
+
def mkSaveInfoAnnotation (e : Expr) : Expr :=
|
| 1313 |
+
if e.isMVar then
|
| 1314 |
+
mkAnnotation `save_info e
|
| 1315 |
+
else
|
| 1316 |
+
e
|
| 1317 |
+
|
| 1318 |
+
def isSaveInfoAnnotation? (e : Expr) : Option Expr :=
|
| 1319 |
+
annotation? `save_info e
|
| 1320 |
+
|
| 1321 |
+
partial def removeSaveInfoAnnotation (e : Expr) : Expr :=
|
| 1322 |
+
match isSaveInfoAnnotation? e with
|
| 1323 |
+
| some e => removeSaveInfoAnnotation e
|
| 1324 |
+
| _ => e
|
| 1325 |
+
|
| 1326 |
+
/--
|
| 1327 |
+
Return `some mvarId` if `e` corresponds to a hole that is going to be filled "later" by executing a tactic or resuming elaboration.
|
| 1328 |
+
|
| 1329 |
+
We do not save `ofTermInfo` for this kind of node in the `InfoTree`.
|
| 1330 |
+
-/
|
| 1331 |
+
def isTacticOrPostponedHole? (e : Expr) : TermElabM (Option MVarId) := do
|
| 1332 |
+
match e with
|
| 1333 |
+
| Expr.mvar mvarId =>
|
| 1334 |
+
match (← getSyntheticMVarDecl? mvarId) with
|
| 1335 |
+
| some { kind := .tactic .., .. } => return mvarId
|
| 1336 |
+
| some { kind := .postponed .., .. } => return mvarId
|
| 1337 |
+
| _ => return none
|
| 1338 |
+
| _ => pure none
|
| 1339 |
+
|
| 1340 |
+
def mkTermInfo (elaborator : Name) (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none)
|
| 1341 |
+
(lctx? : Option LocalContext := none) (isBinder := false) :
|
| 1342 |
+
TermElabM (Sum Info MVarId) := do
|
| 1343 |
+
match (← isTacticOrPostponedHole? e) with
|
| 1344 |
+
| some mvarId => return Sum.inr mvarId
|
| 1345 |
+
| none =>
|
| 1346 |
+
let e := removeSaveInfoAnnotation e
|
| 1347 |
+
return Sum.inl <| Info.ofTermInfo { elaborator, lctx := lctx?.getD (← getLCtx), expr := e, stx, expectedType?, isBinder }
|
| 1348 |
+
|
| 1349 |
+
def mkPartialTermInfo (elaborator : Name) (stx : Syntax) (expectedType? : Option Expr := none)
|
| 1350 |
+
(lctx? : Option LocalContext := none) :
|
| 1351 |
+
TermElabM Info := do
|
| 1352 |
+
return Info.ofPartialTermInfo { elaborator, lctx := lctx?.getD (← getLCtx), stx, expectedType? }
|
| 1353 |
+
|
| 1354 |
+
/--
|
| 1355 |
+
Pushes a new leaf node to the info tree associating the expression `e` to the syntax `stx`.
|
| 1356 |
+
As a result, when the user hovers over `stx` they will see the type of `e`, and if `e`
|
| 1357 |
+
is a constant they will see the constant's doc string.
|
| 1358 |
+
|
| 1359 |
+
* `expectedType?`: the expected type of `e` at the point of elaboration, if available
|
| 1360 |
+
* `lctx?`: the local context in which to interpret `e` (otherwise it will use `← getLCtx`)
|
| 1361 |
+
* `elaborator`: a declaration name used as an alternative target for go-to-definition
|
| 1362 |
+
* `isBinder`: if true, this will be treated as defining `e` (which should be a local constant)
|
| 1363 |
+
for the purpose of go-to-definition on local variables
|
| 1364 |
+
* `force`: In patterns, the effect of `addTermInfo` is usually suppressed and replaced
|
| 1365 |
+
by a `patternWithRef?` annotation which will be turned into a term info on the
|
| 1366 |
+
post-match-elaboration expression. This flag overrides that behavior and adds the term
|
| 1367 |
+
info immediately. (See https://github.com/leanprover/lean4/pull/1664.)
|
| 1368 |
+
-/
|
| 1369 |
+
def addTermInfo (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none)
|
| 1370 |
+
(lctx? : Option LocalContext := none) (elaborator := Name.anonymous)
|
| 1371 |
+
(isBinder := false) (force := false) : TermElabM Expr := do
|
| 1372 |
+
if (← read).inPattern && !force then
|
| 1373 |
+
return mkPatternWithRef e stx
|
| 1374 |
+
else
|
| 1375 |
+
discard <| withInfoContext'
|
| 1376 |
+
(pure ())
|
| 1377 |
+
(fun _ => mkTermInfo elaborator stx e expectedType? lctx? isBinder)
|
| 1378 |
+
(mkPartialTermInfo elaborator stx expectedType? lctx?)
|
| 1379 |
+
return e
|
| 1380 |
+
|
| 1381 |
+
def addTermInfo' (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none) (lctx? : Option LocalContext := none) (elaborator := Name.anonymous) (isBinder := false) : TermElabM Unit :=
|
| 1382 |
+
discard <| addTermInfo stx e expectedType? lctx? elaborator isBinder
|
| 1383 |
+
|
| 1384 |
+
def withInfoContext' (stx : Syntax) (x : TermElabM Expr)
|
| 1385 |
+
(mkInfo : Expr → TermElabM (Sum Info MVarId)) (mkInfoOnError : TermElabM Info) :
|
| 1386 |
+
TermElabM Expr := do
|
| 1387 |
+
if (← read).inPattern then
|
| 1388 |
+
let e ← x
|
| 1389 |
+
return mkPatternWithRef e stx
|
| 1390 |
+
else
|
| 1391 |
+
Elab.withInfoContext' x mkInfo mkInfoOnError
|
| 1392 |
+
|
| 1393 |
+
/-- Info node capturing `def/let rec` bodies, used by the unused variables linter. -/
|
| 1394 |
+
structure BodyInfo where
|
| 1395 |
+
/-- The body as a fully elaborated term. `none` if the body failed to elaborate. -/
|
| 1396 |
+
value? : Option Expr
|
| 1397 |
+
deriving TypeName
|
| 1398 |
+
|
| 1399 |
+
/-- Creates an `Info.ofCustomInfo` node backed by a `BodyInfo`. -/
|
| 1400 |
+
def mkBodyInfo (stx : Syntax) (value? : Option Expr) : Info :=
|
| 1401 |
+
.ofCustomInfo { stx, value := .mk { value? : BodyInfo } }
|
| 1402 |
+
|
| 1403 |
+
/-- Extracts a `BodyInfo` custom info. -/
|
| 1404 |
+
def getBodyInfo? : Info → Option BodyInfo
|
| 1405 |
+
| .ofCustomInfo { value, .. } => value.get? BodyInfo
|
| 1406 |
+
| _ => none
|
| 1407 |
+
|
| 1408 |
+
def withTermInfoContext' (elaborator : Name) (stx : Syntax) (x : TermElabM Expr)
|
| 1409 |
+
(expectedType? : Option Expr := none) (lctx? : Option LocalContext := none)
|
| 1410 |
+
(isBinder : Bool := false) :
|
| 1411 |
+
TermElabM Expr :=
|
| 1412 |
+
withInfoContext' stx x
|
| 1413 |
+
(mkTermInfo elaborator stx (expectedType? := expectedType?) (lctx? := lctx?) (isBinder := isBinder))
|
| 1414 |
+
(mkPartialTermInfo elaborator stx (expectedType? := expectedType?) (lctx? := lctx?))
|
| 1415 |
+
|
| 1416 |
+
/--
|
| 1417 |
+
Postpone the elaboration of `stx`, return a metavariable that acts as a placeholder, and
|
| 1418 |
+
ensures the info tree is updated and a hole id is introduced.
|
| 1419 |
+
When `stx` is elaborated, new info nodes are created and attached to the new hole id in the info tree.
|
| 1420 |
+
-/
|
| 1421 |
+
def postponeElabTerm (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
| 1422 |
+
withTermInfoContext' .anonymous stx (expectedType? := expectedType?) do
|
| 1423 |
+
postponeElabTermCore stx expectedType?
|
| 1424 |
+
|
| 1425 |
+
/--
|
| 1426 |
+
Helper function for `elabTerm` that tries the registered elaboration functions for `stxNode` kind until it finds one that supports the syntax or
|
| 1427 |
+
an error is found. -/
|
| 1428 |
+
private def elabUsingElabFnsAux (s : SavedState) (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone : Bool)
|
| 1429 |
+
: List (KeyedDeclsAttribute.AttributeEntry TermElab) → TermElabM Expr
|
| 1430 |
+
| [] => do throwError "unexpected syntax{indentD stx}"
|
| 1431 |
+
| (elabFn::elabFns) =>
|
| 1432 |
+
try
|
| 1433 |
+
-- record elaborator in info tree, but only when not backtracking to other elaborators (outer `try`)
|
| 1434 |
+
withTermInfoContext' elabFn.declName stx (expectedType? := expectedType?)
|
| 1435 |
+
(try
|
| 1436 |
+
elabFn.value stx expectedType?
|
| 1437 |
+
catch ex => match ex with
|
| 1438 |
+
| .error .. =>
|
| 1439 |
+
if (← read).errToSorry then
|
| 1440 |
+
exceptionToSorry ex expectedType?
|
| 1441 |
+
else
|
| 1442 |
+
throw ex
|
| 1443 |
+
| .internal id _ =>
|
| 1444 |
+
if (← read).errToSorry && id == abortTermExceptionId then
|
| 1445 |
+
exceptionToSorry ex expectedType?
|
| 1446 |
+
else if id == unsupportedSyntaxExceptionId then
|
| 1447 |
+
throw ex -- to outer try
|
| 1448 |
+
else if catchExPostpone && id == postponeExceptionId then
|
| 1449 |
+
/- If `elab` threw `Exception.postpone`, we reset any state modifications.
|
| 1450 |
+
For example, we want to make sure pending synthetic metavariables created by `elab` before
|
| 1451 |
+
it threw `Exception.postpone` are discarded.
|
| 1452 |
+
Note that we are also discarding the messages created by `elab`.
|
| 1453 |
+
|
| 1454 |
+
For example, consider the expression.
|
| 1455 |
+
`((f.x a1).x a2).x a3`
|
| 1456 |
+
Now, suppose the elaboration of `f.x a1` produces an `Exception.postpone`.
|
| 1457 |
+
Then, a new metavariable `?m` is created. Then, `?m.x a2` also throws `Exception.postpone`
|
| 1458 |
+
because the type of `?m` is not yet known. Then another, metavariable `?n` is created, and
|
| 1459 |
+
finally `?n.x a3` also throws `Exception.postpone`. If we did not restore the state, we would
|
| 1460 |
+
keep "dead" metavariables `?m` and `?n` on the pending synthetic metavariable list. This is
|
| 1461 |
+
wasteful because when we resume the elaboration of `((f.x a1).x a2).x a3`, we start it from scratch
|
| 1462 |
+
and new metavariables are created for the nested functions. -/
|
| 1463 |
+
s.restore
|
| 1464 |
+
postponeElabTermCore stx expectedType?
|
| 1465 |
+
else
|
| 1466 |
+
throw ex)
|
| 1467 |
+
catch ex => match ex with
|
| 1468 |
+
| .internal id _ =>
|
| 1469 |
+
if id == unsupportedSyntaxExceptionId then
|
| 1470 |
+
s.restore -- also removes the info tree created above
|
| 1471 |
+
elabUsingElabFnsAux s stx expectedType? catchExPostpone elabFns
|
| 1472 |
+
else
|
| 1473 |
+
throw ex
|
| 1474 |
+
| _ => throw ex
|
| 1475 |
+
|
| 1476 |
+
private def elabUsingElabFns (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone : Bool) : TermElabM Expr := do
|
| 1477 |
+
let s ← saveState
|
| 1478 |
+
let k := stx.getKind
|
| 1479 |
+
match termElabAttribute.getEntries (← getEnv) k with
|
| 1480 |
+
| [] => throwError "elaboration function for '{k}' has not been implemented{indentD stx}"
|
| 1481 |
+
| elabFns => elabUsingElabFnsAux s stx expectedType? catchExPostpone elabFns
|
| 1482 |
+
|
| 1483 |
+
instance : MonadMacroAdapter TermElabM where
|
| 1484 |
+
getCurrMacroScope := getCurrMacroScope
|
| 1485 |
+
getNextMacroScope := return (← getThe Core.State).nextMacroScope
|
| 1486 |
+
setNextMacroScope next := modifyThe Core.State fun s => { s with nextMacroScope := next }
|
| 1487 |
+
|
| 1488 |
+
private def isExplicit (stx : Syntax) : Bool :=
|
| 1489 |
+
match stx with
|
| 1490 |
+
| `(@$_) => true
|
| 1491 |
+
| _ => false
|
| 1492 |
+
|
| 1493 |
+
private def isExplicitApp (stx : Syntax) : Bool :=
|
| 1494 |
+
stx.getKind == ``Lean.Parser.Term.app && isExplicit stx[0]
|
| 1495 |
+
|
| 1496 |
+
/--
|
| 1497 |
+
Return true if `stx` is a lambda abstraction containing a `{}` or `[]` binder annotation.
|
| 1498 |
+
Example: `fun {α} (a : α) => a` -/
|
| 1499 |
+
private def isLambdaWithImplicit (stx : Syntax) : Bool :=
|
| 1500 |
+
match stx with
|
| 1501 |
+
| `(fun $binders* => $_) => binders.raw.any fun b => b.isOfKind ``Lean.Parser.Term.implicitBinder || b.isOfKind `Lean.Parser.Term.instBinder
|
| 1502 |
+
| _ => false
|
| 1503 |
+
|
| 1504 |
+
private partial def dropTermParens : Syntax → Syntax := fun stx =>
|
| 1505 |
+
match stx with
|
| 1506 |
+
| `(($stx)) => dropTermParens stx
|
| 1507 |
+
| _ => stx
|
| 1508 |
+
|
| 1509 |
+
private def isHole (stx : Syntax) : Bool :=
|
| 1510 |
+
stx.isOfKind ``Lean.Parser.Term.hole || stx.isOfKind ``Lean.Parser.Term.syntheticHole
|
| 1511 |
+
|
| 1512 |
+
private def isTacticBlock (stx : Syntax) : Bool :=
|
| 1513 |
+
match stx with
|
| 1514 |
+
| `(by $_:tacticSeq) => true
|
| 1515 |
+
| _ => false
|
| 1516 |
+
|
| 1517 |
+
private def isNoImplicitLambda (stx : Syntax) : Bool :=
|
| 1518 |
+
match stx with
|
| 1519 |
+
| `(no_implicit_lambda% $_:term) => true
|
| 1520 |
+
| _ => false
|
| 1521 |
+
|
| 1522 |
+
private def isTypeAscription (stx : Syntax) : Bool :=
|
| 1523 |
+
match stx with
|
| 1524 |
+
| `(($_ : $[$_]?)) => true
|
| 1525 |
+
| _ => false
|
| 1526 |
+
|
| 1527 |
+
def hasNoImplicitLambdaAnnotation (type : Expr) : Bool :=
|
| 1528 |
+
annotation? `noImplicitLambda type |>.isSome
|
| 1529 |
+
|
| 1530 |
+
def mkNoImplicitLambdaAnnotation (type : Expr) : Expr :=
|
| 1531 |
+
if hasNoImplicitLambdaAnnotation type then
|
| 1532 |
+
type
|
| 1533 |
+
else
|
| 1534 |
+
mkAnnotation `noImplicitLambda type
|
| 1535 |
+
|
| 1536 |
+
/-- Block usage of implicit lambdas if `stx` is `@f` or `@f arg1 ...` or `fun` with an implicit binder annotation. -/
|
| 1537 |
+
def blockImplicitLambda (stx : Syntax) : Bool :=
|
| 1538 |
+
let stx := dropTermParens stx
|
| 1539 |
+
-- TODO: make it extensible
|
| 1540 |
+
isExplicit stx || isExplicitApp stx || isLambdaWithImplicit stx || isHole stx || isTacticBlock stx ||
|
| 1541 |
+
isNoImplicitLambda stx || isTypeAscription stx
|
| 1542 |
+
|
| 1543 |
+
/-- Return true iff `stx` is a `Syntax.ident`, and it is a local variable. -/
|
| 1544 |
+
def isLocalIdent? (stx : Syntax) : TermElabM (Option Expr) :=
|
| 1545 |
+
match stx with
|
| 1546 |
+
| Syntax.ident _ _ val _ => do
|
| 1547 |
+
let r? ← resolveLocalName val
|
| 1548 |
+
match r? with
|
| 1549 |
+
| some (fvar, []) => return some fvar
|
| 1550 |
+
| _ => return none
|
| 1551 |
+
| _ => return none
|
| 1552 |
+
|
| 1553 |
+
inductive UseImplicitLambdaResult where
|
| 1554 |
+
| no
|
| 1555 |
+
| yes (expectedType : Expr)
|
| 1556 |
+
| postpone
|
| 1557 |
+
|
| 1558 |
+
/--
|
| 1559 |
+
Return normalized expected type if it is of the form `{a : α} → β` or `[a : α] → β` and
|
| 1560 |
+
`blockImplicitLambda stx` is not true, else return `none`.
|
| 1561 |
+
|
| 1562 |
+
Remark: implicit lambdas are not triggered by the strict implicit binder annotation `{{a : α}} → β`
|
| 1563 |
+
-/
|
| 1564 |
+
private def useImplicitLambda (stx : Syntax) (expectedType? : Option Expr) : TermElabM UseImplicitLambdaResult := do
|
| 1565 |
+
if blockImplicitLambda stx then
|
| 1566 |
+
return .no
|
| 1567 |
+
let some expectedType := expectedType? | return .no
|
| 1568 |
+
if hasNoImplicitLambdaAnnotation expectedType then
|
| 1569 |
+
return .no
|
| 1570 |
+
let expectedType ← whnfForall expectedType
|
| 1571 |
+
let .forallE _ _ _ c := expectedType | return .no
|
| 1572 |
+
unless c.isImplicit || c.isInstImplicit do
|
| 1573 |
+
return .no
|
| 1574 |
+
if let some x ← isLocalIdent? stx then
|
| 1575 |
+
if (← isMVarApp (← inferType x)) then
|
| 1576 |
+
/-
|
| 1577 |
+
If `stx` is a local variable without type information, then adding implicit lambdas makes elaboration fail.
|
| 1578 |
+
We should try to postpone elaboration until the type of the local variable becomes available, or disable
|
| 1579 |
+
implicit lambdas if we cannot postpone anymore.
|
| 1580 |
+
Here is an example where this special case is useful.
|
| 1581 |
+
```
|
| 1582 |
+
def foo2mk (_ : ∀ {α : Type} (a : α), a = a) : nat := 37
|
| 1583 |
+
example (x) : foo2mk x = foo2mk x := rfl
|
| 1584 |
+
```
|
| 1585 |
+
The example about would fail without this special case.
|
| 1586 |
+
The expected type would be `(a : α✝) → a = a`, where `α✝` is a new free variable introduced by the implicit lambda.
|
| 1587 |
+
Now, let `?m` be the type of `x`. Then, the constraint `?m =?= (a : α✝) → a = a` cannot be solved using the
|
| 1588 |
+
assignment `?m := (a : α✝) → a = a` since `α✝` is not in the scope of `?m`.
|
| 1589 |
+
|
| 1590 |
+
Note that, this workaround does not prevent the following example from failing.
|
| 1591 |
+
```
|
| 1592 |
+
example (x) : foo2mk (id x) = 37 := rfl
|
| 1593 |
+
```
|
| 1594 |
+
The user can write
|
| 1595 |
+
```
|
| 1596 |
+
example (x) : foo2mk (id @x) = 37 := rfl
|
| 1597 |
+
```
|
| 1598 |
+
-/
|
| 1599 |
+
return .postpone
|
| 1600 |
+
return .yes expectedType
|
| 1601 |
+
|
| 1602 |
+
private def decorateErrorMessageWithLambdaImplicitVars (ex : Exception) (impFVars : Array Expr) : TermElabM Exception := do
|
| 1603 |
+
match ex with
|
| 1604 |
+
| .error ref msg =>
|
| 1605 |
+
if impFVars.isEmpty then
|
| 1606 |
+
return Exception.error ref msg
|
| 1607 |
+
else
|
| 1608 |
+
let mut msg := m!"{msg}\nthe following variables have been introduced by the implicit lambda feature"
|
| 1609 |
+
for impFVar in impFVars do
|
| 1610 |
+
let auxMsg := m!"{impFVar} : {← inferType impFVar}"
|
| 1611 |
+
let auxMsg ← addMessageContext auxMsg
|
| 1612 |
+
msg := m!"{msg}{indentD auxMsg}"
|
| 1613 |
+
msg := m!"{msg}\nyou can disable implicit lambdas using `@` or writing a lambda expression with `\{}` or `[]` binder annotations."
|
| 1614 |
+
return Exception.error ref msg
|
| 1615 |
+
| _ => return ex
|
| 1616 |
+
|
| 1617 |
+
private def elabImplicitLambdaAux (stx : Syntax) (catchExPostpone : Bool) (expectedType : Expr) (impFVars : Array Expr) : TermElabM Expr := do
|
| 1618 |
+
let body ← elabUsingElabFns stx expectedType catchExPostpone
|
| 1619 |
+
try
|
| 1620 |
+
let body ← ensureHasType expectedType body
|
| 1621 |
+
let r ← mkLambdaFVars impFVars body
|
| 1622 |
+
trace[Elab.implicitForall] r
|
| 1623 |
+
return r
|
| 1624 |
+
catch ex =>
|
| 1625 |
+
throw (← decorateErrorMessageWithLambdaImplicitVars ex impFVars)
|
| 1626 |
+
|
| 1627 |
+
private partial def elabImplicitLambda (stx : Syntax) (catchExPostpone : Bool) (type : Expr) : TermElabM Expr :=
|
| 1628 |
+
loop type #[]
|
| 1629 |
+
where
|
| 1630 |
+
loop (type : Expr) (fvars : Array Expr) : TermElabM Expr := do
|
| 1631 |
+
match (← whnfForall type) with
|
| 1632 |
+
| .forallE n d b c =>
|
| 1633 |
+
if c.isExplicit then
|
| 1634 |
+
elabImplicitLambdaAux stx catchExPostpone type fvars
|
| 1635 |
+
else withFreshMacroScope do
|
| 1636 |
+
let n ← MonadQuotation.addMacroScope n
|
| 1637 |
+
withLocalDecl n c d fun fvar => do
|
| 1638 |
+
let type := b.instantiate1 fvar
|
| 1639 |
+
loop type (fvars.push fvar)
|
| 1640 |
+
| _ =>
|
| 1641 |
+
elabImplicitLambdaAux stx catchExPostpone type fvars
|
| 1642 |
+
|
| 1643 |
+
/-- Main loop for `elabTerm` -/
|
| 1644 |
+
private partial def elabTermAux (expectedType? : Option Expr) (catchExPostpone : Bool) (implicitLambda : Bool) : Syntax → TermElabM Expr
|
| 1645 |
+
| .missing => mkSyntheticSorryFor expectedType?
|
| 1646 |
+
| stx => withFreshMacroScope <| withIncRecDepth do
|
| 1647 |
+
withTraceNode `Elab.step (fun _ => return m!"expected type: {expectedType?}, term\n{stx}")
|
| 1648 |
+
(tag := stx.getKind.toString) do
|
| 1649 |
+
checkSystem "elaborator"
|
| 1650 |
+
let env ← getEnv
|
| 1651 |
+
let result ← match (← liftMacroM (expandMacroImpl? env stx)) with
|
| 1652 |
+
| some (decl, stxNew?) =>
|
| 1653 |
+
let stxNew ← liftMacroM <| liftExcept stxNew?
|
| 1654 |
+
withTermInfoContext' decl stx (expectedType? := expectedType?) <|
|
| 1655 |
+
withMacroExpansion stx stxNew <|
|
| 1656 |
+
withRef stxNew <|
|
| 1657 |
+
elabTermAux expectedType? catchExPostpone implicitLambda stxNew
|
| 1658 |
+
| _ =>
|
| 1659 |
+
let useImplicitResult ← if implicitLambda && (← read).implicitLambda then useImplicitLambda stx expectedType? else pure .no
|
| 1660 |
+
match useImplicitResult with
|
| 1661 |
+
| .yes expectedType => elabImplicitLambda stx catchExPostpone expectedType
|
| 1662 |
+
| .no => elabUsingElabFns stx expectedType? catchExPostpone
|
| 1663 |
+
| .postpone =>
|
| 1664 |
+
/-
|
| 1665 |
+
Try to postpone elaboration, and if we cannot postpone anymore disable implicit lambdas.
|
| 1666 |
+
See comment at `useImplicitLambda`.
|
| 1667 |
+
-/
|
| 1668 |
+
if (← read).mayPostpone then
|
| 1669 |
+
if catchExPostpone then
|
| 1670 |
+
postponeElabTerm stx expectedType?
|
| 1671 |
+
else
|
| 1672 |
+
throwPostpone
|
| 1673 |
+
else
|
| 1674 |
+
elabUsingElabFns stx expectedType? catchExPostpone
|
| 1675 |
+
trace[Elab.step.result] result
|
| 1676 |
+
pure result
|
| 1677 |
+
|
| 1678 |
+
/-- Store in the `InfoTree` that `e` is a "dot"-completion target. `stx` should cover the entire term. -/
|
| 1679 |
+
def addDotCompletionInfo (stx : Syntax) (e : Expr) (expectedType? : Option Expr) : TermElabM Unit := do
|
| 1680 |
+
addCompletionInfo <| CompletionInfo.dot { expr := e, stx, lctx := (← getLCtx), elaborator := .anonymous, expectedType? } (expectedType? := expectedType?)
|
| 1681 |
+
|
| 1682 |
+
/--
|
| 1683 |
+
Main function for elaborating terms.
|
| 1684 |
+
It extracts the elaboration methods from the environment using the node kind.
|
| 1685 |
+
Recall that the environment has a mapping from `SyntaxNodeKind` to `TermElab` methods.
|
| 1686 |
+
It creates a fresh macro scope for executing the elaboration method.
|
| 1687 |
+
All unlogged trace messages produced by the elaboration method are logged using
|
| 1688 |
+
the position information at `stx`. If the elaboration method throws an `Exception.error` and `errToSorry == true`,
|
| 1689 |
+
the error is logged and a synthetic sorry expression is returned.
|
| 1690 |
+
If the elaboration throws `Exception.postpone` and `catchExPostpone == true`,
|
| 1691 |
+
a new synthetic metavariable of kind `SyntheticMVarKind.postponed` is created, registered,
|
| 1692 |
+
and returned.
|
| 1693 |
+
The option `catchExPostpone == false` is used to implement `resumeElabTerm`
|
| 1694 |
+
to prevent the creation of another synthetic metavariable when resuming the elaboration.
|
| 1695 |
+
|
| 1696 |
+
If `implicitLambda == false`, then disable implicit lambdas feature for the given syntax, but not for its subterms.
|
| 1697 |
+
We use this flag to implement, for example, the `@` modifier. If `Context.implicitLambda == false`, then this parameter has no effect.
|
| 1698 |
+
-/
|
| 1699 |
+
def elabTerm (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone := true) (implicitLambda := true) : TermElabM Expr :=
|
| 1700 |
+
withRef stx <| elabTermAux expectedType? catchExPostpone implicitLambda stx
|
| 1701 |
+
|
| 1702 |
+
/--
|
| 1703 |
+
Similar to `Lean.Elab.Term.elabTerm`, but ensures that the type of the elaborated term is `expectedType?`
|
| 1704 |
+
by inserting coercions if necessary.
|
| 1705 |
+
|
| 1706 |
+
If `errToSorry` is true, then if coercion insertion fails, this function returns `sorry` and logs the error.
|
| 1707 |
+
Otherwise, it throws the error.
|
| 1708 |
+
-/
|
| 1709 |
+
def elabTermEnsuringType (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone := true) (implicitLambda := true) (errorMsgHeader? : Option String := none) : TermElabM Expr := do
|
| 1710 |
+
let e ← elabTerm stx expectedType? catchExPostpone implicitLambda
|
| 1711 |
+
try
|
| 1712 |
+
withRef stx <| ensureHasType expectedType? e errorMsgHeader?
|
| 1713 |
+
catch ex =>
|
| 1714 |
+
if (← read).errToSorry && ex matches .error .. then
|
| 1715 |
+
withRef stx <| exceptionToSorry ex expectedType?
|
| 1716 |
+
else
|
| 1717 |
+
throw ex
|
| 1718 |
+
|
| 1719 |
+
/-- Execute `x` and return `some` if no new errors were recorded or exceptions were thrown. Otherwise, return `none`. -/
|
| 1720 |
+
def commitIfNoErrors? (x : TermElabM α) : TermElabM (Option α) := do
|
| 1721 |
+
let saved ← saveState
|
| 1722 |
+
Core.resetMessageLog
|
| 1723 |
+
try
|
| 1724 |
+
let a ← x
|
| 1725 |
+
if (← MonadLog.hasErrors) then
|
| 1726 |
+
restoreState saved
|
| 1727 |
+
return none
|
| 1728 |
+
else
|
| 1729 |
+
Core.setMessageLog (saved.meta.core.messages ++ (← Core.getMessageLog))
|
| 1730 |
+
return a
|
| 1731 |
+
catch _ =>
|
| 1732 |
+
restoreState saved
|
| 1733 |
+
return none
|
| 1734 |
+
|
| 1735 |
+
/-- Adapt a syntax transformation to a regular, term-producing elaborator. -/
|
| 1736 |
+
def adaptExpander (exp : Syntax → TermElabM Syntax) : TermElab := fun stx expectedType? => do
|
| 1737 |
+
let stx' ← exp stx
|
| 1738 |
+
withMacroExpansion stx stx' <| elabTerm stx' expectedType?
|
| 1739 |
+
|
| 1740 |
+
/--
|
| 1741 |
+
Create a new metavariable with the given type, and try to synthesize it.
|
| 1742 |
+
If type class resolution cannot be executed (e.g., it is stuck because of metavariables in `type`),
|
| 1743 |
+
register metavariable as a pending one.
|
| 1744 |
+
-/
|
| 1745 |
+
def mkInstMVar (type : Expr) (extraErrorMsg? : Option MessageData := none) : TermElabM Expr := do
|
| 1746 |
+
let mvar ← mkFreshExprMVar type MetavarKind.synthetic
|
| 1747 |
+
let mvarId := mvar.mvarId!
|
| 1748 |
+
unless (← synthesizeInstMVarCore mvarId (extraErrorMsg? := extraErrorMsg?)) do
|
| 1749 |
+
registerSyntheticMVarWithCurrRef mvarId (.typeClass extraErrorMsg?)
|
| 1750 |
+
return mvar
|
| 1751 |
+
|
| 1752 |
+
/--
|
| 1753 |
+
Make sure `e` is a type by inferring its type and making sure it is an `Expr.sort`
|
| 1754 |
+
or is unifiable with `Expr.sort`, or can be coerced into one. -/
|
| 1755 |
+
def ensureType (e : Expr) : TermElabM Expr := do
|
| 1756 |
+
if (← isType e) then
|
| 1757 |
+
return e
|
| 1758 |
+
else
|
| 1759 |
+
let eType ← inferType e
|
| 1760 |
+
let u ← mkFreshLevelMVar
|
| 1761 |
+
if (← isDefEq eType (mkSort u)) then
|
| 1762 |
+
return e
|
| 1763 |
+
else if let some coerced ← coerceToSort? e then
|
| 1764 |
+
return coerced
|
| 1765 |
+
else
|
| 1766 |
+
if (← instantiateMVars e).hasSyntheticSorry then
|
| 1767 |
+
throwAbortTerm
|
| 1768 |
+
throwError "type expected, got\n ({← instantiateMVars e} : {← instantiateMVars eType})"
|
| 1769 |
+
|
| 1770 |
+
/-- Elaborate `stx` and ensure result is a type. -/
|
| 1771 |
+
def elabType (stx : Syntax) : TermElabM Expr := do
|
| 1772 |
+
let u ← mkFreshLevelMVar
|
| 1773 |
+
let type ← elabTerm stx (mkSort u)
|
| 1774 |
+
withRef stx <| ensureType type
|
| 1775 |
+
|
| 1776 |
+
/--
|
| 1777 |
+
Enable auto-bound implicits, and execute `k` while catching auto bound implicit exceptions. When an exception is caught,
|
| 1778 |
+
a new local declaration is created, registered, and `k` is tried to be executed again. -/
|
| 1779 |
+
partial def withAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
|
| 1780 |
+
let flag := autoImplicit.get (← getOptions)
|
| 1781 |
+
if flag then
|
| 1782 |
+
withReader (fun ctx => { ctx with autoBoundImplicit := flag, autoBoundImplicits := {} }) do
|
| 1783 |
+
let rec loop (s : SavedState) : TermElabM α := withIncRecDepth do
|
| 1784 |
+
checkSystem "auto-implicit"
|
| 1785 |
+
try
|
| 1786 |
+
k
|
| 1787 |
+
catch
|
| 1788 |
+
| ex => match isAutoBoundImplicitLocalException? ex with
|
| 1789 |
+
| some n =>
|
| 1790 |
+
-- Restore state, declare `n`, and try again
|
| 1791 |
+
s.restore (restoreInfo := true)
|
| 1792 |
+
withLocalDecl n .implicit (← mkFreshTypeMVar) fun x =>
|
| 1793 |
+
withReader (fun ctx => { ctx with autoBoundImplicits := ctx.autoBoundImplicits.push x } ) do
|
| 1794 |
+
loop (← saveState)
|
| 1795 |
+
| none => throw ex
|
| 1796 |
+
loop (← saveState)
|
| 1797 |
+
else
|
| 1798 |
+
k
|
| 1799 |
+
|
| 1800 |
+
def withoutAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
|
| 1801 |
+
withReader (fun ctx => { ctx with autoBoundImplicit := false, autoBoundImplicits := {} }) k
|
| 1802 |
+
|
| 1803 |
+
partial def withAutoBoundImplicitForbiddenPred (p : Name → Bool) (x : TermElabM α) : TermElabM α := do
|
| 1804 |
+
withReader (fun ctx => { ctx with autoBoundImplicitForbidden := fun n => p n || ctx.autoBoundImplicitForbidden n }) x
|
| 1805 |
+
|
| 1806 |
+
/--
|
| 1807 |
+
Collect unassigned metavariables in `type` that are not already in `init` and not satisfying `except`.
|
| 1808 |
+
-/
|
| 1809 |
+
partial def collectUnassignedMVars (type : Expr) (init : Array Expr := #[]) (except : MVarId → Bool := fun _ => false)
|
| 1810 |
+
: TermElabM (Array Expr) := do
|
| 1811 |
+
let mvarIds ← getMVars type
|
| 1812 |
+
if mvarIds.isEmpty then
|
| 1813 |
+
return init
|
| 1814 |
+
else
|
| 1815 |
+
go mvarIds.toList init init
|
| 1816 |
+
where
|
| 1817 |
+
go (mvarIds : List MVarId) (result visited : Array Expr) : TermElabM (Array Expr) := do
|
| 1818 |
+
match mvarIds with
|
| 1819 |
+
| [] => return result
|
| 1820 |
+
| mvarId :: mvarIds => do
|
| 1821 |
+
let visited := visited.push (mkMVar mvarId)
|
| 1822 |
+
if (← mvarId.isAssigned) then
|
| 1823 |
+
go mvarIds result visited
|
| 1824 |
+
else if result.contains (mkMVar mvarId) || except mvarId then
|
| 1825 |
+
go mvarIds result visited
|
| 1826 |
+
else
|
| 1827 |
+
let mvarType := (← getMVarDecl mvarId).type
|
| 1828 |
+
let mvarIdsNew ← getMVars mvarType
|
| 1829 |
+
let mvarIdsNew := mvarIdsNew.filter fun mvarId => !visited.contains (mkMVar mvarId)
|
| 1830 |
+
if mvarIdsNew.isEmpty then
|
| 1831 |
+
go mvarIds (result.push (mkMVar mvarId)) visited
|
| 1832 |
+
else
|
| 1833 |
+
go (mvarIdsNew.toList ++ mvarId :: mvarIds) result visited
|
| 1834 |
+
|
| 1835 |
+
/--
|
| 1836 |
+
Adds an `InlayHintInfo` for the fvar auto implicits in `autos` at `inlayHintPos`.
|
| 1837 |
+
The inserted inlay hint has a hover that denotes the type of the auto-implicit (with meta-variables)
|
| 1838 |
+
and can be inserted at `inlayHintPos`.
|
| 1839 |
+
-/
|
| 1840 |
+
def addAutoBoundImplicitsInlayHint (autos : Array Expr) (inlayHintPos : String.Pos) : TermElabM Unit := do
|
| 1841 |
+
-- If the list of auto-implicits contains a non-type fvar, then the list of auto-implicits will
|
| 1842 |
+
-- also contain an mvar that denotes the type of the non-type fvar.
|
| 1843 |
+
-- For example, the auto-implicit `x` in a type `Foo x` for `Foo.{u} {α : Sort u} (x : α) : Type`
|
| 1844 |
+
-- also comes with an auto-implicit mvar denoting the type of `x`.
|
| 1845 |
+
-- We have no way of displaying this mvar to the user in an inlay hint, as it doesn't have a name,
|
| 1846 |
+
-- so we filter it.
|
| 1847 |
+
-- This also means that inserting the inlay hint with the syntax displayed in the inlay hint will
|
| 1848 |
+
-- cause a "failed to infer binder type" error, since we don't have a name to insert in the code.
|
| 1849 |
+
let autos := autos.filter (· matches .fvar ..)
|
| 1850 |
+
if autos.isEmpty then
|
| 1851 |
+
return
|
| 1852 |
+
let autoNames ← autos.mapM (·.fvarId!.getUserName)
|
| 1853 |
+
let formattedHint := s!" \{{" ".intercalate <| Array.toList <| autoNames.map toString}}"
|
| 1854 |
+
let deferredResolution ih := do
|
| 1855 |
+
let description := "Automatically-inserted implicit parameters:"
|
| 1856 |
+
let codeBlockStart := "```lean"
|
| 1857 |
+
let typeInfos ← autos.mapM fun auto => do
|
| 1858 |
+
let name := toString <| ← auto.fvarId!.getUserName
|
| 1859 |
+
let type := toString <| ← Meta.ppExpr <| ← instantiateMVars (← inferType auto)
|
| 1860 |
+
return s!"{name} : {type}"
|
| 1861 |
+
let codeBlockEnd := "```"
|
| 1862 |
+
let tooltip := "\n".intercalate <| description :: codeBlockStart :: typeInfos.toList ++ [codeBlockEnd]
|
| 1863 |
+
return { ih with tooltip? := tooltip }
|
| 1864 |
+
pushInfoLeaf <| .ofCustomInfo {
|
| 1865 |
+
position := inlayHintPos
|
| 1866 |
+
label := .name formattedHint
|
| 1867 |
+
textEdits := #[{
|
| 1868 |
+
range := ⟨inlayHintPos, inlayHintPos⟩,
|
| 1869 |
+
newText := formattedHint
|
| 1870 |
+
}]
|
| 1871 |
+
kind? := some .parameter
|
| 1872 |
+
lctx := ← getLCtx
|
| 1873 |
+
deferredResolution
|
| 1874 |
+
: InlayHint
|
| 1875 |
+
}.toCustomInfo
|
| 1876 |
+
|
| 1877 |
+
/--
|
| 1878 |
+
Return `autoBoundImplicits ++ xs`
|
| 1879 |
+
This method throws an error if a variable in `autoBoundImplicits` depends on some `x` in `xs`.
|
| 1880 |
+
The `autoBoundImplicits` may contain free variables created by the auto-implicit feature, and unassigned free variables.
|
| 1881 |
+
It avoids the hack used at `autoBoundImplicitsOld`.
|
| 1882 |
+
|
| 1883 |
+
If `inlayHintPos?` is set, this function also inserts an inlay hint denoting `autoBoundImplicits`.
|
| 1884 |
+
See `addAutoBoundImplicitsInlayHint` for more information.
|
| 1885 |
+
|
| 1886 |
+
Remark: we cannot simply replace every occurrence of `addAutoBoundImplicitsOld` with this one because a particular
|
| 1887 |
+
use-case may not be able to handle the metavariables in the array being given to `k`.
|
| 1888 |
+
-/
|
| 1889 |
+
def addAutoBoundImplicits (xs : Array Expr) (inlayHintPos? : Option String.Pos) : TermElabM (Array Expr) := do
|
| 1890 |
+
let autos := (← read).autoBoundImplicits
|
| 1891 |
+
go autos.toList #[]
|
| 1892 |
+
where
|
| 1893 |
+
go (todo : List Expr) (autos : Array Expr) : TermElabM (Array Expr) := do
|
| 1894 |
+
match todo with
|
| 1895 |
+
| [] =>
|
| 1896 |
+
if let some inlayHintPos := inlayHintPos? then
|
| 1897 |
+
addAutoBoundImplicitsInlayHint autos inlayHintPos
|
| 1898 |
+
for auto in autos do
|
| 1899 |
+
if auto.isFVar then
|
| 1900 |
+
let localDecl ← auto.fvarId!.getDecl
|
| 1901 |
+
for x in xs do
|
| 1902 |
+
if (← localDeclDependsOn localDecl x.fvarId!) then
|
| 1903 |
+
throwError "invalid auto implicit argument '{auto}', it depends on explicitly provided argument '{x}'"
|
| 1904 |
+
return autos ++ xs
|
| 1905 |
+
| auto :: todo =>
|
| 1906 |
+
let autos ← collectUnassignedMVars (← inferType auto) autos
|
| 1907 |
+
go todo (autos.push auto)
|
| 1908 |
+
|
| 1909 |
+
/--
|
| 1910 |
+
Similar to `addAutoBoundImplicits`, but converts all metavariables into free variables.
|
| 1911 |
+
|
| 1912 |
+
It uses `mkForallFVars` + `forallBoundedTelescope` to convert metavariables into free variables.
|
| 1913 |
+
The type `type` is modified during the process if type depends on `xs`.
|
| 1914 |
+
We use this method to simplify the conversion of code using `autoBoundImplicitsOld` to `autoBoundImplicits`.
|
| 1915 |
+
-/
|
| 1916 |
+
def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr → Expr → TermElabM α) (inlayHintPos? : Option String.Pos := none) : TermElabM α := do
|
| 1917 |
+
let xs ← addAutoBoundImplicits xs inlayHintPos?
|
| 1918 |
+
if xs.all (·.isFVar) then
|
| 1919 |
+
k xs type
|
| 1920 |
+
else
|
| 1921 |
+
forallBoundedTelescope (← mkForallFVars xs type) xs.size fun xs type => k xs type
|
| 1922 |
+
|
| 1923 |
+
def mkAuxName (suffix : Name) : TermElabM Name := mkAuxDeclName (kind := suffix)
|
| 1924 |
+
|
| 1925 |
+
builtin_initialize registerTraceClass `Elab.letrec
|
| 1926 |
+
|
| 1927 |
+
/-- Return true if mvarId is an auxiliary metavariable created for compiling `let rec` or it
|
| 1928 |
+
is delayed assigned to one. -/
|
| 1929 |
+
def isLetRecAuxMVar (mvarId : MVarId) : TermElabM Bool := do
|
| 1930 |
+
trace[Elab.letrec] "mvarId: {mkMVar mvarId} letrecMVars: {(← get).letRecsToLift.map (mkMVar $ ·.mvarId)}"
|
| 1931 |
+
let mvarId ← getDelayedMVarRoot mvarId
|
| 1932 |
+
trace[Elab.letrec] "mvarId root: {mkMVar mvarId}"
|
| 1933 |
+
return (← get).letRecsToLift.any (·.mvarId == mvarId)
|
| 1934 |
+
|
| 1935 |
+
private def checkDeprecatedCore (constName : Name) : TermElabM Unit := do
|
| 1936 |
+
if (← read).checkDeprecated then
|
| 1937 |
+
Linter.checkDeprecated constName
|
| 1938 |
+
|
| 1939 |
+
/--
|
| 1940 |
+
Create an `Expr.const` using the given name and explicit levels.
|
| 1941 |
+
Remark: fresh universe metavariables are created if the constant has more universe
|
| 1942 |
+
parameters than `explicitLevels`.
|
| 1943 |
+
|
| 1944 |
+
If `checkDeprecated := true`, then `Linter.checkDeprecated` is invoked.
|
| 1945 |
+
-/
|
| 1946 |
+
def mkConst (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do
|
| 1947 |
+
checkDeprecatedCore constName
|
| 1948 |
+
let cinfo ← getConstVal constName
|
| 1949 |
+
if explicitLevels.length > cinfo.levelParams.length then
|
| 1950 |
+
throwError "too many explicit universe levels for '{constName}'"
|
| 1951 |
+
else
|
| 1952 |
+
let numMissingLevels := cinfo.levelParams.length - explicitLevels.length
|
| 1953 |
+
let us ← mkFreshLevelMVars numMissingLevels
|
| 1954 |
+
return Lean.mkConst constName (explicitLevels ++ us)
|
| 1955 |
+
|
| 1956 |
+
def checkDeprecated (ref : Syntax) (e : Expr) : TermElabM Unit := do
|
| 1957 |
+
if let .const declName _ := e.getAppFn then
|
| 1958 |
+
withRef ref do checkDeprecatedCore declName
|
| 1959 |
+
|
| 1960 |
+
@[inline] def withoutCheckDeprecated [MonadWithReaderOf Context m] : m α → m α :=
|
| 1961 |
+
withTheReader Context (fun ctx => { ctx with checkDeprecated := false })
|
| 1962 |
+
|
| 1963 |
+
private def mkConsts (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do
|
| 1964 |
+
candidates.foldlM (init := []) fun result (declName, projs) => do
|
| 1965 |
+
-- TODO: better support for `mkConst` failure. We may want to cache the failures, and report them if all candidates fail.
|
| 1966 |
+
/-
|
| 1967 |
+
We disable `checkDeprecated` here because there may be many overloaded symbols.
|
| 1968 |
+
Note that, this method and `resolveName` and `resolveName'` return a list of pairs instead of a list of `TermElabResult`s.
|
| 1969 |
+
We perform the `checkDeprecated` test at `resolveId?` and `elabAppFnId`.
|
| 1970 |
+
At `elabAppFnId`, we perform the check when converting the list returned by `resolveName'` into a list of
|
| 1971 |
+
`TermElabResult`s.
|
| 1972 |
+
-/
|
| 1973 |
+
let const ← withoutCheckDeprecated <| mkConst declName explicitLevels
|
| 1974 |
+
return (const, projs) :: result
|
| 1975 |
+
|
| 1976 |
+
def resolveName (stx : Syntax) (n : Name) (preresolved : List Syntax.Preresolved) (explicitLevels : List Level) (expectedType? : Option Expr := none) : TermElabM (List (Expr × List String)) := do
|
| 1977 |
+
addCompletionInfo <| CompletionInfo.id stx stx.getId (danglingDot := false) (← getLCtx) expectedType?
|
| 1978 |
+
if let some (e, projs) ← resolveLocalName n then
|
| 1979 |
+
unless explicitLevels.isEmpty do
|
| 1980 |
+
throwError "invalid use of explicit universe parameters, '{e}' is a local variable"
|
| 1981 |
+
return [(e, projs)]
|
| 1982 |
+
let preresolved := preresolved.filterMap fun
|
| 1983 |
+
| .decl n projs => some (n, projs)
|
| 1984 |
+
| _ => none
|
| 1985 |
+
-- check for section variable capture by a quotation
|
| 1986 |
+
let ctx ← read
|
| 1987 |
+
if let some (e, projs) := preresolved.findSome? fun (n, projs) => ctx.sectionFVars.find? n |>.map (·, projs) then
|
| 1988 |
+
return [(e, projs)] -- section variables should shadow global decls
|
| 1989 |
+
if preresolved.isEmpty then
|
| 1990 |
+
process (← realizeGlobalName n)
|
| 1991 |
+
else
|
| 1992 |
+
process preresolved
|
| 1993 |
+
where
|
| 1994 |
+
process (candidates : List (Name × List String)) : TermElabM (List (Expr × List String)) := do
|
| 1995 |
+
if candidates.isEmpty then
|
| 1996 |
+
if (← read).autoBoundImplicit &&
|
| 1997 |
+
!(← read).autoBoundImplicitForbidden n &&
|
| 1998 |
+
isValidAutoBoundImplicitName n (relaxedAutoImplicit.get (← getOptions)) then
|
| 1999 |
+
throwAutoBoundImplicitLocal n
|
| 2000 |
+
else
|
| 2001 |
+
throwUnknownIdentifierAt stx m!"unknown identifier '{Lean.mkConst n}'"
|
| 2002 |
+
mkConsts candidates explicitLevels
|
| 2003 |
+
|
| 2004 |
+
/--
|
| 2005 |
+
Similar to `resolveName`, but creates identifiers for the main part and each projection with position information derived from `ident`.
|
| 2006 |
+
Example: Assume resolveName `v.head.bla.boo` produces `(v.head, ["bla", "boo"])`, then this method produces
|
| 2007 |
+
`(v.head, id, [f₁, f₂])` where `id` is an identifier for `v.head`, and `f₁` and `f₂` are identifiers for fields `"bla"` and `"boo"`. -/
|
| 2008 |
+
def resolveName' (ident : Syntax) (explicitLevels : List Level) (expectedType? : Option Expr := none) : TermElabM (List (Expr × Syntax × List Syntax)) := do
|
| 2009 |
+
match ident with
|
| 2010 |
+
| .ident _ _ n preresolved =>
|
| 2011 |
+
let r ← resolveName ident n preresolved explicitLevels expectedType?
|
| 2012 |
+
r.mapM fun (c, fields) => do
|
| 2013 |
+
let ids := ident.identComponents (nFields? := fields.length)
|
| 2014 |
+
return (c, ids.head!, ids.tail!)
|
| 2015 |
+
| _ => throwError "identifier expected"
|
| 2016 |
+
|
| 2017 |
+
def resolveId? (stx : Syntax) (kind := "term") (withInfo := false) : TermElabM (Option Expr) := withRef stx do
|
| 2018 |
+
match stx with
|
| 2019 |
+
| .ident _ _ val preresolved =>
|
| 2020 |
+
let rs ← try resolveName stx val preresolved [] catch _ => pure []
|
| 2021 |
+
let rs := rs.filter fun ⟨_, projs⟩ => projs.isEmpty
|
| 2022 |
+
let fs := rs.map fun (f, _) => f
|
| 2023 |
+
match fs with
|
| 2024 |
+
| [] => return none
|
| 2025 |
+
| [f] =>
|
| 2026 |
+
let f ← if withInfo then addTermInfo stx f else pure f
|
| 2027 |
+
checkDeprecated stx f
|
| 2028 |
+
return some f
|
| 2029 |
+
| _ => throwError "ambiguous {kind}, use fully qualified name, possible interpretations {fs}"
|
| 2030 |
+
| _ => throwError "identifier expected"
|
| 2031 |
+
|
| 2032 |
+
def TermElabM.run (x : TermElabM α) (ctx : Context := {}) (s : State := {}) : MetaM (α × State) :=
|
| 2033 |
+
withConfig setElabConfig (x ctx |>.run s)
|
| 2034 |
+
|
| 2035 |
+
@[inline] def TermElabM.run' (x : TermElabM α) (ctx : Context := {}) (s : State := {}) : MetaM α :=
|
| 2036 |
+
(·.1) <$> x.run ctx s
|
| 2037 |
+
|
| 2038 |
+
def TermElabM.toIO (x : TermElabM α)
|
| 2039 |
+
(ctxCore : Core.Context) (sCore : Core.State)
|
| 2040 |
+
(ctxMeta : Meta.Context) (sMeta : Meta.State)
|
| 2041 |
+
(ctx : Context) (s : State) : IO (α × Core.State × Meta.State × State) := do
|
| 2042 |
+
let ((a, s), sCore, sMeta) ← (x.run ctx s).toIO ctxCore sCore ctxMeta sMeta
|
| 2043 |
+
return (a, sCore, sMeta, s)
|
| 2044 |
+
|
| 2045 |
+
/--
|
| 2046 |
+
Execute `x` and then tries to solve pending universe constraints.
|
| 2047 |
+
Note that, stuck constraints will not be discarded.
|
| 2048 |
+
-/
|
| 2049 |
+
def universeConstraintsCheckpoint (x : TermElabM α) : TermElabM α := do
|
| 2050 |
+
let a ← x
|
| 2051 |
+
discard <| processPostponed (mayPostpone := true) (exceptionOnFailure := true)
|
| 2052 |
+
return a
|
| 2053 |
+
|
| 2054 |
+
def expandDeclId (currNamespace : Name) (currLevelNames : List Name) (declId : Syntax) (modifiers : Modifiers) : TermElabM ExpandDeclIdResult := do
|
| 2055 |
+
let r ← Elab.expandDeclId currNamespace currLevelNames declId modifiers
|
| 2056 |
+
if (← read).sectionVars.contains r.shortName then
|
| 2057 |
+
throwError "invalid declaration name '{r.shortName}', there is a section variable with the same name"
|
| 2058 |
+
return r
|
| 2059 |
+
|
| 2060 |
+
/--
|
| 2061 |
+
Helper function for "embedding" an `Expr` in `Syntax`.
|
| 2062 |
+
It creates a named hole `?m` and immediately assigns `e` to it.
|
| 2063 |
+
Examples:
|
| 2064 |
+
```lean
|
| 2065 |
+
let e := mkConst ``Nat.zero
|
| 2066 |
+
`(Nat.succ $(← exprToSyntax e))
|
| 2067 |
+
```
|
| 2068 |
+
-/
|
| 2069 |
+
def exprToSyntax (e : Expr) : TermElabM Term := withFreshMacroScope do
|
| 2070 |
+
let result ← `(?m)
|
| 2071 |
+
let eType ← inferType e
|
| 2072 |
+
let mvar ← elabTerm result eType
|
| 2073 |
+
mvar.mvarId!.assign e
|
| 2074 |
+
return result
|
| 2075 |
+
|
| 2076 |
+
end Term
|
| 2077 |
+
|
| 2078 |
+
open Term in
|
| 2079 |
+
def withoutModifyingStateWithInfoAndMessages [MonadControlT TermElabM m] [Monad m] (x : m α) : m α := do
|
| 2080 |
+
controlAt TermElabM fun runInBase => withoutModifyingStateWithInfoAndMessagesImpl <| runInBase x
|
| 2081 |
+
|
| 2082 |
+
builtin_initialize
|
| 2083 |
+
registerTraceClass `Elab.postpone
|
| 2084 |
+
registerTraceClass `Elab.coe
|
| 2085 |
+
registerTraceClass `Elab.debug
|
| 2086 |
+
registerTraceClass `Elab.reuse
|
| 2087 |
+
|
| 2088 |
+
/--
|
| 2089 |
+
Marks an elaborator (tactic or command, currently) as supporting incremental elaboration.
|
| 2090 |
+
|
| 2091 |
+
For unmarked elaborators, the corresponding snapshot bundle field in the elaboration context is
|
| 2092 |
+
unset so as to prevent accidental, incorrect reuse.
|
| 2093 |
+
-/
|
| 2094 |
+
@[builtin_doc]
|
| 2095 |
+
builtin_initialize incrementalAttr : TagAttribute ←
|
| 2096 |
+
registerTagAttribute `incremental "Marks an elaborator (tactic or command, currently) as \
|
| 2097 |
+
supporting incremental elaboration. For unmarked elaborators, the corresponding snapshot bundle \
|
| 2098 |
+
field in the elaboration context is unset so as to prevent accidental, incorrect reuse."
|
| 2099 |
+
|
| 2100 |
+
builtin_initialize builtinIncrementalElabs : IO.Ref NameSet ← IO.mkRef {}
|
| 2101 |
+
|
| 2102 |
+
def addBuiltinIncrementalElab (decl : Name) : IO Unit := do
|
| 2103 |
+
builtinIncrementalElabs.modify fun s => s.insert decl
|
| 2104 |
+
|
| 2105 |
+
@[inherit_doc incrementalAttr, builtin_doc]
|
| 2106 |
+
builtin_initialize
|
| 2107 |
+
registerBuiltinAttribute {
|
| 2108 |
+
name := `builtin_incremental
|
| 2109 |
+
descr := s!"(builtin) {incrementalAttr.attr.descr}"
|
| 2110 |
+
applicationTime := .afterCompilation
|
| 2111 |
+
add := fun decl stx kind => do
|
| 2112 |
+
Attribute.Builtin.ensureNoArgs stx
|
| 2113 |
+
unless kind == AttributeKind.global do
|
| 2114 |
+
throwError "invalid attribute 'builtin_incremental', must be global"
|
| 2115 |
+
declareBuiltin decl <| mkApp (mkConst ``addBuiltinIncrementalElab) (toExpr decl)
|
| 2116 |
+
}
|
| 2117 |
+
|
| 2118 |
+
/-- Checks whether a declaration is annotated with `[builtin_incremental]` or `[incremental]`. -/
|
| 2119 |
+
def isIncrementalElab [Monad m] [MonadEnv m] [MonadLiftT IO m] (decl : Name) : m Bool :=
|
| 2120 |
+
(return (← builtinIncrementalElabs.get (m := IO)).contains decl) <||>
|
| 2121 |
+
(return incrementalAttr.hasTag (← getEnv) decl)
|
| 2122 |
+
|
| 2123 |
+
export Term (TermElabM)
|
| 2124 |
+
|
| 2125 |
+
builtin_initialize
|
| 2126 |
+
registerTraceClass `Elab.implicitForall
|
| 2127 |
+
|
| 2128 |
+
end Lean.Elab
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Time.lean
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2021 Mario Carneiro. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Mario Carneiro
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Elab.Command
|
| 8 |
+
|
| 9 |
+
/-!
|
| 10 |
+
# Defines `#time` command.
|
| 11 |
+
|
| 12 |
+
Time the elaboration of a command, and print the result (in milliseconds).
|
| 13 |
+
-/
|
| 14 |
+
|
| 15 |
+
namespace Lean.Elab.Time
|
| 16 |
+
|
| 17 |
+
open Command
|
| 18 |
+
|
| 19 |
+
@[builtin_command_elab Lean.Parser.timeCmd] def elabTimeCmd : CommandElab
|
| 20 |
+
| `(#time%$tk $stx:command) => do
|
| 21 |
+
let start ← IO.monoMsNow
|
| 22 |
+
elabCommand stx
|
| 23 |
+
logInfoAt tk m!"time: {(← IO.monoMsNow) - start}ms"
|
| 24 |
+
| _ => throwUnsupportedSyntax
|
| 25 |
+
|
| 26 |
+
end Lean.Elab.Time
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Util.lean
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Leonardo de Moura
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Parser.Command
|
| 8 |
+
import Lean.KeyedDeclsAttribute
|
| 9 |
+
import Lean.Elab.Exception
|
| 10 |
+
|
| 11 |
+
namespace Lean
|
| 12 |
+
|
| 13 |
+
def Syntax.prettyPrint (stx : Syntax) : Format :=
|
| 14 |
+
match stx.unsetTrailing.reprint with -- TODO use syntax pretty printer
|
| 15 |
+
| some str => format str.toFormat
|
| 16 |
+
| none => format stx
|
| 17 |
+
|
| 18 |
+
def MacroScopesView.format (view : MacroScopesView) (mainModule : Name) : Format :=
|
| 19 |
+
Std.format <|
|
| 20 |
+
if view.scopes.isEmpty then
|
| 21 |
+
view.name
|
| 22 |
+
else if view.mainModule == mainModule then
|
| 23 |
+
view.scopes.foldl Name.mkNum (view.name ++ view.imported)
|
| 24 |
+
else
|
| 25 |
+
view.scopes.foldl Name.mkNum (view.name ++ view.imported ++ view.mainModule)
|
| 26 |
+
|
| 27 |
+
/--
|
| 28 |
+
Two names are from the same lexical scope if their scoping information modulo `MacroScopesView.name`
|
| 29 |
+
is equal.
|
| 30 |
+
-/
|
| 31 |
+
def MacroScopesView.equalScope (a b : MacroScopesView) : Bool :=
|
| 32 |
+
a.scopes == b.scopes && a.mainModule == b.mainModule && a.imported == b.imported
|
| 33 |
+
|
| 34 |
+
namespace Elab
|
| 35 |
+
|
| 36 |
+
def expandOptNamedPrio (stx : Syntax) : MacroM Nat :=
|
| 37 |
+
if stx.isNone then
|
| 38 |
+
return eval_prio default
|
| 39 |
+
else match stx[0] with
|
| 40 |
+
| `(Parser.Command.namedPrio| (priority := $prio)) => evalPrio prio
|
| 41 |
+
| _ => Macro.throwUnsupported
|
| 42 |
+
|
| 43 |
+
structure MacroStackElem where
|
| 44 |
+
before : Syntax
|
| 45 |
+
after : Syntax
|
| 46 |
+
|
| 47 |
+
abbrev MacroStack := List MacroStackElem
|
| 48 |
+
|
| 49 |
+
/-- If `ref` does not have position information, then try to use macroStack -/
|
| 50 |
+
def getBetterRef (ref : Syntax) (macroStack : MacroStack) : Syntax :=
|
| 51 |
+
match ref.getPos? with
|
| 52 |
+
| some _ => ref
|
| 53 |
+
| none =>
|
| 54 |
+
match macroStack.find? (·.before.getPos? != none) with
|
| 55 |
+
| some elem => elem.before
|
| 56 |
+
| none => ref
|
| 57 |
+
|
| 58 |
+
register_builtin_option pp.macroStack : Bool := {
|
| 59 |
+
defValue := false
|
| 60 |
+
group := "pp"
|
| 61 |
+
descr := "display macro expansion stack"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def addMacroStack {m} [Monad m] [MonadOptions m] (msgData : MessageData) (macroStack : MacroStack) : m MessageData := do
|
| 65 |
+
if !pp.macroStack.get (← getOptions) then pure msgData else
|
| 66 |
+
match macroStack with
|
| 67 |
+
| [] => pure msgData
|
| 68 |
+
| stack@(top::_) =>
|
| 69 |
+
let msgData := msgData ++ Format.line ++ "with resulting expansion" ++ indentD top.after
|
| 70 |
+
pure $ stack.foldl
|
| 71 |
+
(fun (msgData : MessageData) (elem : MacroStackElem) =>
|
| 72 |
+
msgData ++ Format.line ++ "while expanding" ++ indentD elem.before)
|
| 73 |
+
msgData
|
| 74 |
+
|
| 75 |
+
def checkSyntaxNodeKind [Monad m] [MonadEnv m] [MonadError m] (k : Name) : m Name := do
|
| 76 |
+
if Parser.isValidSyntaxNodeKind (← getEnv) k then pure k
|
| 77 |
+
else throwError "failed"
|
| 78 |
+
|
| 79 |
+
def checkSyntaxNodeKindAtNamespaces [Monad m] [MonadEnv m] [MonadError m] (k : Name) : Name → m Name
|
| 80 |
+
| n@(.str p _) => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespaces k p
|
| 81 |
+
| .anonymous => checkSyntaxNodeKind k
|
| 82 |
+
| _ => throwError "failed"
|
| 83 |
+
|
| 84 |
+
def checkSyntaxNodeKindAtCurrentNamespaces (k : Name) : AttrM Name := do
|
| 85 |
+
let ctx ← read
|
| 86 |
+
checkSyntaxNodeKindAtNamespaces k ctx.currNamespace
|
| 87 |
+
|
| 88 |
+
def syntaxNodeKindOfAttrParam (defaultParserNamespace : Name) (stx : Syntax) : AttrM SyntaxNodeKind := do
|
| 89 |
+
let k ← Attribute.Builtin.getId stx
|
| 90 |
+
checkSyntaxNodeKindAtCurrentNamespaces k
|
| 91 |
+
<|>
|
| 92 |
+
checkSyntaxNodeKind (defaultParserNamespace ++ k)
|
| 93 |
+
<|>
|
| 94 |
+
throwError "invalid syntax node kind '{k}'"
|
| 95 |
+
|
| 96 |
+
private unsafe def evalSyntaxConstantUnsafe (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax :=
|
| 97 |
+
env.evalConstCheck Syntax opts `Lean.Syntax constName
|
| 98 |
+
|
| 99 |
+
@[implemented_by evalSyntaxConstantUnsafe]
|
| 100 |
+
opaque evalSyntaxConstant (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax := throw ""
|
| 101 |
+
|
| 102 |
+
unsafe def mkElabAttribute (γ) (attrBuiltinName attrName : Name) (parserNamespace : Name) (typeName : Name) (kind : String)
|
| 103 |
+
(attrDeclName : Name := by exact decl_name%) : IO (KeyedDeclsAttribute γ) :=
|
| 104 |
+
KeyedDeclsAttribute.init {
|
| 105 |
+
builtinName := attrBuiltinName
|
| 106 |
+
name := attrName
|
| 107 |
+
descr := kind ++ " elaborator"
|
| 108 |
+
valueTypeName := typeName
|
| 109 |
+
evalKey := fun _ stx => do
|
| 110 |
+
let kind ← syntaxNodeKindOfAttrParam parserNamespace stx
|
| 111 |
+
/- Recall that a `SyntaxNodeKind` is often the name of the parser, but this is not always true, and we must check it. -/
|
| 112 |
+
if (← getEnv).contains kind && (← getInfoState).enabled then
|
| 113 |
+
addConstInfo stx[1] kind none
|
| 114 |
+
return kind
|
| 115 |
+
onAdded := fun builtin declName => do
|
| 116 |
+
if builtin then
|
| 117 |
+
declareBuiltinDocStringAndRanges declName
|
| 118 |
+
} attrDeclName
|
| 119 |
+
|
| 120 |
+
unsafe def mkMacroAttributeUnsafe (ref : Name) : IO (KeyedDeclsAttribute Macro) :=
|
| 121 |
+
mkElabAttribute Macro `builtin_macro `macro Name.anonymous `Lean.Macro "macro" ref
|
| 122 |
+
|
| 123 |
+
@[implemented_by mkMacroAttributeUnsafe]
|
| 124 |
+
opaque mkMacroAttribute (ref : Name) : IO (KeyedDeclsAttribute Macro)
|
| 125 |
+
|
| 126 |
+
/--
|
| 127 |
+
Registers a macro expander for a given syntax node kind.
|
| 128 |
+
|
| 129 |
+
A macro expander should have type `Lean.Macro` (which is `Lean.Syntax → Lean.MacroM Lean.Syntax`),
|
| 130 |
+
i.e. should take syntax of the given syntax node kind as a parameter and produce different syntax
|
| 131 |
+
in the same syntax category.
|
| 132 |
+
|
| 133 |
+
The `macro_rules` and `macro` commands should usually be preferred over using this attribute
|
| 134 |
+
directly.
|
| 135 |
+
-/
|
| 136 |
+
@[builtin_doc]
|
| 137 |
+
builtin_initialize macroAttribute : KeyedDeclsAttribute Macro ← mkMacroAttribute decl_name%
|
| 138 |
+
|
| 139 |
+
/--
|
| 140 |
+
Try to expand macro at syntax tree root and return macro declaration name and new syntax if successful.
|
| 141 |
+
Return none if all macros threw `Macro.Exception.unsupportedSyntax`.
|
| 142 |
+
-/
|
| 143 |
+
def expandMacroImpl? (env : Environment) : Syntax → MacroM (Option (Name × Except Macro.Exception Syntax)) := fun stx => do
|
| 144 |
+
for e in macroAttribute.getEntries env stx.getKind do
|
| 145 |
+
try
|
| 146 |
+
let stx' ← withFreshMacroScope (e.value stx)
|
| 147 |
+
return (e.declName, Except.ok stx')
|
| 148 |
+
catch
|
| 149 |
+
| Macro.Exception.unsupportedSyntax => pure ()
|
| 150 |
+
| ex => return (e.declName, Except.error ex)
|
| 151 |
+
return none
|
| 152 |
+
|
| 153 |
+
class MonadMacroAdapter (m : Type → Type) where
|
| 154 |
+
getCurrMacroScope : m MacroScope
|
| 155 |
+
getNextMacroScope : m MacroScope
|
| 156 |
+
setNextMacroScope : MacroScope → m Unit
|
| 157 |
+
|
| 158 |
+
@[always_inline]
|
| 159 |
+
instance (m n) [MonadLift m n] [MonadMacroAdapter m] : MonadMacroAdapter n := {
|
| 160 |
+
getCurrMacroScope := liftM (MonadMacroAdapter.getCurrMacroScope : m _)
|
| 161 |
+
getNextMacroScope := liftM (MonadMacroAdapter.getNextMacroScope : m _)
|
| 162 |
+
setNextMacroScope := fun s => liftM (MonadMacroAdapter.setNextMacroScope s : m _)
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def liftMacroM [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadError m] [MonadResolveName m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] (x : MacroM α) : m α := do
|
| 166 |
+
let env ← getEnv
|
| 167 |
+
let currNamespace ← getCurrNamespace
|
| 168 |
+
let openDecls ← getOpenDecls
|
| 169 |
+
let methods := Macro.mkMethods {
|
| 170 |
+
-- TODO: record recursive expansions in info tree?
|
| 171 |
+
expandMacro? := fun stx => do
|
| 172 |
+
match (← expandMacroImpl? env stx) with
|
| 173 |
+
| some (_, stx?) => liftExcept stx?
|
| 174 |
+
| none => return none
|
| 175 |
+
hasDecl := fun declName => return env.contains declName
|
| 176 |
+
getCurrNamespace := return currNamespace
|
| 177 |
+
resolveNamespace := fun n => return ResolveName.resolveNamespace env currNamespace openDecls n
|
| 178 |
+
resolveGlobalName := fun n => return ResolveName.resolveGlobalName env currNamespace openDecls n
|
| 179 |
+
}
|
| 180 |
+
match x { methods := methods
|
| 181 |
+
ref := ← getRef
|
| 182 |
+
currMacroScope := ← MonadMacroAdapter.getCurrMacroScope
|
| 183 |
+
mainModule := env.mainModule
|
| 184 |
+
currRecDepth := ← MonadRecDepth.getRecDepth
|
| 185 |
+
maxRecDepth := ← MonadRecDepth.getMaxRecDepth
|
| 186 |
+
} { macroScope := (← MonadMacroAdapter.getNextMacroScope) } with
|
| 187 |
+
| EStateM.Result.error Macro.Exception.unsupportedSyntax _ => throwUnsupportedSyntax
|
| 188 |
+
| EStateM.Result.error (Macro.Exception.error ref msg) _ =>
|
| 189 |
+
if msg == maxRecDepthErrorMessage then
|
| 190 |
+
-- Make sure we can detect exception using `Exception.isMaxRecDepth`
|
| 191 |
+
throwMaxRecDepthAt ref
|
| 192 |
+
else
|
| 193 |
+
throwErrorAt ref msg
|
| 194 |
+
| EStateM.Result.ok a s =>
|
| 195 |
+
MonadMacroAdapter.setNextMacroScope s.macroScope
|
| 196 |
+
s.traceMsgs.reverse.forM fun (clsName, msg) => trace clsName fun _ => msg
|
| 197 |
+
return a
|
| 198 |
+
|
| 199 |
+
@[inline] def adaptMacro {m : Type → Type} [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadError m] [MonadResolveName m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] (x : Macro) (stx : Syntax) : m Syntax :=
|
| 200 |
+
liftMacroM (x stx)
|
| 201 |
+
|
| 202 |
+
partial def mkUnusedBaseName (baseName : Name) : MacroM Name := do
|
| 203 |
+
let currNamespace ← Macro.getCurrNamespace
|
| 204 |
+
if ← Macro.hasDecl (currNamespace ++ baseName) then
|
| 205 |
+
let rec loop (idx : Nat) := do
|
| 206 |
+
let name := baseName.appendIndexAfter idx
|
| 207 |
+
if ← Macro.hasDecl (currNamespace ++ name) then
|
| 208 |
+
loop (idx+1)
|
| 209 |
+
else
|
| 210 |
+
return name
|
| 211 |
+
loop 1
|
| 212 |
+
else
|
| 213 |
+
return baseName
|
| 214 |
+
|
| 215 |
+
def logException [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] [MonadLiftT IO m] (ex : Exception) : m Unit := do
|
| 216 |
+
match ex with
|
| 217 |
+
| Exception.error ref msg => logErrorAt ref msg
|
| 218 |
+
| Exception.internal id _ =>
|
| 219 |
+
unless isAbortExceptionId id || ex.isInterrupt do
|
| 220 |
+
let name ← id.getName
|
| 221 |
+
logError m!"internal exception: {name}"
|
| 222 |
+
|
| 223 |
+
/--
|
| 224 |
+
If `x` throws an exception, catch it and turn it into a log message (using `logException`).
|
| 225 |
+
-/
|
| 226 |
+
def withLogging [Monad m] [MonadLog m] [MonadExcept Exception m] [AddMessageContext m] [MonadOptions m] [MonadLiftT IO m]
|
| 227 |
+
(x : m Unit) : m Unit := do
|
| 228 |
+
try x catch ex => logException ex
|
| 229 |
+
|
| 230 |
+
def nestedExceptionToMessageData [Monad m] [MonadLog m] (ex : Exception) : m MessageData := do
|
| 231 |
+
let pos ← getRefPos
|
| 232 |
+
match ex.getRef.getPos? with
|
| 233 |
+
| none => return ex.toMessageData
|
| 234 |
+
| some exPos =>
|
| 235 |
+
if pos == exPos then
|
| 236 |
+
return ex.toMessageData
|
| 237 |
+
else
|
| 238 |
+
let exPosition := (← getFileMap).toPosition exPos
|
| 239 |
+
return m!"{exPosition.line}:{exPosition.column} {ex.toMessageData}"
|
| 240 |
+
|
| 241 |
+
def throwErrorWithNestedErrors [MonadError m] [Monad m] [MonadLog m] (msg : MessageData) (exs : Array Exception) : m α := do
|
| 242 |
+
throwError "{msg}, errors {toMessageList (← exs.mapM fun | ex => nestedExceptionToMessageData ex)}"
|
| 243 |
+
|
| 244 |
+
builtin_initialize
|
| 245 |
+
registerTraceClass `Elab
|
| 246 |
+
registerTraceClass `Elab.step
|
| 247 |
+
registerTraceClass `Elab.step.result (inherited := true)
|
| 248 |
+
|
| 249 |
+
end Lean.Elab
|
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/WhereFinally.lean
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/-
|
| 2 |
+
Copyright (c) 2025 Microsoft Corporation. All rights reserved.
|
| 3 |
+
Released under Apache 2.0 license as described in the file LICENSE.
|
| 4 |
+
Authors: Sebastian Graf
|
| 5 |
+
-/
|
| 6 |
+
prelude
|
| 7 |
+
import Lean.Parser.Term
|
| 8 |
+
|
| 9 |
+
namespace Lean.Elab
|
| 10 |
+
|
| 11 |
+
structure WhereFinallyView where
|
| 12 |
+
ref : Syntax
|
| 13 |
+
tactic : TSyntax ``Lean.Parser.Tactic.tacticSeq
|
| 14 |
+
deriving Inhabited
|
| 15 |
+
|
| 16 |
+
def WhereFinallyView.none : WhereFinallyView := { ref := .missing, tactic := ⟨.missing⟩ }
|
| 17 |
+
|
| 18 |
+
def WhereFinallyView.isNone (o : WhereFinallyView) : Bool := o.ref.isMissing && o.tactic.raw.isMissing
|
| 19 |
+
|
| 20 |
+
/-- Creates a view of the `finally` section of a `whereDecls` syntax object -/
|
| 21 |
+
def mkWhereFinallyView {m} [Monad m] [MonadError m] (stx : TSyntax ``Parser.Term.whereDecls) : m WhereFinallyView := do
|
| 22 |
+
-- Fail gracefully upon partial parses/missing where or finally sections
|
| 23 |
+
let whereFinally := stx.raw[2][0]
|
| 24 |
+
if whereFinally.isMissing then
|
| 25 |
+
return { ref := stx, tactic := ⟨.missing⟩ }
|
| 26 |
+
if !whereFinally[2][0].isMissing then
|
| 27 |
+
throwErrorAt stx "`where ... finally` does not currently support any named sub-sections `| sectionName => ...`"
|
| 28 |
+
let tactic := ⟨whereFinally[1]⟩
|
| 29 |
+
return { ref := whereFinally, tactic }
|
external/alphageometry/.venv-ag/Lib/site-packages/absl/app.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The Abseil Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Generic entry point for Abseil Python applications.
|
| 16 |
+
|
| 17 |
+
To use this module, define a ``main`` function with a single ``argv`` argument
|
| 18 |
+
and call ``app.run(main)``. For example::
|
| 19 |
+
|
| 20 |
+
def main(argv):
|
| 21 |
+
if len(argv) > 1:
|
| 22 |
+
raise app.UsageError('Too many command-line arguments.')
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
app.run(main)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import collections
|
| 29 |
+
import errno
|
| 30 |
+
import os
|
| 31 |
+
import pdb
|
| 32 |
+
import sys
|
| 33 |
+
import textwrap
|
| 34 |
+
import traceback
|
| 35 |
+
|
| 36 |
+
from absl import command_name
|
| 37 |
+
from absl import flags
|
| 38 |
+
from absl import logging
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
import faulthandler
|
| 42 |
+
except ImportError:
|
| 43 |
+
faulthandler = None
|
| 44 |
+
|
| 45 |
+
FLAGS = flags.FLAGS
|
| 46 |
+
|
| 47 |
+
flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode')
|
| 48 |
+
flags.DEFINE_boolean('pdb_post_mortem', False,
|
| 49 |
+
'Set to true to handle uncaught exceptions with PDB '
|
| 50 |
+
'post mortem.')
|
| 51 |
+
flags.DEFINE_alias('pdb', 'pdb_post_mortem')
|
| 52 |
+
flags.DEFINE_boolean('run_with_profiling', False,
|
| 53 |
+
'Set to true for profiling the script. '
|
| 54 |
+
'Execution will be slower, and the output format might '
|
| 55 |
+
'change over time.')
|
| 56 |
+
flags.DEFINE_string('profile_file', None,
|
| 57 |
+
'Dump profile information to a file (for python -m '
|
| 58 |
+
'pstats). Implies --run_with_profiling.')
|
| 59 |
+
flags.DEFINE_boolean('use_cprofile_for_profiling', True,
|
| 60 |
+
'Use cProfile instead of the profile module for '
|
| 61 |
+
'profiling. This has no effect unless '
|
| 62 |
+
'--run_with_profiling is set.')
|
| 63 |
+
flags.DEFINE_boolean('only_check_args', False,
|
| 64 |
+
'Set to true to validate args and exit.',
|
| 65 |
+
allow_hide_cpp=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# If main() exits via an abnormal exception, call into these
|
| 69 |
+
# handlers before exiting.
|
| 70 |
+
EXCEPTION_HANDLERS = []
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Error(Exception):
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class UsageError(Error):
|
| 78 |
+
"""Exception raised when the arguments supplied by the user are invalid.
|
| 79 |
+
|
| 80 |
+
Raise this when the arguments supplied are invalid from the point of
|
| 81 |
+
view of the application. For example when two mutually exclusive
|
| 82 |
+
flags have been supplied or when there are not enough non-flag
|
| 83 |
+
arguments. It is distinct from flags.Error which covers the lower
|
| 84 |
+
level of parsing and validating individual flags.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, message, exitcode=1):
|
| 88 |
+
super().__init__(message)
|
| 89 |
+
self.exitcode = exitcode
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class HelpFlag(flags.BooleanFlag):
|
| 93 |
+
"""Special boolean flag that displays usage and raises SystemExit."""
|
| 94 |
+
NAME = 'help'
|
| 95 |
+
SHORT_NAME = '?'
|
| 96 |
+
|
| 97 |
+
def __init__(self):
|
| 98 |
+
super().__init__(
|
| 99 |
+
self.NAME,
|
| 100 |
+
False,
|
| 101 |
+
'show this help',
|
| 102 |
+
short_name=self.SHORT_NAME,
|
| 103 |
+
allow_hide_cpp=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def parse(self, arg):
|
| 107 |
+
if self._parse(arg):
|
| 108 |
+
usage(shorthelp=True, writeto_stdout=True)
|
| 109 |
+
# Advertise --helpfull on stdout, since usage() was on stdout.
|
| 110 |
+
print()
|
| 111 |
+
print('Try --helpfull to get a list of all flags.')
|
| 112 |
+
sys.exit(1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class HelpshortFlag(HelpFlag):
|
| 116 |
+
"""--helpshort is an alias for --help."""
|
| 117 |
+
NAME = 'helpshort'
|
| 118 |
+
SHORT_NAME = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class HelpfullFlag(flags.BooleanFlag):
|
| 122 |
+
"""Display help for flags in the main module and all dependent modules."""
|
| 123 |
+
|
| 124 |
+
def __init__(self):
|
| 125 |
+
super().__init__('helpfull', False, 'show full help', allow_hide_cpp=True)
|
| 126 |
+
|
| 127 |
+
def parse(self, arg):
|
| 128 |
+
if self._parse(arg):
|
| 129 |
+
usage(writeto_stdout=True)
|
| 130 |
+
sys.exit(1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class HelpXMLFlag(flags.BooleanFlag):
|
| 134 |
+
"""Similar to HelpfullFlag, but generates output in XML format."""
|
| 135 |
+
|
| 136 |
+
def __init__(self):
|
| 137 |
+
super().__init__(
|
| 138 |
+
'helpxml',
|
| 139 |
+
False,
|
| 140 |
+
'like --helpfull, but generates XML output',
|
| 141 |
+
allow_hide_cpp=True,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def parse(self, arg):
|
| 145 |
+
if self._parse(arg):
|
| 146 |
+
flags.FLAGS.write_help_in_xml_format(sys.stdout)
|
| 147 |
+
sys.exit(1)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def parse_flags_with_usage(args):
|
| 151 |
+
"""Tries to parse the flags, print usage, and exit if unparsable.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
args: [str], a non-empty list of the command line arguments including
|
| 155 |
+
program name.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
[str], a non-empty list of remaining command line arguments after parsing
|
| 159 |
+
flags, including program name.
|
| 160 |
+
"""
|
| 161 |
+
try:
|
| 162 |
+
return FLAGS(args)
|
| 163 |
+
except flags.Error as error:
|
| 164 |
+
message = str(error)
|
| 165 |
+
if '\n' in message:
|
| 166 |
+
final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent(
|
| 167 |
+
message, ' ')
|
| 168 |
+
else:
|
| 169 |
+
final_message = 'FATAL Flags parsing error: %s\n' % message
|
| 170 |
+
sys.stderr.write(final_message)
|
| 171 |
+
sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n')
|
| 172 |
+
sys.exit(1)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
_define_help_flags_called = False
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def define_help_flags():
|
| 179 |
+
"""Registers help flags. Idempotent."""
|
| 180 |
+
# Use a global to ensure idempotence.
|
| 181 |
+
global _define_help_flags_called
|
| 182 |
+
|
| 183 |
+
if not _define_help_flags_called:
|
| 184 |
+
flags.DEFINE_flag(HelpFlag())
|
| 185 |
+
flags.DEFINE_flag(HelpshortFlag()) # alias for --help
|
| 186 |
+
flags.DEFINE_flag(HelpfullFlag())
|
| 187 |
+
flags.DEFINE_flag(HelpXMLFlag())
|
| 188 |
+
_define_help_flags_called = True
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _register_and_parse_flags_with_usage(
|
| 192 |
+
argv=None,
|
| 193 |
+
flags_parser=parse_flags_with_usage,
|
| 194 |
+
):
|
| 195 |
+
"""Registers help flags, parses arguments and shows usage if appropriate.
|
| 196 |
+
|
| 197 |
+
This also calls sys.exit(0) if flag --only_check_args is True.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
argv: [str], a non-empty list of the command line arguments including
|
| 201 |
+
program name, sys.argv is used if None.
|
| 202 |
+
flags_parser: Callable[[List[str]], Any], the function used to parse flags.
|
| 203 |
+
The return value of this function is passed to `main` untouched. It must
|
| 204 |
+
guarantee FLAGS is parsed after this function is called.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
The return value of `flags_parser`. When using the default `flags_parser`,
|
| 208 |
+
it returns the following:
|
| 209 |
+
[str], a non-empty list of remaining command line arguments after parsing
|
| 210 |
+
flags, including program name.
|
| 211 |
+
|
| 212 |
+
Raises:
|
| 213 |
+
Error: Raised when flags_parser is called, but FLAGS is not parsed.
|
| 214 |
+
SystemError: Raised when it's called more than once.
|
| 215 |
+
"""
|
| 216 |
+
# fmt: on
|
| 217 |
+
if _register_and_parse_flags_with_usage.done:
|
| 218 |
+
raise SystemError('Flag registration can be done only once.')
|
| 219 |
+
|
| 220 |
+
define_help_flags()
|
| 221 |
+
|
| 222 |
+
original_argv = sys.argv if argv is None else argv
|
| 223 |
+
args_to_main = flags_parser(original_argv)
|
| 224 |
+
if not FLAGS.is_parsed():
|
| 225 |
+
raise Error('FLAGS must be parsed after flags_parser is called.')
|
| 226 |
+
|
| 227 |
+
# Exit when told so.
|
| 228 |
+
if FLAGS.only_check_args:
|
| 229 |
+
sys.exit(0)
|
| 230 |
+
# Immediately after flags are parsed, bump verbosity to INFO if the flag has
|
| 231 |
+
# not been set.
|
| 232 |
+
if FLAGS['verbosity'].using_default_value:
|
| 233 |
+
FLAGS.verbosity = 0
|
| 234 |
+
_register_and_parse_flags_with_usage.done = True
|
| 235 |
+
|
| 236 |
+
return args_to_main
|
| 237 |
+
|
| 238 |
+
_register_and_parse_flags_with_usage.done = False
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _run_main(main, argv):
|
| 242 |
+
"""Calls main, optionally with pdb or profiler."""
|
| 243 |
+
if FLAGS.run_with_pdb:
|
| 244 |
+
sys.exit(pdb.runcall(main, argv))
|
| 245 |
+
elif FLAGS.run_with_profiling or FLAGS.profile_file:
|
| 246 |
+
# Avoid import overhead since most apps (including performance-sensitive
|
| 247 |
+
# ones) won't be run with profiling.
|
| 248 |
+
# pylint: disable=g-import-not-at-top
|
| 249 |
+
import atexit
|
| 250 |
+
if FLAGS.use_cprofile_for_profiling:
|
| 251 |
+
import cProfile as profile
|
| 252 |
+
else:
|
| 253 |
+
import profile
|
| 254 |
+
profiler = profile.Profile()
|
| 255 |
+
if FLAGS.profile_file:
|
| 256 |
+
atexit.register(profiler.dump_stats, FLAGS.profile_file)
|
| 257 |
+
else:
|
| 258 |
+
atexit.register(profiler.print_stats)
|
| 259 |
+
sys.exit(profiler.runcall(main, argv))
|
| 260 |
+
else:
|
| 261 |
+
sys.exit(main(argv))
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _call_exception_handlers(exception):
|
| 265 |
+
"""Calls any installed exception handlers."""
|
| 266 |
+
for handler in EXCEPTION_HANDLERS:
|
| 267 |
+
try:
|
| 268 |
+
if handler.wants(exception):
|
| 269 |
+
handler.handle(exception)
|
| 270 |
+
except: # pylint: disable=bare-except
|
| 271 |
+
try:
|
| 272 |
+
# We don't want to stop for exceptions in the exception handlers but
|
| 273 |
+
# we shouldn't hide them either.
|
| 274 |
+
logging.error(traceback.format_exc())
|
| 275 |
+
except: # pylint: disable=bare-except
|
| 276 |
+
# In case even the logging statement fails, ignore.
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def run(
|
| 281 |
+
main,
|
| 282 |
+
argv=None,
|
| 283 |
+
flags_parser=parse_flags_with_usage,
|
| 284 |
+
):
|
| 285 |
+
"""Begins executing the program.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
main: The main function to execute. It takes an single argument "argv",
|
| 289 |
+
which is a list of command line arguments with parsed flags removed.
|
| 290 |
+
The return value is passed to `sys.exit`, and so for example
|
| 291 |
+
a return value of 0 or None results in a successful termination, whereas
|
| 292 |
+
a return value of 1 results in abnormal termination.
|
| 293 |
+
For more details, see https://docs.python.org/3/library/sys#sys.exit
|
| 294 |
+
argv: A non-empty list of the command line arguments including program name,
|
| 295 |
+
sys.argv is used if None.
|
| 296 |
+
flags_parser: Callable[[List[str]], Any], the function used to parse flags.
|
| 297 |
+
The return value of this function is passed to `main` untouched.
|
| 298 |
+
It must guarantee FLAGS is parsed after this function is called.
|
| 299 |
+
Should be passed as a keyword-only arg which will become mandatory in a
|
| 300 |
+
future release.
|
| 301 |
+
- Parses command line flags with the flag module.
|
| 302 |
+
- If there are any errors, prints usage().
|
| 303 |
+
- Calls main() with the remaining arguments.
|
| 304 |
+
- If main() raises a UsageError, prints usage and the error message.
|
| 305 |
+
"""
|
| 306 |
+
# fmt: on
|
| 307 |
+
try:
|
| 308 |
+
args = _run_init(
|
| 309 |
+
sys.argv if argv is None else argv,
|
| 310 |
+
flags_parser,
|
| 311 |
+
)
|
| 312 |
+
while _init_callbacks:
|
| 313 |
+
callback = _init_callbacks.popleft()
|
| 314 |
+
callback()
|
| 315 |
+
try:
|
| 316 |
+
_run_main(main, args)
|
| 317 |
+
except UsageError as error:
|
| 318 |
+
usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
|
| 319 |
+
except:
|
| 320 |
+
exc = sys.exc_info()[1]
|
| 321 |
+
# Don't try to post-mortem debug successful SystemExits, since those
|
| 322 |
+
# mean there wasn't actually an error. In particular, the test framework
|
| 323 |
+
# raises SystemExit(False) even if all tests passed.
|
| 324 |
+
if isinstance(exc, SystemExit) and not exc.code:
|
| 325 |
+
raise
|
| 326 |
+
|
| 327 |
+
# Check the tty so that we don't hang waiting for input in an
|
| 328 |
+
# non-interactive scenario.
|
| 329 |
+
if FLAGS.pdb_post_mortem and sys.stdout.isatty():
|
| 330 |
+
traceback.print_exc()
|
| 331 |
+
print()
|
| 332 |
+
print(' *** Entering post-mortem debugging ***')
|
| 333 |
+
print()
|
| 334 |
+
pdb.post_mortem()
|
| 335 |
+
raise
|
| 336 |
+
except Exception as e:
|
| 337 |
+
_call_exception_handlers(e)
|
| 338 |
+
raise
|
| 339 |
+
|
| 340 |
+
# Callbacks which have been deferred until after _run_init has been called.
|
| 341 |
+
_init_callbacks = collections.deque()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def call_after_init(callback):
|
| 345 |
+
"""Calls the given callback only once ABSL has finished initialization.
|
| 346 |
+
|
| 347 |
+
If ABSL has already finished initialization when ``call_after_init`` is
|
| 348 |
+
called then the callback is executed immediately, otherwise `callback` is
|
| 349 |
+
stored to be executed after ``app.run`` has finished initializing (aka. just
|
| 350 |
+
before the main function is called).
|
| 351 |
+
|
| 352 |
+
If called after ``app.run``, this is equivalent to calling ``callback()`` in
|
| 353 |
+
the caller thread. If called before ``app.run``, callbacks are run
|
| 354 |
+
sequentially (in an undefined order) in the same thread as ``app.run``.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
callback: a callable to be called once ABSL has finished initialization.
|
| 358 |
+
This may be immediate if initialization has already finished. It
|
| 359 |
+
takes no arguments and returns nothing.
|
| 360 |
+
"""
|
| 361 |
+
if _run_init.done:
|
| 362 |
+
callback()
|
| 363 |
+
else:
|
| 364 |
+
_init_callbacks.append(callback)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _run_init(
|
| 368 |
+
argv,
|
| 369 |
+
flags_parser,
|
| 370 |
+
):
|
| 371 |
+
"""Does one-time initialization and re-parses flags on rerun."""
|
| 372 |
+
if _run_init.done:
|
| 373 |
+
return flags_parser(argv)
|
| 374 |
+
command_name.make_process_name_useful()
|
| 375 |
+
# Set up absl logging handler.
|
| 376 |
+
logging.use_absl_handler()
|
| 377 |
+
args = _register_and_parse_flags_with_usage(
|
| 378 |
+
argv=argv,
|
| 379 |
+
flags_parser=flags_parser,
|
| 380 |
+
)
|
| 381 |
+
if faulthandler:
|
| 382 |
+
try:
|
| 383 |
+
faulthandler.enable()
|
| 384 |
+
except Exception: # pylint: disable=broad-except
|
| 385 |
+
# Some tests verify stderr output very closely, so don't print anything.
|
| 386 |
+
# Disabled faulthandler is a low-impact error.
|
| 387 |
+
pass
|
| 388 |
+
_run_init.done = True
|
| 389 |
+
return args
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
_run_init.done = False
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def usage(shorthelp=False, writeto_stdout=False, detailed_error=None,
|
| 396 |
+
exitcode=None):
|
| 397 |
+
"""Writes __main__'s docstring to stderr with some help text.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
shorthelp: bool, if True, prints only flags from the main module,
|
| 401 |
+
rather than all flags.
|
| 402 |
+
writeto_stdout: bool, if True, writes help message to stdout,
|
| 403 |
+
rather than to stderr.
|
| 404 |
+
detailed_error: str, additional detail about why usage info was presented.
|
| 405 |
+
exitcode: optional integer, if set, exits with this status code after
|
| 406 |
+
writing help.
|
| 407 |
+
"""
|
| 408 |
+
if writeto_stdout:
|
| 409 |
+
stdfile = sys.stdout
|
| 410 |
+
else:
|
| 411 |
+
stdfile = sys.stderr
|
| 412 |
+
|
| 413 |
+
doc = sys.modules['__main__'].__doc__
|
| 414 |
+
if not doc:
|
| 415 |
+
doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
|
| 416 |
+
doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
|
| 417 |
+
else:
|
| 418 |
+
# Replace all '%s' with sys.argv[0], and all '%%' with '%'.
|
| 419 |
+
num_specifiers = doc.count('%') - 2 * doc.count('%%')
|
| 420 |
+
try:
|
| 421 |
+
doc %= (sys.argv[0],) * num_specifiers
|
| 422 |
+
except (OverflowError, TypeError, ValueError):
|
| 423 |
+
# Just display the docstring as-is.
|
| 424 |
+
pass
|
| 425 |
+
if shorthelp:
|
| 426 |
+
flag_str = FLAGS.main_module_help()
|
| 427 |
+
else:
|
| 428 |
+
flag_str = FLAGS.get_help()
|
| 429 |
+
try:
|
| 430 |
+
stdfile.write(doc)
|
| 431 |
+
if flag_str:
|
| 432 |
+
stdfile.write('\nflags:\n')
|
| 433 |
+
stdfile.write(flag_str)
|
| 434 |
+
stdfile.write('\n')
|
| 435 |
+
if detailed_error is not None:
|
| 436 |
+
stdfile.write('\n%s\n' % detailed_error)
|
| 437 |
+
except OSError as e:
|
| 438 |
+
# We avoid printing a huge backtrace if we get EPIPE, because
|
| 439 |
+
# "foo.par --help | less" is a frequent use case.
|
| 440 |
+
if e.errno != errno.EPIPE:
|
| 441 |
+
raise
|
| 442 |
+
if exitcode is not None:
|
| 443 |
+
sys.exit(exitcode)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class ExceptionHandler:
|
| 447 |
+
"""Base exception handler from which other may inherit."""
|
| 448 |
+
|
| 449 |
+
def wants(self, exc):
|
| 450 |
+
"""Returns whether this handler wants to handle the exception or not.
|
| 451 |
+
|
| 452 |
+
This base class returns True for all exceptions by default. Override in
|
| 453 |
+
subclass if it wants to be more selective.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
exc: Exception, the current exception.
|
| 457 |
+
"""
|
| 458 |
+
del exc # Unused.
|
| 459 |
+
return True
|
| 460 |
+
|
| 461 |
+
def handle(self, exc):
|
| 462 |
+
"""Do something with the current exception.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
exc: Exception, the current exception
|
| 466 |
+
|
| 467 |
+
This method must be overridden.
|
| 468 |
+
"""
|
| 469 |
+
raise NotImplementedError()
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def install_exception_handler(handler):
|
| 473 |
+
"""Installs an exception handler.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
handler: ExceptionHandler, the exception handler to install.
|
| 477 |
+
|
| 478 |
+
Raises:
|
| 479 |
+
TypeError: Raised when the handler was not of the correct type.
|
| 480 |
+
|
| 481 |
+
All installed exception handlers will be called if main() exits via
|
| 482 |
+
an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt,
|
| 483 |
+
FlagsError or UsageError.
|
| 484 |
+
"""
|
| 485 |
+
if not isinstance(handler, ExceptionHandler):
|
| 486 |
+
raise TypeError('handler of type %s does not inherit from ExceptionHandler'
|
| 487 |
+
% type(handler))
|
| 488 |
+
EXCEPTION_HANDLERS.append(handler)
|
external/alphageometry/.venv-ag/Lib/site-packages/absl/app.pyi
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, TypeVar, Union, overload
|
| 2 |
+
|
| 3 |
+
from absl.flags import _flag
|
| 4 |
+
|
| 5 |
+
_MainArgs = TypeVar('_MainArgs')
|
| 6 |
+
_Exc = TypeVar('_Exc', bound=Exception)
|
| 7 |
+
|
| 8 |
+
class ExceptionHandler():
|
| 9 |
+
|
| 10 |
+
def wants(self, exc: _Exc) -> bool:
|
| 11 |
+
...
|
| 12 |
+
|
| 13 |
+
def handle(self, exc: _Exc):
|
| 14 |
+
...
|
| 15 |
+
|
| 16 |
+
EXCEPTION_HANDLERS: List[ExceptionHandler] = ...
|
| 17 |
+
|
| 18 |
+
class HelpFlag(_flag.BooleanFlag):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
...
|
| 21 |
+
|
| 22 |
+
class HelpshortFlag(HelpFlag):
|
| 23 |
+
...
|
| 24 |
+
|
| 25 |
+
class HelpfullFlag(_flag.BooleanFlag):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
class HelpXMLFlag(_flag.BooleanFlag):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
...
|
| 32 |
+
|
| 33 |
+
def define_help_flags() -> None:
|
| 34 |
+
...
|
| 35 |
+
|
| 36 |
+
@overload
|
| 37 |
+
def usage(shorthelp: Union[bool, int] = ...,
|
| 38 |
+
writeto_stdout: Union[bool, int] = ...,
|
| 39 |
+
detailed_error: Optional[Any] = ...,
|
| 40 |
+
exitcode: None = ...) -> None:
|
| 41 |
+
...
|
| 42 |
+
|
| 43 |
+
@overload
|
| 44 |
+
def usage(shorthelp: Union[bool, int],
|
| 45 |
+
writeto_stdout: Union[bool, int],
|
| 46 |
+
detailed_error: Optional[Any],
|
| 47 |
+
exitcode: int) -> NoReturn:
|
| 48 |
+
...
|
| 49 |
+
|
| 50 |
+
@overload
|
| 51 |
+
def usage(shorthelp: Union[bool, int] = ...,
|
| 52 |
+
writeto_stdout: Union[bool, int] = ...,
|
| 53 |
+
detailed_error: Optional[Any] = ...,
|
| 54 |
+
*,
|
| 55 |
+
exitcode: int) -> NoReturn:
|
| 56 |
+
...
|
| 57 |
+
|
| 58 |
+
def install_exception_handler(handler: ExceptionHandler) -> None:
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
class Error(Exception):
|
| 62 |
+
...
|
| 63 |
+
|
| 64 |
+
class UsageError(Error):
|
| 65 |
+
exitcode: int
|
| 66 |
+
|
| 67 |
+
def parse_flags_with_usage(args: List[str]) -> List[str]:
|
| 68 |
+
...
|
| 69 |
+
|
| 70 |
+
def call_after_init(callback: Callable[[], Any]) -> None:
|
| 71 |
+
...
|
| 72 |
+
|
| 73 |
+
# Without the flag_parser argument, `main` should require a List[str].
|
| 74 |
+
@overload
|
| 75 |
+
def run(
|
| 76 |
+
main: Callable[[List[str]], Any],
|
| 77 |
+
argv: Optional[List[str]] = ...,
|
| 78 |
+
) -> NoReturn:
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
@overload
|
| 82 |
+
def run(
|
| 83 |
+
main: Callable[[_MainArgs], Any],
|
| 84 |
+
argv: Optional[List[str]] = ...,
|
| 85 |
+
*,
|
| 86 |
+
flags_parser: Callable[[List[str]], _MainArgs],
|
| 87 |
+
) -> NoReturn:
|
| 88 |
+
...
|
external/alphageometry/.venv-ag/Lib/site-packages/absl/command_name.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The Abseil Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A tiny stand alone library to change the kernel process name on Linux."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
# This library must be kept small and stand alone. It is used by small things
|
| 21 |
+
# that require no extension modules.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def make_process_name_useful():
|
| 25 |
+
"""Sets the process name to something better than 'python' if possible."""
|
| 26 |
+
set_kernel_process_name(os.path.basename(sys.argv[0]))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def set_kernel_process_name(name):
|
| 30 |
+
"""Changes the Kernel's /proc/self/status process name on Linux.
|
| 31 |
+
|
| 32 |
+
The kernel name is NOT what will be shown by the ps or top command.
|
| 33 |
+
It is a 15 character string stored in the kernel's process table that
|
| 34 |
+
is included in the kernel log when a process is OOM killed.
|
| 35 |
+
The first 15 bytes of name are used. Non-ASCII unicode is replaced with '?'.
|
| 36 |
+
|
| 37 |
+
Does nothing if /proc/self/comm cannot be written or prctl() fails.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
name: bytes|unicode, the Linux kernel's command name to set.
|
| 41 |
+
"""
|
| 42 |
+
if not isinstance(name, bytes):
|
| 43 |
+
name = name.encode('ascii', 'replace')
|
| 44 |
+
try:
|
| 45 |
+
# This is preferred to using ctypes to try and call prctl() when possible.
|
| 46 |
+
with open('/proc/self/comm', 'wb') as proc_comm:
|
| 47 |
+
proc_comm.write(name[:15])
|
| 48 |
+
except OSError:
|
| 49 |
+
try:
|
| 50 |
+
import ctypes # pylint: disable=g-import-not-at-top
|
| 51 |
+
except ImportError:
|
| 52 |
+
return # No ctypes.
|
| 53 |
+
try:
|
| 54 |
+
libc = ctypes.CDLL('libc.so.6')
|
| 55 |
+
except OSError:
|
| 56 |
+
return # No libc.so.6.
|
| 57 |
+
pr_set_name = ctypes.c_ulong(15) # linux/prctl.h PR_SET_NAME value.
|
| 58 |
+
zero = ctypes.c_ulong(0)
|
| 59 |
+
try:
|
| 60 |
+
libc.prctl(pr_set_name, name, zero, zero, zero)
|
| 61 |
+
# Ignore the prctl return value. Nothing we can do if it errored.
|
| 62 |
+
except AttributeError:
|
| 63 |
+
return # No prctl.
|
external/alphageometry/.venv-ag/Lib/site-packages/absl/py.typed
ADDED
|
File without changes
|
external/alphageometry/.venv-ag/Lib/site-packages/distutils-precedence.pth
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'local') == 'local'; enabled and __import__('_distutils_hack').add_shim();
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Note: import <name> as <name> is required for names to be exported.
|
| 16 |
+
# See PEP 484 & https://github.com/google/jax/issues/7570
|
| 17 |
+
|
| 18 |
+
from jax.experimental.x64_context import (
|
| 19 |
+
enable_x64 as enable_x64,
|
| 20 |
+
disable_x64 as disable_x64,
|
| 21 |
+
)
|
| 22 |
+
from jax._src.callback import (
|
| 23 |
+
io_callback as io_callback
|
| 24 |
+
)
|
| 25 |
+
from jax._src.earray import (
|
| 26 |
+
EArray as EArray
|
| 27 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/__init__.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
This module includes experimental JAX support for the `Python array API standard`_.
|
| 17 |
+
Support for this is currently experimental and not fully complete.
|
| 18 |
+
|
| 19 |
+
Example Usage::
|
| 20 |
+
|
| 21 |
+
>>> from jax.experimental import array_api as xp
|
| 22 |
+
|
| 23 |
+
>>> xp.__array_api_version__
|
| 24 |
+
'2023.12'
|
| 25 |
+
|
| 26 |
+
>>> arr = xp.arange(1000)
|
| 27 |
+
|
| 28 |
+
>>> arr.sum()
|
| 29 |
+
Array(499500, dtype=int32)
|
| 30 |
+
|
| 31 |
+
The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
|
| 32 |
+
and implements most of the API listed in the standard.
|
| 33 |
+
|
| 34 |
+
.. _Python array API standard: https://data-apis.org/array-api/latest/
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
|
| 40 |
+
|
| 41 |
+
from jax.experimental.array_api import fft as fft
|
| 42 |
+
from jax.experimental.array_api import linalg as linalg
|
| 43 |
+
|
| 44 |
+
from jax.numpy import (
|
| 45 |
+
abs as abs,
|
| 46 |
+
acos as acos,
|
| 47 |
+
acosh as acosh,
|
| 48 |
+
add as add,
|
| 49 |
+
all as all,
|
| 50 |
+
any as any,
|
| 51 |
+
argmax as argmax,
|
| 52 |
+
argmin as argmin,
|
| 53 |
+
argsort as argsort,
|
| 54 |
+
asin as asin,
|
| 55 |
+
asinh as asinh,
|
| 56 |
+
atan as atan,
|
| 57 |
+
atan2 as atan2,
|
| 58 |
+
atanh as atanh,
|
| 59 |
+
bitwise_and as bitwise_and,
|
| 60 |
+
bitwise_invert as bitwise_invert,
|
| 61 |
+
bitwise_left_shift as bitwise_left_shift,
|
| 62 |
+
bitwise_or as bitwise_or,
|
| 63 |
+
bitwise_right_shift as bitwise_right_shift,
|
| 64 |
+
bitwise_xor as bitwise_xor,
|
| 65 |
+
bool as bool,
|
| 66 |
+
broadcast_arrays as broadcast_arrays,
|
| 67 |
+
broadcast_to as broadcast_to,
|
| 68 |
+
can_cast as can_cast,
|
| 69 |
+
complex128 as complex128,
|
| 70 |
+
complex64 as complex64,
|
| 71 |
+
concat as concat,
|
| 72 |
+
conj as conj,
|
| 73 |
+
copysign as copysign,
|
| 74 |
+
cos as cos,
|
| 75 |
+
cosh as cosh,
|
| 76 |
+
cumulative_sum as cumulative_sum,
|
| 77 |
+
divide as divide,
|
| 78 |
+
e as e,
|
| 79 |
+
empty as empty,
|
| 80 |
+
empty_like as empty_like,
|
| 81 |
+
equal as equal,
|
| 82 |
+
exp as exp,
|
| 83 |
+
expand_dims as expand_dims,
|
| 84 |
+
expm1 as expm1,
|
| 85 |
+
flip as flip,
|
| 86 |
+
float32 as float32,
|
| 87 |
+
float64 as float64,
|
| 88 |
+
floor_divide as floor_divide,
|
| 89 |
+
from_dlpack as from_dlpack,
|
| 90 |
+
full as full,
|
| 91 |
+
full_like as full_like,
|
| 92 |
+
greater as greater,
|
| 93 |
+
greater_equal as greater_equal,
|
| 94 |
+
iinfo as iinfo,
|
| 95 |
+
imag as imag,
|
| 96 |
+
inf as inf,
|
| 97 |
+
int16 as int16,
|
| 98 |
+
int32 as int32,
|
| 99 |
+
int64 as int64,
|
| 100 |
+
int8 as int8,
|
| 101 |
+
isdtype as isdtype,
|
| 102 |
+
isfinite as isfinite,
|
| 103 |
+
isinf as isinf,
|
| 104 |
+
isnan as isnan,
|
| 105 |
+
less as less,
|
| 106 |
+
less_equal as less_equal,
|
| 107 |
+
log as log,
|
| 108 |
+
log10 as log10,
|
| 109 |
+
log1p as log1p,
|
| 110 |
+
log2 as log2,
|
| 111 |
+
logaddexp as logaddexp,
|
| 112 |
+
logical_and as logical_and,
|
| 113 |
+
logical_not as logical_not,
|
| 114 |
+
logical_or as logical_or,
|
| 115 |
+
logical_xor as logical_xor,
|
| 116 |
+
matmul as matmul,
|
| 117 |
+
matrix_transpose as matrix_transpose,
|
| 118 |
+
max as max,
|
| 119 |
+
maximum as maximum,
|
| 120 |
+
mean as mean,
|
| 121 |
+
meshgrid as meshgrid,
|
| 122 |
+
min as min,
|
| 123 |
+
minimum as minimum,
|
| 124 |
+
moveaxis as moveaxis,
|
| 125 |
+
multiply as multiply,
|
| 126 |
+
nan as nan,
|
| 127 |
+
negative as negative,
|
| 128 |
+
newaxis as newaxis,
|
| 129 |
+
nonzero as nonzero,
|
| 130 |
+
not_equal as not_equal,
|
| 131 |
+
ones as ones,
|
| 132 |
+
ones_like as ones_like,
|
| 133 |
+
permute_dims as permute_dims,
|
| 134 |
+
pi as pi,
|
| 135 |
+
positive as positive,
|
| 136 |
+
pow as pow,
|
| 137 |
+
prod as prod,
|
| 138 |
+
real as real,
|
| 139 |
+
remainder as remainder,
|
| 140 |
+
repeat as repeat,
|
| 141 |
+
result_type as result_type,
|
| 142 |
+
roll as roll,
|
| 143 |
+
round as round,
|
| 144 |
+
searchsorted as searchsorted,
|
| 145 |
+
sign as sign,
|
| 146 |
+
signbit as signbit,
|
| 147 |
+
sin as sin,
|
| 148 |
+
sinh as sinh,
|
| 149 |
+
sort as sort,
|
| 150 |
+
sqrt as sqrt,
|
| 151 |
+
square as square,
|
| 152 |
+
squeeze as squeeze,
|
| 153 |
+
stack as stack,
|
| 154 |
+
subtract as subtract,
|
| 155 |
+
sum as sum,
|
| 156 |
+
take as take,
|
| 157 |
+
tan as tan,
|
| 158 |
+
tanh as tanh,
|
| 159 |
+
tensordot as tensordot,
|
| 160 |
+
tile as tile,
|
| 161 |
+
tril as tril,
|
| 162 |
+
triu as triu,
|
| 163 |
+
uint16 as uint16,
|
| 164 |
+
uint32 as uint32,
|
| 165 |
+
uint64 as uint64,
|
| 166 |
+
uint8 as uint8,
|
| 167 |
+
unique_all as unique_all,
|
| 168 |
+
unique_counts as unique_counts,
|
| 169 |
+
unique_inverse as unique_inverse,
|
| 170 |
+
unique_values as unique_values,
|
| 171 |
+
unstack as unstack,
|
| 172 |
+
vecdot as vecdot,
|
| 173 |
+
where as where,
|
| 174 |
+
zeros as zeros,
|
| 175 |
+
zeros_like as zeros_like,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
from jax.experimental.array_api._manipulation_functions import (
|
| 179 |
+
reshape as reshape,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
from jax.experimental.array_api._creation_functions import (
|
| 183 |
+
arange as arange,
|
| 184 |
+
asarray as asarray,
|
| 185 |
+
eye as eye,
|
| 186 |
+
linspace as linspace,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
from jax.experimental.array_api._data_type_functions import (
|
| 190 |
+
astype as astype,
|
| 191 |
+
finfo as finfo,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
from jax.experimental.array_api._elementwise_functions import (
|
| 195 |
+
ceil as ceil,
|
| 196 |
+
clip as clip,
|
| 197 |
+
floor as floor,
|
| 198 |
+
hypot as hypot,
|
| 199 |
+
trunc as trunc,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
from jax.experimental.array_api._statistical_functions import (
|
| 203 |
+
std as std,
|
| 204 |
+
var as var,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
from jax.experimental.array_api._utility_functions import (
|
| 208 |
+
__array_namespace_info__ as __array_namespace_info__,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
from jax.experimental.array_api import _array_methods
|
| 212 |
+
_array_methods.add_array_object_methods()
|
| 213 |
+
del _array_methods
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_array_methods.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import jax
|
| 20 |
+
from jax._src.array import ArrayImpl
|
| 21 |
+
from jax.experimental.array_api._version import __array_api_version__
|
| 22 |
+
from jax.sharding import Sharding
|
| 23 |
+
|
| 24 |
+
from jax._src.lib import xla_extension as xe
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _array_namespace(self, /, *, api_version: None | str = None):
|
| 28 |
+
if api_version is not None and api_version != __array_api_version__:
|
| 29 |
+
raise ValueError(f"{api_version=!r} is not available; "
|
| 30 |
+
f"available versions are: {[__array_api_version__]}")
|
| 31 |
+
return jax.experimental.array_api
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _to_device(self, device: xe.Device | Sharding | None, *,
|
| 35 |
+
stream: int | Any | None = None):
|
| 36 |
+
if stream is not None:
|
| 37 |
+
raise NotImplementedError("stream argument of array.to_device()")
|
| 38 |
+
return jax.device_put(self, device)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def add_array_object_methods():
|
| 42 |
+
# TODO(jakevdp): set on tracers as well?
|
| 43 |
+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
|
| 44 |
+
setattr(ArrayImpl, "to_device", _to_device)
|
| 45 |
+
setattr(ArrayImpl, "device", property(lambda self: self.sharding))
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_creation_functions.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import jax
|
| 18 |
+
import jax.numpy as jnp
|
| 19 |
+
|
| 20 |
+
# TODO(micky774): Deprecate after adding device argument to jax.numpy functions
|
| 21 |
+
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
|
| 22 |
+
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
|
| 23 |
+
|
| 24 |
+
def asarray(obj, /, *, dtype=None, device=None, copy=None):
|
| 25 |
+
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
|
| 26 |
+
|
| 27 |
+
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
|
| 28 |
+
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
|
| 29 |
+
|
| 30 |
+
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
|
| 31 |
+
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_data_type_functions.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import builtins
|
| 18 |
+
from typing import NamedTuple
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
|
| 23 |
+
from jax._src.lib import xla_client as xc
|
| 24 |
+
from jax._src.sharding import Sharding
|
| 25 |
+
from jax._src import dtypes as _dtypes
|
| 26 |
+
|
| 27 |
+
# TODO(micky774): Update jax.numpy dtypes to dtype *objects*
|
| 28 |
+
bool = np.dtype('bool')
|
| 29 |
+
int8 = np.dtype('int8')
|
| 30 |
+
int16 = np.dtype('int16')
|
| 31 |
+
int32 = np.dtype('int32')
|
| 32 |
+
int64 = np.dtype('int64')
|
| 33 |
+
uint8 = np.dtype('uint8')
|
| 34 |
+
uint16 = np.dtype('uint16')
|
| 35 |
+
uint32 = np.dtype('uint32')
|
| 36 |
+
uint64 = np.dtype('uint64')
|
| 37 |
+
float32 = np.dtype('float32')
|
| 38 |
+
float64 = np.dtype('float64')
|
| 39 |
+
complex64 = np.dtype('complex64')
|
| 40 |
+
complex128 = np.dtype('complex128')
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed
|
| 44 |
+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
|
| 45 |
+
src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
|
| 46 |
+
if (
|
| 47 |
+
src_dtype is not None
|
| 48 |
+
and _dtypes.isdtype(src_dtype, "complex floating")
|
| 49 |
+
and _dtypes.isdtype(dtype, ("integral", "real floating"))
|
| 50 |
+
):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"Casting from complex to non-complex dtypes is not permitted. Please "
|
| 53 |
+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
|
| 54 |
+
"your input."
|
| 55 |
+
)
|
| 56 |
+
return jnp.astype(x, dtype, copy=copy, device=device)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class FInfo(NamedTuple):
|
| 60 |
+
bits: int
|
| 61 |
+
eps: float
|
| 62 |
+
max: float
|
| 63 |
+
min: float
|
| 64 |
+
smallest_normal: float
|
| 65 |
+
dtype: jnp.dtype
|
| 66 |
+
|
| 67 |
+
# TODO(micky774): Update jax.numpy.finfo so that its attributes are python
|
| 68 |
+
# floats
|
| 69 |
+
def finfo(type, /) -> FInfo:
|
| 70 |
+
info = jnp.finfo(type)
|
| 71 |
+
return FInfo(
|
| 72 |
+
bits=info.bits,
|
| 73 |
+
eps=float(info.eps),
|
| 74 |
+
max=float(info.max),
|
| 75 |
+
min=float(info.min),
|
| 76 |
+
smallest_normal=float(info.smallest_normal),
|
| 77 |
+
dtype=jnp.dtype(type)
|
| 78 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_elementwise_functions.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import jax
|
| 16 |
+
from jax.numpy import isdtype
|
| 17 |
+
from jax._src.dtypes import issubdtype
|
| 18 |
+
from jax._src.numpy.util import promote_args
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# TODO(micky774): Update jnp.ceil to preserve integral dtype
|
| 22 |
+
def ceil(x, /):
|
| 23 |
+
"""Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
|
| 24 |
+
x, = promote_args("ceil", x)
|
| 25 |
+
if isdtype(x.dtype, "integral"):
|
| 26 |
+
return x
|
| 27 |
+
return jax.numpy.ceil(x)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# TODO(micky774): Remove when jnp.clip deprecation is completed
|
| 31 |
+
# (began 2024-4-2) and default behavior is Array API 2023 compliant
|
| 32 |
+
def clip(x, /, min=None, max=None):
|
| 33 |
+
"""Returns the complex conjugate for each element x_i of the input array x."""
|
| 34 |
+
x, = promote_args("clip", x)
|
| 35 |
+
|
| 36 |
+
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"Clip received a complex value either through the input or the min/max "
|
| 39 |
+
"keywords. Complex values have no ordering and cannot be clipped. "
|
| 40 |
+
"Please convert to a real value or array by taking the real or "
|
| 41 |
+
"imaginary components via jax.numpy.real/imag respectively."
|
| 42 |
+
)
|
| 43 |
+
return jax.numpy.clip(x, min=min, max=max)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# TODO(micky774): Update jnp.floor to preserve integral dtype
|
| 47 |
+
def floor(x, /):
|
| 48 |
+
"""Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i."""
|
| 49 |
+
x, = promote_args("floor", x)
|
| 50 |
+
if isdtype(x.dtype, "integral"):
|
| 51 |
+
return x
|
| 52 |
+
return jax.numpy.floor(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# TODO(micky774): Remove when jnp.hypot deprecation is completed
|
| 56 |
+
# (began 2024-4-14) and default behavior is Array API 2023 compliant
|
| 57 |
+
def hypot(x1, x2, /):
|
| 58 |
+
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
|
| 59 |
+
x1, x2 = promote_args("hypot", x1, x2)
|
| 60 |
+
|
| 61 |
+
if issubdtype(x1.dtype, jax.numpy.complexfloating):
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"hypot does not support complex-valued inputs. Please convert to real "
|
| 64 |
+
"values first, such as by using jnp.real or jnp.imag to take the real "
|
| 65 |
+
"or imaginary components respectively.")
|
| 66 |
+
return jax.numpy.hypot(x1, x2)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# TODO(micky774): Update jnp.trunc to preserve integral dtype
|
| 70 |
+
def trunc(x, /):
|
| 71 |
+
"""Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i."""
|
| 72 |
+
x, = promote_args("trunc", x)
|
| 73 |
+
if isdtype(x.dtype, "integral"):
|
| 74 |
+
return x
|
| 75 |
+
return jax.numpy.trunc(x)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_fft_functions.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import jax.numpy as jnp
|
| 16 |
+
|
| 17 |
+
# TODO(micky774): Remove after adding device parameter to corresponding jnp.fft
|
| 18 |
+
# functions.
|
| 19 |
+
def fftfreq(n, /, *, d=1.0, device=None):
|
| 20 |
+
"""Returns the discrete Fourier transform sample frequencies."""
|
| 21 |
+
return jnp.fft.fftfreq(n, d=d).to_device(device)
|
| 22 |
+
|
| 23 |
+
def rfftfreq(n, /, *, d=1.0, device=None):
|
| 24 |
+
"""Returns the discrete Fourier transform sample frequencies (for rfft and irfft)."""
|
| 25 |
+
return jnp.fft.rfftfreq(n, d=d).to_device(device)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_linear_algebra_functions.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import jax
|
| 16 |
+
|
| 17 |
+
# TODO(micky774): Remove after deprecation is completed (began 2024-5-14)
|
| 18 |
+
def matrix_rank(x, /, *, rtol=None):
|
| 19 |
+
"""
|
| 20 |
+
Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices).
|
| 21 |
+
"""
|
| 22 |
+
return jax.numpy.linalg.matrix_rank(x, rtol)
|
| 23 |
+
|
| 24 |
+
def pinv(x, /, *, rtol=None):
|
| 25 |
+
"""
|
| 26 |
+
Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x.
|
| 27 |
+
"""
|
| 28 |
+
return jax.numpy.linalg.pinv(x, rtol)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_manipulation_functions.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import jax
|
| 18 |
+
from jax import Array
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# TODO(micky774): Implement copy
|
| 22 |
+
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
|
| 23 |
+
"""Reshapes an array without changing its data."""
|
| 24 |
+
del copy # unused
|
| 25 |
+
return jax.numpy.reshape(x, shape)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_statistical_functions.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import jax
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
|
| 19 |
+
"""Calculates the standard deviation of the input array x."""
|
| 20 |
+
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
|
| 24 |
+
"""Calculates the variance of the input array x."""
|
| 25 |
+
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_utility_functions.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import jax
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
from jax._src.sharding import Sharding
|
| 20 |
+
from jax._src.lib import xla_client as xc
|
| 21 |
+
from jax._src import dtypes as _dtypes, config
|
| 22 |
+
|
| 23 |
+
# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api
|
| 24 |
+
# deprecation
|
| 25 |
+
class __array_namespace_info__:
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self._capabilities = {
|
| 29 |
+
"boolean indexing": True,
|
| 30 |
+
"data-dependent shapes": False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _build_dtype_dict(self):
|
| 35 |
+
array_api_types = {
|
| 36 |
+
"bool", "int8", "int16",
|
| 37 |
+
"int32", "uint8", "uint16",
|
| 38 |
+
"uint32", "float32", "complex64"
|
| 39 |
+
}
|
| 40 |
+
if config.enable_x64.value:
|
| 41 |
+
array_api_types |= {"int64", "uint64", "float64", "complex128"}
|
| 42 |
+
return {category: {t.name: t for t in types if t.name in array_api_types}
|
| 43 |
+
for category, types in _dtypes._dtype_kinds.items()}
|
| 44 |
+
|
| 45 |
+
def default_device(self):
|
| 46 |
+
# By default JAX arrays are uncommitted (device=None), meaning that
|
| 47 |
+
# JAX is free to choose the most efficient device placement.
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
def devices(self):
|
| 51 |
+
return jax.devices()
|
| 52 |
+
|
| 53 |
+
def capabilities(self):
|
| 54 |
+
return self._capabilities
|
| 55 |
+
|
| 56 |
+
def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
|
| 57 |
+
# Array API supported dtypes are device-independent in JAX
|
| 58 |
+
del device
|
| 59 |
+
default_dtypes = {
|
| 60 |
+
"real floating": "f",
|
| 61 |
+
"complex floating": "c",
|
| 62 |
+
"integral": "i",
|
| 63 |
+
"indexing": "i",
|
| 64 |
+
}
|
| 65 |
+
return {
|
| 66 |
+
dtype_name: _dtypes.canonicalize_dtype(
|
| 67 |
+
_dtypes._default_types.get(kind)
|
| 68 |
+
) for dtype_name, kind in default_dtypes.items()
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def dtypes(
|
| 72 |
+
self, *,
|
| 73 |
+
device: xc.Device | Sharding | None = None,
|
| 74 |
+
kind: str | Tuple[str, ...] | None = None):
|
| 75 |
+
# Array API supported dtypes are device-independent in JAX
|
| 76 |
+
del device
|
| 77 |
+
data_types = self._build_dtype_dict()
|
| 78 |
+
if kind is None:
|
| 79 |
+
out_dict = data_types["numeric"] | data_types["bool"]
|
| 80 |
+
elif isinstance(kind, tuple):
|
| 81 |
+
out_dict = {}
|
| 82 |
+
for _kind in kind:
|
| 83 |
+
out_dict |= data_types[_kind]
|
| 84 |
+
else:
|
| 85 |
+
out_dict = data_types[kind]
|
| 86 |
+
return out_dict
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_version.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
__array_api_version__ = '2023.12'
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/fft.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from jax.numpy.fft import (
|
| 16 |
+
fft as fft,
|
| 17 |
+
fftn as fftn,
|
| 18 |
+
fftshift as fftshift,
|
| 19 |
+
hfft as hfft,
|
| 20 |
+
ifft as ifft,
|
| 21 |
+
ifftn as ifftn,
|
| 22 |
+
ifftshift as ifftshift,
|
| 23 |
+
ihfft as ihfft,
|
| 24 |
+
irfft as irfft,
|
| 25 |
+
irfftn as irfftn,
|
| 26 |
+
rfft as rfft,
|
| 27 |
+
rfftn as rfftn,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from jax.experimental.array_api._fft_functions import (
|
| 31 |
+
fftfreq as fftfreq,
|
| 32 |
+
rfftfreq as rfftfreq,
|
| 33 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/linalg.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from jax.numpy.linalg import (
|
| 16 |
+
cholesky as cholesky,
|
| 17 |
+
cross as cross,
|
| 18 |
+
det as det,
|
| 19 |
+
diagonal as diagonal,
|
| 20 |
+
eigh as eigh,
|
| 21 |
+
eigvalsh as eigvalsh,
|
| 22 |
+
inv as inv,
|
| 23 |
+
matmul as matmul,
|
| 24 |
+
matrix_norm as matrix_norm,
|
| 25 |
+
matrix_power as matrix_power,
|
| 26 |
+
matrix_transpose as matrix_transpose,
|
| 27 |
+
outer as outer,
|
| 28 |
+
qr as qr,
|
| 29 |
+
slogdet as slogdet,
|
| 30 |
+
solve as solve,
|
| 31 |
+
svd as svd,
|
| 32 |
+
svdvals as svdvals,
|
| 33 |
+
tensordot as tensordot,
|
| 34 |
+
vecdot as vecdot,
|
| 35 |
+
vector_norm as vector_norm,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from jax.numpy.linalg import trace as trace
|
| 39 |
+
|
| 40 |
+
from jax.experimental.array_api._linear_algebra_functions import (
|
| 41 |
+
matrix_rank as matrix_rank,
|
| 42 |
+
pinv as pinv,
|
| 43 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Array serialization and deserialization."""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import abc
|
| 19 |
+
import asyncio
|
| 20 |
+
from collections.abc import Awaitable, Sequence
|
| 21 |
+
from functools import partial
|
| 22 |
+
import itertools
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import threading
|
| 28 |
+
import time
|
| 29 |
+
from typing import Any, Callable, Optional, Union
|
| 30 |
+
|
| 31 |
+
import jax
|
| 32 |
+
from jax._src import array
|
| 33 |
+
from jax._src import distributed
|
| 34 |
+
from jax._src import sharding
|
| 35 |
+
from jax._src import sharding_impls
|
| 36 |
+
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
| 37 |
+
from jax._src import typing
|
| 38 |
+
from jax._src import util
|
| 39 |
+
from jax._src.lib import xla_extension as xe
|
| 40 |
+
import jax.numpy as jnp
|
| 41 |
+
import numpy as np
|
| 42 |
+
import tensorstore as ts
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
|
| 46 |
+
_REMOVED_VALUE = 'Value removed'
|
| 47 |
+
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
| 48 |
+
_module_unique_count = itertools.count()
|
| 49 |
+
_DEFAULT_DRIVER = 'file'
|
| 50 |
+
_DISTRIBUTED_SYSTEM_MSG = (
|
| 51 |
+
'Please initialize the distributed system via '
|
| 52 |
+
'`jax.distributed.initialize()` at the start of your program.')
|
| 53 |
+
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
|
| 54 |
+
_REMOTE_DRIVER_VALIDATIONS = [
|
| 55 |
+
{'driver': 'gcs', 'path_regex': None},
|
| 56 |
+
{'driver': 's3', 'path_regex': None},
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
class BarrierTimeoutException(Exception):
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
_BARRIER_TIMED_OUT_MSG = (
|
| 63 |
+
"Suggestions for possible fixes:\n"
|
| 64 |
+
"* Check the logs to see if one or more processes failed.\n"
|
| 65 |
+
"* Make sure the training and checkpointing endpoints are close geographically.\n"
|
| 66 |
+
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
|
| 67 |
+
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def create_async_array_from_callback(
|
| 72 |
+
global_shape: array.Shape,
|
| 73 |
+
inp_sharding: jax.sharding.Sharding,
|
| 74 |
+
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
|
| 75 |
+
):
|
| 76 |
+
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
| 77 |
+
addressable_da = inp_sharding._addressable_device_assignment
|
| 78 |
+
future_arrays = [data_callback(device_to_index_map[d], d)
|
| 79 |
+
for d in addressable_da]
|
| 80 |
+
dbs = await asyncio.gather(*future_arrays)
|
| 81 |
+
return array.make_array_from_single_device_arrays(
|
| 82 |
+
global_shape, inp_sharding, dbs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _get_metadata(arr):
|
| 86 |
+
local_shape = arr.addressable_data(0).shape
|
| 87 |
+
return {
|
| 88 |
+
'compressor': {'id': 'zstd'},
|
| 89 |
+
'shape': arr.shape,
|
| 90 |
+
'chunks': np.array(np.maximum(1, local_shape)),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _spec_has_metadata(tree):
|
| 95 |
+
if not isinstance(tree, dict):
|
| 96 |
+
return False
|
| 97 |
+
return 'metadata' in tree or any(
|
| 98 |
+
_spec_has_metadata(subtree) for _, subtree in tree.items())
|
| 99 |
+
|
| 100 |
+
def _get_kvstore_for_gcs(ckpt_path: str):
|
| 101 |
+
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
|
| 102 |
+
if m is None:
|
| 103 |
+
raise ValueError('The ckpt_path should contain the bucket name and the '
|
| 104 |
+
f'file path inside the bucket. Got: {ckpt_path}')
|
| 105 |
+
gcs_bucket = m.group(1)
|
| 106 |
+
path_without_bucket = m.group(2)
|
| 107 |
+
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
|
| 108 |
+
|
| 109 |
+
def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
|
| 110 |
+
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
|
| 111 |
+
# fix the path prefix to add back the stripped '/'.
|
| 112 |
+
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
|
| 113 |
+
is_gcs_path = ckpt_path.startswith('gs://')
|
| 114 |
+
spec = {'driver': 'zarr', 'kvstore': {}}
|
| 115 |
+
if ocdbt:
|
| 116 |
+
if not is_gcs_path and not os.path.isabs(ckpt_path):
|
| 117 |
+
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
|
| 118 |
+
base_path = os.path.dirname(ckpt_path)
|
| 119 |
+
spec['kvstore'] = {
|
| 120 |
+
'driver': 'ocdbt',
|
| 121 |
+
'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}',
|
| 122 |
+
'path': os.path.basename(ckpt_path),
|
| 123 |
+
}
|
| 124 |
+
else:
|
| 125 |
+
if is_gcs_path:
|
| 126 |
+
spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
|
| 127 |
+
else:
|
| 128 |
+
spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path}
|
| 129 |
+
|
| 130 |
+
return spec
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool:
|
| 134 |
+
"""Detect if user is using cloud storages.
|
| 135 |
+
|
| 136 |
+
This can detect common defines and unable to detect some corner cases such as
|
| 137 |
+
using gcsfuse.
|
| 138 |
+
"""
|
| 139 |
+
if isinstance(tspec, str):
|
| 140 |
+
# KvStoreUrl
|
| 141 |
+
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
|
| 142 |
+
return True
|
| 143 |
+
else:
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
for key in ('base', 'kvstore'):
|
| 147 |
+
if key in tspec:
|
| 148 |
+
return is_remote_storage(tspec[key])
|
| 149 |
+
|
| 150 |
+
if 'driver' in tspec:
|
| 151 |
+
for rule in _REMOTE_DRIVER_VALIDATIONS:
|
| 152 |
+
if tspec['driver'] == rule['driver']:
|
| 153 |
+
if rule['path_regex'] is None:
|
| 154 |
+
return True
|
| 155 |
+
|
| 156 |
+
# check if path matches the regex.
|
| 157 |
+
if re.match(rule['path_regex'], tspec['path']):
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
return False
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Lifted from T5X.
|
| 164 |
+
class _LimitInFlightBytes:
|
| 165 |
+
"""Limits in-flight bytes when reading/writing checkpoints per process."""
|
| 166 |
+
|
| 167 |
+
def __init__(self, num_bytes):
|
| 168 |
+
self._max_bytes = num_bytes
|
| 169 |
+
self._available_bytes = num_bytes
|
| 170 |
+
self._cv = asyncio.Condition(lock=asyncio.Lock())
|
| 171 |
+
|
| 172 |
+
async def wait_for_bytes(self, requested_bytes):
|
| 173 |
+
if requested_bytes >= self._max_bytes:
|
| 174 |
+
raise ValueError('Requested more bytes than we reserved space for: '
|
| 175 |
+
f'{requested_bytes} > {self._max_bytes}')
|
| 176 |
+
async with self._cv:
|
| 177 |
+
await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
|
| 178 |
+
self._available_bytes -= requested_bytes
|
| 179 |
+
assert self._available_bytes >= 0
|
| 180 |
+
|
| 181 |
+
async def release_bytes(self, requested_bytes):
|
| 182 |
+
async with self._cv:
|
| 183 |
+
self._available_bytes += requested_bytes
|
| 184 |
+
assert self._available_bytes <= self._max_bytes
|
| 185 |
+
self._cv.notify_all()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
async def async_serialize(
|
| 189 |
+
arr_inp,
|
| 190 |
+
tensorstore_spec,
|
| 191 |
+
commit_future=None,
|
| 192 |
+
context=TS_CONTEXT,
|
| 193 |
+
primary_host: Optional[int] = 0,
|
| 194 |
+
replica_id: int = 0,
|
| 195 |
+
):
|
| 196 |
+
"""Serialize an array using TensorStore.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
arr_inp: The array to serialize.
|
| 200 |
+
tensorstore_spec: The tensorstore spec to use.
|
| 201 |
+
commit_future: A list of futures that will be appended to. The futures can
|
| 202 |
+
be awaited asynchronously. If None, the futures will be awaited
|
| 203 |
+
synchronously by this method.
|
| 204 |
+
context: ts.Context instance.
|
| 205 |
+
primary_host: Primary host, which indicates the host that will be treated as
|
| 206 |
+
the "leader". If None, all hosts are treated as the primary. DO NOT USE
|
| 207 |
+
unless you are sure you know what you are doing.
|
| 208 |
+
replica_id: Allows overriding the shard replica id that will be saved.
|
| 209 |
+
DO NOT USE unless you are sure you know what you are doing.
|
| 210 |
+
"""
|
| 211 |
+
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
| 212 |
+
arr_inp.is_fully_addressable):
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f'Passing fully addressable arrays to a multiprocess '
|
| 215 |
+
f'serialization is not allowed, as this may lead to a race condition '
|
| 216 |
+
f'between processes. Serialization have failed for the array with '
|
| 217 |
+
f'the path "{tensorstore_spec["kvstore"]["path"]}".')
|
| 218 |
+
|
| 219 |
+
# 'metadata' may not be present at the top level (for example, if we are using
|
| 220 |
+
# a 'cast' driver).
|
| 221 |
+
if not _spec_has_metadata(tensorstore_spec):
|
| 222 |
+
tensorstore_spec['metadata'] = _get_metadata(arr_inp)
|
| 223 |
+
|
| 224 |
+
# Set dtype if it's not in spec
|
| 225 |
+
if 'dtype' not in tensorstore_spec:
|
| 226 |
+
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name
|
| 227 |
+
|
| 228 |
+
# If primary_host is None, all hosts will checkpoint. This is used
|
| 229 |
+
# for checkpointing to local filesystem.
|
| 230 |
+
if primary_host is None or jax.process_index() == primary_host:
|
| 231 |
+
open_future = ts.open(
|
| 232 |
+
ts.Spec(tensorstore_spec),
|
| 233 |
+
create=True,
|
| 234 |
+
open=True,
|
| 235 |
+
context=context,
|
| 236 |
+
)
|
| 237 |
+
# Asynchronous case.
|
| 238 |
+
if commit_future is not None:
|
| 239 |
+
assert isinstance(commit_future, list)
|
| 240 |
+
commit_future.append(open_future)
|
| 241 |
+
else:
|
| 242 |
+
await open_future
|
| 243 |
+
|
| 244 |
+
# `ts.open` runs twice for process `primary_host` because for the first time,
|
| 245 |
+
# we just get the future to be awaited upon in the background thread. The
|
| 246 |
+
# second one runs with `assume_metadata=True` which does no I/O operation and
|
| 247 |
+
# returns the tensorstore object.
|
| 248 |
+
# For every process other than `primary_host`, we open with
|
| 249 |
+
# `assume_metadata=True`.
|
| 250 |
+
t = await ts.open(
|
| 251 |
+
ts.Spec(tensorstore_spec),
|
| 252 |
+
open=True,
|
| 253 |
+
assume_metadata=True,
|
| 254 |
+
context=context,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
async def _write_array(shard):
|
| 258 |
+
if shard.replica_id == replica_id:
|
| 259 |
+
write_future = t[shard.index].write(shard.data)
|
| 260 |
+
if commit_future is not None:
|
| 261 |
+
assert isinstance(commit_future, list)
|
| 262 |
+
commit_future.append(write_future.commit)
|
| 263 |
+
await write_future.copy
|
| 264 |
+
else:
|
| 265 |
+
await write_future.commit
|
| 266 |
+
|
| 267 |
+
local_shards = arr_inp.addressable_shards
|
| 268 |
+
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
|
| 269 |
+
return await asyncio.gather(*future_write_state)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def run_serialization(arrays, tensorstore_specs):
|
| 273 |
+
async def _run_serializer():
|
| 274 |
+
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
|
| 275 |
+
return await asyncio.gather(*future_writer)
|
| 276 |
+
asyncio.run(_run_serializer())
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def estimate_read_memory_footprint(t: ts.TensorStore,
|
| 280 |
+
domain: ts.IndexDomain) -> int:
|
| 281 |
+
rank = t.rank
|
| 282 |
+
num_bytes = t.dtype.numpy_dtype.itemsize
|
| 283 |
+
chunk_template = t.chunk_layout.read_chunk_template
|
| 284 |
+
if domain is None:
|
| 285 |
+
domain = t.domain
|
| 286 |
+
origin = domain.origin
|
| 287 |
+
shape = domain.shape
|
| 288 |
+
chunk_origin = chunk_template.origin
|
| 289 |
+
chunk_shape = chunk_template.shape
|
| 290 |
+
|
| 291 |
+
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
|
| 292 |
+
# For those, instead of returning a near-infinite memory footprint, estimate
|
| 293 |
+
# the footprint as the entire shape.
|
| 294 |
+
for i in range(rank):
|
| 295 |
+
if not chunk_template[i].finite:
|
| 296 |
+
return domain.size * num_bytes
|
| 297 |
+
|
| 298 |
+
# Otherwise, if we have a chunked driver, estimate based on chunk size.
|
| 299 |
+
for i in range(rank):
|
| 300 |
+
origin_value = origin[i]
|
| 301 |
+
chunk_origin_value = chunk_origin[i]
|
| 302 |
+
chunk_size = chunk_shape[i]
|
| 303 |
+
lower = origin_value - chunk_origin_value
|
| 304 |
+
upper = origin_value + shape[i] - chunk_origin_value
|
| 305 |
+
lower_aligned = lower // chunk_size * chunk_size
|
| 306 |
+
upper_aligned = -(-upper // chunk_size) * chunk_size
|
| 307 |
+
num_bytes *= (upper_aligned - lower_aligned)
|
| 308 |
+
|
| 309 |
+
return num_bytes
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
async def async_deserialize(
|
| 313 |
+
user_in_sharding: jax.sharding.Sharding | Layout,
|
| 314 |
+
tensorstore_spec: ts.Spec | dict[str, Any],
|
| 315 |
+
global_shape: Sequence[int] | None = None,
|
| 316 |
+
dtype=None,
|
| 317 |
+
byte_limiter: _LimitInFlightBytes | None = None,
|
| 318 |
+
context=TS_CONTEXT,
|
| 319 |
+
assume_metadata: bool = False,
|
| 320 |
+
):
|
| 321 |
+
in_sharding = (user_in_sharding.sharding
|
| 322 |
+
if isinstance(user_in_sharding, Layout) else user_in_sharding)
|
| 323 |
+
if not isinstance(in_sharding, jax.sharding.Sharding):
|
| 324 |
+
raise ValueError(
|
| 325 |
+
'sharding passed to deserialization should be specified, concrete and'
|
| 326 |
+
f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
|
| 327 |
+
dll = (user_in_sharding.device_local_layout
|
| 328 |
+
if isinstance(user_in_sharding, Layout) else None)
|
| 329 |
+
t = await ts.open(
|
| 330 |
+
tensorstore_spec,
|
| 331 |
+
open=True,
|
| 332 |
+
assume_metadata=assume_metadata,
|
| 333 |
+
context=context,
|
| 334 |
+
)
|
| 335 |
+
shape = t.shape if global_shape is None else global_shape
|
| 336 |
+
new_shard_shape = in_sharding.shard_shape(tuple(shape))
|
| 337 |
+
|
| 338 |
+
async def cb(index: array.Index, device: jax.Device):
|
| 339 |
+
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
| 340 |
+
restricted_domain = t.domain.intersect(requested_domain)
|
| 341 |
+
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
|
| 342 |
+
# Limit the bytes read for every shard.
|
| 343 |
+
if byte_limiter is not None:
|
| 344 |
+
await byte_limiter.wait_for_bytes(requested_bytes)
|
| 345 |
+
# This maybe needed because the shape the array was saved with is smaller
|
| 346 |
+
# than the requested shape of the array in which it will be reloaded. So
|
| 347 |
+
# the extra values will be filled with 0s.
|
| 348 |
+
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
| 349 |
+
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][
|
| 350 |
+
restricted_domain].write(t[restricted_domain])
|
| 351 |
+
if dtype is not None:
|
| 352 |
+
# Cast while reloading on process to avoid 2 copies on device if the
|
| 353 |
+
# casting is done on device.
|
| 354 |
+
out = out.astype(dtype)
|
| 355 |
+
# Convert to jnp array so that layouts are initialized properly for
|
| 356 |
+
# sub-byte dtypes.
|
| 357 |
+
# TODO(yashkatariya): This is a band-aid fix. Figure out a better way to
|
| 358 |
+
# make this work.
|
| 359 |
+
if out.dtype == jnp.int4:
|
| 360 |
+
out = jnp.asarray(out) # type: ignore
|
| 361 |
+
result = jax.device_put(
|
| 362 |
+
out, Layout(dll, jax.sharding.SingleDeviceSharding(device)))
|
| 363 |
+
if byte_limiter is not None:
|
| 364 |
+
# NB: `out` actually might not be ready for garbage collection by the
|
| 365 |
+
# time we call release_bytes . Thus peak memory usage still might grow
|
| 366 |
+
# beyond what byte_limiter limit suggests it should. The simplest option
|
| 367 |
+
# would be to call `result.block_until_ready()`` here. However it
|
| 368 |
+
# also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU
|
| 369 |
+
# transfer instead of loading data. In the future, if memory pressure
|
| 370 |
+
# becomes a problem, we can instead instrument bytelimiter to
|
| 371 |
+
# keep track of all in-flight tensors and only block_until_ready, if byte
|
| 372 |
+
# limiter hits the limit to get reduced memory usage, without losing
|
| 373 |
+
# performance in common use cases.
|
| 374 |
+
await byte_limiter.release_bytes(requested_bytes)
|
| 375 |
+
return result
|
| 376 |
+
|
| 377 |
+
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def run_deserialization(shardings: Sequence[sharding.Sharding | Layout],
|
| 381 |
+
tensorstore_specs: Sequence[dict[str, Any]],
|
| 382 |
+
global_shapes: Sequence[array.Shape] | None = None,
|
| 383 |
+
dtypes: Sequence[typing.DTypeLike] | None = None,
|
| 384 |
+
concurrent_gb: int = 32):
|
| 385 |
+
concurrent_bytes = concurrent_gb * 10**9
|
| 386 |
+
|
| 387 |
+
async def _run_deserializer():
|
| 388 |
+
# Object should be created once per process.
|
| 389 |
+
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
| 390 |
+
|
| 391 |
+
future_arrays = jax.tree_util.tree_map(
|
| 392 |
+
partial(async_deserialize, byte_limiter=byte_limiter),
|
| 393 |
+
shardings, tensorstore_specs,
|
| 394 |
+
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
| 395 |
+
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
| 396 |
+
return await asyncio.gather(*future_arrays)
|
| 397 |
+
return asyncio.run(_run_deserializer())
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _get_key(key: int):
|
| 401 |
+
return f'tensorstore_checkpoint_{key}'
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class GlobalAsyncCheckpointManagerBase(util.StrictABC):
|
| 405 |
+
"""Interface for checkpointing GDAs asynchronously.
|
| 406 |
+
|
| 407 |
+
This class manages the state of an ongoing asynchronous checkpoint.
|
| 408 |
+
|
| 409 |
+
For example, say a checkpoint happens on every step. If you checkpoint on
|
| 410 |
+
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
| 411 |
+
checkpoint hasn't finished committing to the storage layer yet. So until that
|
| 412 |
+
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
| 413 |
+
class allows to maintain that state.
|
| 414 |
+
|
| 415 |
+
Example:
|
| 416 |
+
|
| 417 |
+
Below is a simplified training loop:
|
| 418 |
+
|
| 419 |
+
```
|
| 420 |
+
# Call this at the start of your program.
|
| 421 |
+
jax.distributed.initialize()
|
| 422 |
+
|
| 423 |
+
manager = GlobalAsyncCheckpointManager()
|
| 424 |
+
|
| 425 |
+
# Restore checkpoint if available or initialize the train_state from
|
| 426 |
+
# init_fn().
|
| 427 |
+
train_state = manager.deserialize(...)
|
| 428 |
+
|
| 429 |
+
while ...:
|
| 430 |
+
if step % num_steps_between_checkpoints == 0:
|
| 431 |
+
manager.serialize(train_state, temp_checkpoint_dir=...,
|
| 432 |
+
final_checkpoint_dir=...)
|
| 433 |
+
train_state = train_step(train_state, input)
|
| 434 |
+
# This is a non-blocking call.
|
| 435 |
+
manager.check_for_errors()
|
| 436 |
+
|
| 437 |
+
manager.serialize(train_state, temp_checkpoint_dir=...,
|
| 438 |
+
final_checkpoint_dir=...)
|
| 439 |
+
# Wait before the end of the program for the checkpoint to finish. This is a
|
| 440 |
+
# blocking call.
|
| 441 |
+
manager.wait_until_finished()
|
| 442 |
+
```
|
| 443 |
+
"""
|
| 444 |
+
|
| 445 |
+
@abc.abstractmethod
|
| 446 |
+
def check_for_errors(self):
|
| 447 |
+
"""Checks if any errors have been raised in the child thread.
|
| 448 |
+
|
| 449 |
+
This is a non-blocking call that can be called in the main thread.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
@abc.abstractmethod
|
| 453 |
+
def wait_until_finished(self):
|
| 454 |
+
"""Blocks until serialization has finished."""
|
| 455 |
+
|
| 456 |
+
@abc.abstractmethod
|
| 457 |
+
def serialize(self, arrays, tensorstore_specs, *,
|
| 458 |
+
on_commit_callback: Callable[[], None]):
|
| 459 |
+
"""Serializes GDAs to TensorStore."""
|
| 460 |
+
|
| 461 |
+
@abc.abstractmethod
|
| 462 |
+
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
| 463 |
+
tensorstore_specs: Sequence[dict[str, Any]],
|
| 464 |
+
global_shapes: Sequence[array.Shape] | None = None,
|
| 465 |
+
dtypes: Sequence[typing.DTypeLike] | None = None):
|
| 466 |
+
"""Deserializes GDAs from TensorStore."""
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class AsyncManager:
|
| 470 |
+
|
| 471 |
+
def __init__(self, timeout_secs=300):
|
| 472 |
+
self._timeout_secs = timeout_secs
|
| 473 |
+
self._timeout_in_ms = self._timeout_secs * 1000
|
| 474 |
+
|
| 475 |
+
self._commit_futures = None
|
| 476 |
+
self._thread = None
|
| 477 |
+
self._exception = None
|
| 478 |
+
|
| 479 |
+
if jax.process_count() > 1 and distributed.global_state.client is None:
|
| 480 |
+
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
|
| 481 |
+
if jax.process_count() > 1:
|
| 482 |
+
self._client = distributed.global_state.client
|
| 483 |
+
self._count = None
|
| 484 |
+
|
| 485 |
+
def __del__(self):
|
| 486 |
+
if self._thread is not None and self._thread.is_alive():
|
| 487 |
+
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
| 488 |
+
'before your program finishes because there is a '
|
| 489 |
+
'possibility of losing errors raised if the '
|
| 490 |
+
'this class is deleted before writing is completed.')
|
| 491 |
+
|
| 492 |
+
def _thread_func(self):
|
| 493 |
+
try:
|
| 494 |
+
current_process = jax.process_index()
|
| 495 |
+
process_count = jax.process_count()
|
| 496 |
+
logger.info('Starting commit to storage layer by process: %s',
|
| 497 |
+
current_process)
|
| 498 |
+
thread_start_time = time.time()
|
| 499 |
+
for future in self._commit_futures:
|
| 500 |
+
future.result()
|
| 501 |
+
logger.info('Finished committing to storage layer by process: %s',
|
| 502 |
+
current_process)
|
| 503 |
+
|
| 504 |
+
if process_count > 1:
|
| 505 |
+
# All processes will wait at the barrier. When all processes are at the
|
| 506 |
+
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
| 507 |
+
key_for_barrier = _get_key(self._count)
|
| 508 |
+
logger.info('Key used for barrier is %s for process %s',
|
| 509 |
+
key_for_barrier, current_process)
|
| 510 |
+
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
| 511 |
+
logger.info('Finished waiting at barrier for process %s',
|
| 512 |
+
current_process)
|
| 513 |
+
|
| 514 |
+
if current_process == 0:
|
| 515 |
+
self._on_commit_callback()
|
| 516 |
+
logger.info('on_commit_callback successfully ran!')
|
| 517 |
+
if process_count > 1:
|
| 518 |
+
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
|
| 519 |
+
logger.info('Process 0 successfully set key %s in the kv store',
|
| 520 |
+
key_for_barrier)
|
| 521 |
+
|
| 522 |
+
jax.monitoring.record_event_duration_secs(
|
| 523 |
+
'/jax/checkpoint/write/async/thread_duration_sec',
|
| 524 |
+
time.time() - thread_start_time)
|
| 525 |
+
|
| 526 |
+
except Exception as e:
|
| 527 |
+
self._exception = e
|
| 528 |
+
|
| 529 |
+
def _start_async_commit(self, on_commit_callback):
|
| 530 |
+
self._count = next(_module_unique_count)
|
| 531 |
+
|
| 532 |
+
self._on_commit_callback = on_commit_callback
|
| 533 |
+
self._thread = threading.Thread(target=self._thread_func)
|
| 534 |
+
self._thread.start()
|
| 535 |
+
|
| 536 |
+
def check_for_errors(self):
|
| 537 |
+
if self._exception is not None:
|
| 538 |
+
# Clears self._exception so it is only raised once.
|
| 539 |
+
exception = self._exception
|
| 540 |
+
self._exception = None
|
| 541 |
+
if (isinstance(exception, xe.XlaRuntimeError) and
|
| 542 |
+
'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)):
|
| 543 |
+
raise BarrierTimeoutException(
|
| 544 |
+
'\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG]))
|
| 545 |
+
raise exception # pylint: disable=raising-bad-type
|
| 546 |
+
|
| 547 |
+
def wait_until_finished(self):
|
| 548 |
+
if self._thread is not None:
|
| 549 |
+
self._thread.join()
|
| 550 |
+
self._thread = None
|
| 551 |
+
logger.info('Thread joined successfully')
|
| 552 |
+
|
| 553 |
+
self.check_for_errors()
|
| 554 |
+
logger.info('Error check finished successfully')
|
| 555 |
+
|
| 556 |
+
if jax.process_count() > 1 and self._count is not None:
|
| 557 |
+
# Block until process 0 writes success value to the key value store.
|
| 558 |
+
# If it fails to write it, then `blocking_key_value_get` will time out.
|
| 559 |
+
get_key = _get_key(self._count)
|
| 560 |
+
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
|
| 561 |
+
logger.info('blocking_key_value_get on key %s was successfully '
|
| 562 |
+
'completed.', get_key)
|
| 563 |
+
|
| 564 |
+
def _add_futures(self, futures: Sequence[asyncio.Future]):
|
| 565 |
+
self._commit_futures = futures
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
| 569 |
+
"""Responsible for serializing GDAs via TensorStore."""
|
| 570 |
+
|
| 571 |
+
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
| 572 |
+
"""Serializes Arrays or Arrays via TensorStore asynchronously.
|
| 573 |
+
|
| 574 |
+
TensorStore writes to a storage layer in 2 steps:
|
| 575 |
+
* Reading/copying from the source after which the source can be modified.
|
| 576 |
+
* Returns a copy future.
|
| 577 |
+
* Writing/committing to the storage layer.
|
| 578 |
+
* Returns a commit future.
|
| 579 |
+
|
| 580 |
+
In asynchronous mode, the serialization waits for the commit future to
|
| 581 |
+
finish in a separate thread allowing other computation to proceed.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
arrays: Arrays or Arrays that should be serialized.
|
| 585 |
+
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
| 586 |
+
Arrays.
|
| 587 |
+
on_commit_callback: This callback will be executed after all processes
|
| 588 |
+
have finished writing their checkpoints to disk. Filesystems where
|
| 589 |
+
atomic rename operations are supported, you can rename from the
|
| 590 |
+
temporary directory to the final directory. On GCS, you write to the
|
| 591 |
+
final directory directly and in `on_commit_callback` you write a
|
| 592 |
+
success file indicating that the serialization was successful because
|
| 593 |
+
GCS does not support atomic rename operations.
|
| 594 |
+
"""
|
| 595 |
+
logger.info('Waiting for previous serialization to finish.')
|
| 596 |
+
self.wait_until_finished()
|
| 597 |
+
|
| 598 |
+
commit_futures = [[] for _ in range(len(tensorstore_specs))]
|
| 599 |
+
|
| 600 |
+
async def _run_serializer():
|
| 601 |
+
future_writer = jax.tree_util.tree_map(
|
| 602 |
+
async_serialize, arrays, tensorstore_specs, commit_futures)
|
| 603 |
+
return await asyncio.gather(*future_writer)
|
| 604 |
+
|
| 605 |
+
asyncio.run(_run_serializer())
|
| 606 |
+
|
| 607 |
+
self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
|
| 608 |
+
|
| 609 |
+
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
| 610 |
+
# has finished writing.
|
| 611 |
+
self._start_async_commit(on_commit_callback)
|
| 612 |
+
|
| 613 |
+
def serialize_with_paths(self, arrays: Sequence[jax.Array],
|
| 614 |
+
paths: Sequence[str], *, on_commit_callback):
|
| 615 |
+
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
| 616 |
+
self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)
|
| 617 |
+
|
| 618 |
+
def deserialize(self, shardings: Sequence[sharding.Sharding | Layout],
|
| 619 |
+
tensorstore_specs: Sequence[dict[str, Any]],
|
| 620 |
+
global_shapes: Sequence[array.Shape] | None = None,
|
| 621 |
+
dtypes: Sequence[typing.DTypeLike] | None = None,
|
| 622 |
+
concurrent_gb: int = 32):
|
| 623 |
+
self.wait_until_finished()
|
| 624 |
+
return run_deserialization(shardings, tensorstore_specs,
|
| 625 |
+
global_shapes, dtypes, concurrent_gb)
|
| 626 |
+
|
| 627 |
+
def deserialize_with_paths(
|
| 628 |
+
self, shardings: Sequence[sharding.Sharding],
|
| 629 |
+
paths: Sequence[str],
|
| 630 |
+
global_shapes: Sequence[array.Shape] | None = None,
|
| 631 |
+
dtypes: Sequence[typing.DTypeLike] | None = None,
|
| 632 |
+
concurrent_gb: int = 32):
|
| 633 |
+
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
| 634 |
+
return self.deserialize(shardings, tspecs, global_shapes, dtypes,
|
| 635 |
+
concurrent_gb)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization_test.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Tests for serialization and deserialization of GDA."""
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import contextlib
|
| 18 |
+
import math
|
| 19 |
+
from functools import partial
|
| 20 |
+
import re
|
| 21 |
+
import os
|
| 22 |
+
import pathlib
|
| 23 |
+
import tracemalloc as tm
|
| 24 |
+
|
| 25 |
+
from absl.testing import absltest
|
| 26 |
+
from absl.testing import parameterized
|
| 27 |
+
import jax
|
| 28 |
+
import jax.numpy as jnp
|
| 29 |
+
from jax._src import test_util as jtu
|
| 30 |
+
from jax._src import array
|
| 31 |
+
from jax._src import xla_bridge as xb
|
| 32 |
+
from jax.sharding import NamedSharding, GSPMDSharding
|
| 33 |
+
from jax.sharding import PartitionSpec as P
|
| 34 |
+
from jax.experimental.array_serialization import serialization
|
| 35 |
+
from jax.experimental.layout import Layout, DeviceLocalLayout as DLL
|
| 36 |
+
import numpy as np
|
| 37 |
+
import tensorstore as ts
|
| 38 |
+
|
| 39 |
+
jax.config.parse_flags_with_absl()
|
| 40 |
+
_exit_stack = contextlib.ExitStack()
|
| 41 |
+
|
| 42 |
+
def setUpModule():
|
| 43 |
+
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
| 44 |
+
|
| 45 |
+
def tearDownModule():
|
| 46 |
+
_exit_stack.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
pattern = re.compile(r"\{(.*?):")
|
| 50 |
+
|
| 51 |
+
def extract_minor_to_major(l):
|
| 52 |
+
match = re.search(pattern, str(l))
|
| 53 |
+
return tuple(int(i) for i in match.groups()[0].split(','))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class CheckpointTest(jtu.JaxTestCase):
|
| 57 |
+
|
| 58 |
+
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
|
| 59 |
+
os.rename(temp_ckpt_dir, final_ckpt_dir)
|
| 60 |
+
|
| 61 |
+
@jtu.skip_on_devices('cpu')
|
| 62 |
+
def test_memory_consumption(self):
|
| 63 |
+
global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y'))
|
| 64 |
+
inp_shape = (2_048, 4_096)
|
| 65 |
+
pspec = P('x', 'y')
|
| 66 |
+
num = math.prod(inp_shape)
|
| 67 |
+
sharding = NamedSharding(global_mesh, pspec)
|
| 68 |
+
src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9
|
| 69 |
+
inp = array.make_array_from_callback(
|
| 70 |
+
inp_shape, sharding,
|
| 71 |
+
lambda idx: src[idx])
|
| 72 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path)
|
| 73 |
+
tspec = serialization.get_tensorstore_spec(str(ckpt_dir))
|
| 74 |
+
|
| 75 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 76 |
+
manager.serialize(
|
| 77 |
+
[inp], [tspec],
|
| 78 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 79 |
+
manager.wait_until_finished()
|
| 80 |
+
|
| 81 |
+
async def deserialize_with_byte_limit():
|
| 82 |
+
r = await serialization.async_deserialize(
|
| 83 |
+
sharding, tspec, inp_shape,
|
| 84 |
+
byte_limiter=serialization._LimitInFlightBytes(4_200_000))
|
| 85 |
+
r.block_until_ready()
|
| 86 |
+
|
| 87 |
+
tm.start()
|
| 88 |
+
asyncio.run(deserialize_with_byte_limit())
|
| 89 |
+
unused_current, peak = tm.get_traced_memory()
|
| 90 |
+
# NB: some padding + tensorstore overhead. It should always be
|
| 91 |
+
# less than array size (2048 * 4096 * 4 = 32M)
|
| 92 |
+
self.assertLess(peak, 10_000_000)
|
| 93 |
+
deserialize_wo_limit = serialization.async_deserialize(
|
| 94 |
+
sharding, tspec, inp_shape)
|
| 95 |
+
tm.clear_traces()
|
| 96 |
+
# NB: call block_until_ready() is important here and above
|
| 97 |
+
# because otherwise this leads to racing condition and segfault with
|
| 98 |
+
# tensorstore attempting to dealloc using tracemalloc which is already
|
| 99 |
+
# destroyed.
|
| 100 |
+
asyncio.run(deserialize_wo_limit).block_until_ready()
|
| 101 |
+
|
| 102 |
+
unused_current, peak = tm.get_traced_memory()
|
| 103 |
+
# We load entire array in memory here.
|
| 104 |
+
self.assertGreater(peak, 30_000_000)
|
| 105 |
+
tm.stop()
|
| 106 |
+
|
| 107 |
+
def test_checkpointing_with_path_variant(self):
|
| 108 |
+
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
| 109 |
+
inp_shape = (8, 2)
|
| 110 |
+
pspec = P('x', 'y')
|
| 111 |
+
num = math.prod(inp_shape)
|
| 112 |
+
|
| 113 |
+
# First Array
|
| 114 |
+
global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
|
| 115 |
+
a1 = array.make_array_from_callback(
|
| 116 |
+
inp_shape, NamedSharding(global_mesh, pspec),
|
| 117 |
+
lambda idx: global_input_data1[idx])
|
| 118 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt_variant').full_path)
|
| 119 |
+
ckpt_path1 = pathlib.Path(
|
| 120 |
+
self.create_tempdir(f'{ckpt_dir}/first').full_path)
|
| 121 |
+
|
| 122 |
+
ckpt_paths = [str(ckpt_path1)]
|
| 123 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 124 |
+
manager.serialize_with_paths(
|
| 125 |
+
[a1], ckpt_paths,
|
| 126 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 127 |
+
manager.wait_until_finished()
|
| 128 |
+
|
| 129 |
+
m1, = manager.deserialize_with_paths(
|
| 130 |
+
[NamedSharding(global_mesh, pspec)], ckpt_paths)
|
| 131 |
+
self.assertIsInstance(m1, array.ArrayImpl)
|
| 132 |
+
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
|
| 133 |
+
np.array([[0], [2]], dtype=np.int32))
|
| 134 |
+
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
|
| 135 |
+
np.array([[1], [3]], dtype=np.int32))
|
| 136 |
+
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
| 137 |
+
self.assertEqual(m1.dtype, np.int32)
|
| 138 |
+
|
| 139 |
+
def test_checkpointing_jax_array(self):
|
| 140 |
+
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
| 141 |
+
inp_shape = (8, 2)
|
| 142 |
+
pspec = P('x', 'y')
|
| 143 |
+
num = math.prod(inp_shape)
|
| 144 |
+
|
| 145 |
+
# First Array
|
| 146 |
+
global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
|
| 147 |
+
a1 = array.make_array_from_callback(
|
| 148 |
+
inp_shape, NamedSharding(global_mesh, pspec),
|
| 149 |
+
lambda idx: global_input_data1[idx])
|
| 150 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
|
| 151 |
+
ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
|
| 152 |
+
|
| 153 |
+
# Second Array
|
| 154 |
+
global_input_data2 = np.arange(
|
| 155 |
+
num, num + num, dtype=np.int32).reshape(inp_shape)
|
| 156 |
+
a2 = array.make_array_from_callback(
|
| 157 |
+
inp_shape, NamedSharding(global_mesh, pspec),
|
| 158 |
+
lambda idx: global_input_data2[idx])
|
| 159 |
+
ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path)
|
| 160 |
+
|
| 161 |
+
# Third Array
|
| 162 |
+
def cb3(_):
|
| 163 |
+
return np.array([], dtype=np.float32)
|
| 164 |
+
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
|
| 165 |
+
a3 = array.make_array_from_callback(
|
| 166 |
+
(0,), NamedSharding(global_mesh1d, P(None)), cb3)
|
| 167 |
+
ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path)
|
| 168 |
+
|
| 169 |
+
ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)]
|
| 170 |
+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
| 171 |
+
|
| 172 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 173 |
+
manager.serialize(
|
| 174 |
+
[a1, a2, a3], tspecs,
|
| 175 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 176 |
+
manager.wait_until_finished()
|
| 177 |
+
|
| 178 |
+
m1, m2, m3 = serialization.run_deserialization(
|
| 179 |
+
[NamedSharding(global_mesh, pspec),
|
| 180 |
+
NamedSharding(global_mesh, P('x')),
|
| 181 |
+
NamedSharding(global_mesh1d, P(None))],
|
| 182 |
+
tspecs)
|
| 183 |
+
|
| 184 |
+
self.assertIsInstance(m1, array.ArrayImpl)
|
| 185 |
+
self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
|
| 186 |
+
np.array([[0], [2]], dtype=np.int32))
|
| 187 |
+
self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
|
| 188 |
+
np.array([[1], [3]], dtype=np.int32))
|
| 189 |
+
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
|
| 190 |
+
self.assertEqual(m1.dtype, np.int32)
|
| 191 |
+
|
| 192 |
+
self.assertIsInstance(m2, array.ArrayImpl)
|
| 193 |
+
self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
|
| 194 |
+
np.array([[16, 17], [18, 19]], dtype=np.int32))
|
| 195 |
+
self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
|
| 196 |
+
np.array([[16, 17], [18, 19]], dtype=np.int32))
|
| 197 |
+
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
|
| 198 |
+
self.assertEqual(m2.dtype, np.int32)
|
| 199 |
+
|
| 200 |
+
self.assertIsInstance(m3, array.ArrayImpl)
|
| 201 |
+
for i, s in enumerate(m3.addressable_shards):
|
| 202 |
+
self.assertEqual(s.index, (slice(None),))
|
| 203 |
+
self.assertEqual(s.replica_id, i)
|
| 204 |
+
self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
|
| 205 |
+
self.assertEqual(m3.dtype, np.float32)
|
| 206 |
+
|
| 207 |
+
@parameterized.product(input_dtype=[np.int32, jnp.bfloat16])
|
| 208 |
+
def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
|
| 209 |
+
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
| 210 |
+
global_input_shape = (8, 2)
|
| 211 |
+
num = math.prod(global_input_shape)
|
| 212 |
+
|
| 213 |
+
global_input_data1 = np.arange(num, dtype=input_dtype).reshape(
|
| 214 |
+
global_input_shape
|
| 215 |
+
)
|
| 216 |
+
def cb1(index):
|
| 217 |
+
return global_input_data1[index]
|
| 218 |
+
arr = array.make_array_from_callback(
|
| 219 |
+
global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb1)
|
| 220 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
|
| 221 |
+
|
| 222 |
+
ckpt_paths = [str(ckpt_dir)]
|
| 223 |
+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
| 224 |
+
|
| 225 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 226 |
+
manager.serialize(
|
| 227 |
+
[arr], tspecs,
|
| 228 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 229 |
+
manager.wait_until_finished()
|
| 230 |
+
|
| 231 |
+
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
| 232 |
+
|
| 233 |
+
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
|
| 234 |
+
[np.float32])
|
| 235 |
+
|
| 236 |
+
expected_data = {
|
| 237 |
+
0: np.array([[0], [2], [4]], dtype=np.float32),
|
| 238 |
+
1: np.array([[1], [3], [5]], dtype=np.float32),
|
| 239 |
+
2: np.array([[6], [8], [10]], dtype=np.float32),
|
| 240 |
+
3: np.array([[7], [9], [11]], dtype=np.float32),
|
| 241 |
+
4: np.array([[12], [14], [0]], dtype=np.float32),
|
| 242 |
+
5: np.array([[13], [15], [0]], dtype=np.float32),
|
| 243 |
+
6: np.array([[0], [0], [0]], dtype=np.float32),
|
| 244 |
+
7: np.array([[0], [0], [0]], dtype=np.float32),
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
for l in m1.addressable_shards:
|
| 248 |
+
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
| 249 |
+
|
| 250 |
+
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
|
| 251 |
+
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
|
| 252 |
+
for l in m2.addressable_shards:
|
| 253 |
+
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
|
| 254 |
+
|
| 255 |
+
@parameterized.product(input_dtype=[jnp.int4, jnp.int8])
|
| 256 |
+
def test_checkpointing_with_int4(self, input_dtype):
|
| 257 |
+
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
| 258 |
+
global_input_shape = (8, 2)
|
| 259 |
+
num = math.prod(global_input_shape)
|
| 260 |
+
|
| 261 |
+
global_input_data = np.arange(num, dtype=input_dtype).reshape(
|
| 262 |
+
global_input_shape
|
| 263 |
+
)
|
| 264 |
+
def cb(index):
|
| 265 |
+
return global_input_data[index]
|
| 266 |
+
arr = array.make_array_from_callback(
|
| 267 |
+
global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb)
|
| 268 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
|
| 269 |
+
|
| 270 |
+
ckpt_paths = [str(ckpt_dir)]
|
| 271 |
+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
| 272 |
+
|
| 273 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 274 |
+
manager.serialize(
|
| 275 |
+
[arr], tspecs,
|
| 276 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 277 |
+
manager.wait_until_finished()
|
| 278 |
+
|
| 279 |
+
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
|
| 280 |
+
|
| 281 |
+
target_dtype = jnp.dtype('int4')
|
| 282 |
+
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
|
| 283 |
+
[target_dtype])
|
| 284 |
+
|
| 285 |
+
# values bigger than 7 are converted properly.
|
| 286 |
+
expected_data = {
|
| 287 |
+
0: jnp.array([[0], [2], [4]], dtype=target_dtype),
|
| 288 |
+
1: jnp.array([[1], [3], [5]], dtype=target_dtype),
|
| 289 |
+
2: jnp.array([[6], [8], [10]], dtype=target_dtype),
|
| 290 |
+
3: jnp.array([[7], [9], [11]], dtype=target_dtype),
|
| 291 |
+
4: jnp.array([[12], [14], [0]], dtype=target_dtype),
|
| 292 |
+
5: jnp.array([[13], [15], [0]], dtype=target_dtype),
|
| 293 |
+
6: jnp.array([[0], [0], [0]], dtype=target_dtype),
|
| 294 |
+
7: jnp.array([[0], [0], [0]], dtype=target_dtype),
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
for l in m1.addressable_shards:
|
| 298 |
+
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
| 299 |
+
|
| 300 |
+
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
|
| 301 |
+
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype])
|
| 302 |
+
for l in m2.addressable_shards:
|
| 303 |
+
self.assertArraysEqual(l.data, global_input_data.astype(target_dtype))
|
| 304 |
+
|
| 305 |
+
def test_checkpointing_scalar_jax_array(self):
|
| 306 |
+
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
| 307 |
+
global_input_shape = ()
|
| 308 |
+
data = np.array(4)
|
| 309 |
+
s = NamedSharding(global_mesh, P(None))
|
| 310 |
+
array1 = array.make_array_from_callback(
|
| 311 |
+
global_input_shape, s, lambda idx: data[idx])
|
| 312 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
|
| 313 |
+
|
| 314 |
+
ckpt_paths = [str(ckpt_dir)]
|
| 315 |
+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
|
| 316 |
+
|
| 317 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 318 |
+
manager.serialize(
|
| 319 |
+
[array1], tspecs,
|
| 320 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 321 |
+
manager.wait_until_finished()
|
| 322 |
+
|
| 323 |
+
ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))
|
| 324 |
+
|
| 325 |
+
m1, = serialization.run_deserialization(
|
| 326 |
+
[ds],
|
| 327 |
+
tspecs,
|
| 328 |
+
[()],
|
| 329 |
+
[np.float32]
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
for l in m1.addressable_shards:
|
| 333 |
+
self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
|
| 334 |
+
|
| 335 |
+
def test_deserialize_tensorstore_array_jax_array(self):
|
| 336 |
+
global_mesh = jtu.create_global_mesh((2,), ('x'))
|
| 337 |
+
data = np.arange(1024)
|
| 338 |
+
tspec = ts.array(data).spec()
|
| 339 |
+
m1, = serialization.run_deserialization(
|
| 340 |
+
[NamedSharding(global_mesh, P(None))],
|
| 341 |
+
[tspec]
|
| 342 |
+
)
|
| 343 |
+
for l in m1.addressable_shards:
|
| 344 |
+
self.assertArraysEqual(np.asarray(l.data), data)
|
| 345 |
+
|
| 346 |
+
def test_spec_has_metadata(self):
|
| 347 |
+
spec = {
|
| 348 |
+
'a': {
|
| 349 |
+
'b': 1,
|
| 350 |
+
'c': 2,
|
| 351 |
+
},
|
| 352 |
+
'd': 3,
|
| 353 |
+
'e': {
|
| 354 |
+
'a': 2,
|
| 355 |
+
'metadata': 3
|
| 356 |
+
},
|
| 357 |
+
'f': 4
|
| 358 |
+
}
|
| 359 |
+
self.assertTrue(serialization._spec_has_metadata(spec))
|
| 360 |
+
self.assertTrue(
|
| 361 |
+
serialization._spec_has_metadata({
|
| 362 |
+
'driver': 'zarr',
|
| 363 |
+
'kvstore': 'gfile',
|
| 364 |
+
'metadata': {
|
| 365 |
+
'chunks': 4,
|
| 366 |
+
'shape': (32, 64)
|
| 367 |
+
},
|
| 368 |
+
'one_more': 'thing'
|
| 369 |
+
}))
|
| 370 |
+
|
| 371 |
+
def test_spec_has_no_metadata(self):
|
| 372 |
+
spec = {
|
| 373 |
+
'a': {
|
| 374 |
+
'b': 1,
|
| 375 |
+
'c': 2,
|
| 376 |
+
},
|
| 377 |
+
'd': 3,
|
| 378 |
+
'e': {
|
| 379 |
+
'a': 2,
|
| 380 |
+
},
|
| 381 |
+
'f': 4
|
| 382 |
+
}
|
| 383 |
+
self.assertFalse(serialization._spec_has_metadata(spec))
|
| 384 |
+
|
| 385 |
+
def test_empty_spec_has_no_metadata(self):
|
| 386 |
+
spec = {}
|
| 387 |
+
self.assertFalse(serialization._spec_has_metadata(spec))
|
| 388 |
+
|
| 389 |
+
@parameterized.named_parameters(
|
| 390 |
+
('gcs', 'gs://my/ckpt/dir/path'),
|
| 391 |
+
('file', '/my/ckpt/dir/path')
|
| 392 |
+
)
|
| 393 |
+
def test_get_tensorstore_spec_ocdbt(self, path):
|
| 394 |
+
spec = serialization.get_tensorstore_spec(path, ocdbt=True)
|
| 395 |
+
is_gcs_path = path.startswith('gs://')
|
| 396 |
+
if is_gcs_path:
|
| 397 |
+
self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
|
| 398 |
+
else:
|
| 399 |
+
self.assertEqual(spec['kvstore']['base'],
|
| 400 |
+
f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}')
|
| 401 |
+
self.assertEqual(spec['kvstore']['path'], 'path')
|
| 402 |
+
|
| 403 |
+
def test_get_tensorstore_spec_not_absolute_path(self):
|
| 404 |
+
path = 'my/ckpt/path'
|
| 405 |
+
with self.assertRaisesRegex(ValueError,
|
| 406 |
+
"Checkpoint path should be absolute"):
|
| 407 |
+
serialization.get_tensorstore_spec(path, ocdbt=True)
|
| 408 |
+
|
| 409 |
+
def test_maybe_cloud_storage(self):
|
| 410 |
+
gs_path = 'gs://some-buck/path'
|
| 411 |
+
gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True)
|
| 412 |
+
self.assertTrue(serialization.is_remote_storage(gs_spec))
|
| 413 |
+
|
| 414 |
+
local_path = '/tmp/checkpoint'
|
| 415 |
+
local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True)
|
| 416 |
+
self.assertFalse(serialization.is_remote_storage(local_spec))
|
| 417 |
+
|
| 418 |
+
nested_tspec = {
|
| 419 |
+
'driver': 'cast',
|
| 420 |
+
'dtype': 'int32',
|
| 421 |
+
'base': {
|
| 422 |
+
'driver': 'zarr',
|
| 423 |
+
'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'},
|
| 424 |
+
},
|
| 425 |
+
}
|
| 426 |
+
self.assertTrue(serialization.is_remote_storage(nested_tspec))
|
| 427 |
+
|
| 428 |
+
def test_load_with_layout(self):
|
| 429 |
+
if not jtu.test_device_matches(['tpu']):
|
| 430 |
+
self.skipTest('Layouts are only supported on TPUs')
|
| 431 |
+
|
| 432 |
+
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
| 433 |
+
np_inp = np.arange(32).reshape(8, 4)
|
| 434 |
+
s = NamedSharding(mesh, P('x', 'y'))
|
| 435 |
+
arr = jax.device_put(np_inp, s)
|
| 436 |
+
|
| 437 |
+
out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower(
|
| 438 |
+
arr).compile().output_layouts()
|
| 439 |
+
self.assertEqual(extract_minor_to_major(arr.layout),
|
| 440 |
+
extract_minor_to_major(out_layout)[::-1])
|
| 441 |
+
|
| 442 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
|
| 443 |
+
ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
|
| 444 |
+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path])
|
| 445 |
+
|
| 446 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 447 |
+
manager.serialize(
|
| 448 |
+
[arr], tspecs,
|
| 449 |
+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
|
| 450 |
+
manager.wait_until_finished()
|
| 451 |
+
|
| 452 |
+
out, = serialization.run_deserialization([out_layout], tspecs)
|
| 453 |
+
|
| 454 |
+
self.assertEqual(out.layout, out_layout)
|
| 455 |
+
self.assertIsInstance(out, array.ArrayImpl)
|
| 456 |
+
self.assertArraysEqual(out, np_inp)
|
| 457 |
+
for s in out.addressable_shards:
|
| 458 |
+
self.assertArraysEqual(s.data, np_inp[s.index])
|
| 459 |
+
|
| 460 |
+
def test_deserialization_with_int4(self):
|
| 461 |
+
dtype = jnp.int4
|
| 462 |
+
shape = (8, 2)
|
| 463 |
+
arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
|
| 464 |
+
|
| 465 |
+
ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path)
|
| 466 |
+
|
| 467 |
+
# Run serialization.
|
| 468 |
+
sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices())
|
| 469 |
+
tspecs = jax.tree_util.tree_map(
|
| 470 |
+
serialization.get_tensorstore_spec, [ckpt_dir]
|
| 471 |
+
)
|
| 472 |
+
manager = serialization.GlobalAsyncCheckpointManager()
|
| 473 |
+
manager.serialize(
|
| 474 |
+
[arr],
|
| 475 |
+
tspecs,
|
| 476 |
+
on_commit_callback=lambda: None,
|
| 477 |
+
)
|
| 478 |
+
manager.wait_until_finished()
|
| 479 |
+
|
| 480 |
+
# Run deserialization.
|
| 481 |
+
deserialized_arr, = serialization.run_deserialization(
|
| 482 |
+
shardings=[sharding],
|
| 483 |
+
tensorstore_specs=tspecs,
|
| 484 |
+
global_shapes=[shape],
|
| 485 |
+
dtypes=[dtype],
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
out = deserialized_arr.astype(jnp.int8) # doesn't crash
|
| 489 |
+
self.assertEqual(out.dtype, jnp.int8)
|
| 490 |
+
self.assertArraysEqual(out + out, out * 2)
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/compilation_cache.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from jax._src.compilation_cache import (
|
| 16 |
+
is_initialized as is_initialized, # deprecated
|
| 17 |
+
initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead
|
| 18 |
+
set_cache_dir as set_cache_dir,
|
| 19 |
+
reset_cache as reset_cache,
|
| 20 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/export/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
from jax._src.export._export import (
|
| 17 |
+
minimum_supported_serialization_version,
|
| 18 |
+
maximum_supported_serialization_version,
|
| 19 |
+
Exported,
|
| 20 |
+
export,
|
| 21 |
+
call_exported, # TODO: deprecate
|
| 22 |
+
call,
|
| 23 |
+
DisabledSafetyCheck,
|
| 24 |
+
default_lowering_platform,
|
| 25 |
+
)
|
| 26 |
+
from jax._src.export.shape_poly import (
|
| 27 |
+
is_symbolic_dim,
|
| 28 |
+
symbolic_shape,
|
| 29 |
+
symbolic_args_specs,
|
| 30 |
+
SymbolicScope,
|
| 31 |
+
)
|
| 32 |
+
from jax._src.export.serialization import (
|
| 33 |
+
serialize,
|
| 34 |
+
deserialize,
|
| 35 |
+
)
|
| 36 |
+
from jax._src.export import shape_poly_decision
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from jax.experimental.jax2tf.jax2tf import (
|
| 16 |
+
convert as convert,
|
| 17 |
+
eval_polymorphic_shape as eval_polymorphic_shape,
|
| 18 |
+
dtype_of_val as dtype_of_val,
|
| 19 |
+
split_to_logical_devices as split_to_logical_devices,
|
| 20 |
+
DisabledSafetyCheck as DisabledSafetyCheck,
|
| 21 |
+
PolyShape as PolyShape # TODO: deprecate
|
| 22 |
+
)
|
| 23 |
+
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/call_tf.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Allows JAX to call TensorFlow functions with support for autodiff.
|
| 15 |
+
|
| 16 |
+
**Experimental: please give feedback, and expect changes.**
|
| 17 |
+
|
| 18 |
+
This module introduces the function :func:`call_tf` that allows JAX to call
|
| 19 |
+
TensorFlow functions.
|
| 20 |
+
|
| 21 |
+
For examples and details, see
|
| 22 |
+
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
from collections.abc import Sequence
|
| 29 |
+
import dataclasses
|
| 30 |
+
import functools
|
| 31 |
+
from typing import Any, Callable, Optional
|
| 32 |
+
|
| 33 |
+
from absl import logging
|
| 34 |
+
import jax
|
| 35 |
+
from jax import dlpack
|
| 36 |
+
from jax import dtypes
|
| 37 |
+
from jax import numpy as jnp
|
| 38 |
+
from jax import tree_util
|
| 39 |
+
from jax._src import ad_util
|
| 40 |
+
from jax._src import core
|
| 41 |
+
from jax._src import effects
|
| 42 |
+
from jax._src import util
|
| 43 |
+
from jax._src.lib import xla_client
|
| 44 |
+
from jax._src.lib.mlir import ir
|
| 45 |
+
from jax._src.lib.mlir.dialects import func as func_dialect
|
| 46 |
+
from jax._src.lib.mlir.dialects import hlo
|
| 47 |
+
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
| 48 |
+
from jax.interpreters import mlir
|
| 49 |
+
import numpy as np
|
| 50 |
+
import tensorflow as tf
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
map = util.safe_map
|
| 54 |
+
zip = util.safe_zip
|
| 55 |
+
|
| 56 |
+
TfConcreteFunction = Any
|
| 57 |
+
TfVal = jax2tf_internal.TfVal
|
| 58 |
+
|
| 59 |
+
# The platforms for which to use DLPack to avoid copying (only works on GPU
|
| 60 |
+
# and CPU at the moment, and only for Array). For CPU we don't need
|
| 61 |
+
# DLPack, if we are careful.
|
| 62 |
+
_DLPACK_PLATFORMS = ("gpu",)
|
| 63 |
+
|
| 64 |
+
class UnspecifiedOutputShapeDtype:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
def call_tf(
|
| 68 |
+
callable_tf: Callable,
|
| 69 |
+
has_side_effects=True,
|
| 70 |
+
ordered=False,
|
| 71 |
+
output_shape_dtype=UnspecifiedOutputShapeDtype(),
|
| 72 |
+
call_tf_graph=False,
|
| 73 |
+
) -> Callable:
|
| 74 |
+
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
|
| 75 |
+
|
| 76 |
+
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
|
| 77 |
+
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
|
| 78 |
+
function must return the same type of results.
|
| 79 |
+
|
| 80 |
+
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
|
| 81 |
+
or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
|
| 82 |
+
``callable_tf`` will be compiled with ``tf.function(callable_tf,
|
| 83 |
+
jit_compile=True)``
|
| 84 |
+
and the resulting XLA computation will be embedded in JAX's XLA computation.
|
| 85 |
+
|
| 86 |
+
If ``call_tf`` appears outside a JAX staging context, it will be called inline
|
| 87 |
+
using TensorFlow eager mode.
|
| 88 |
+
|
| 89 |
+
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
|
| 90 |
+
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
|
| 91 |
+
that the gradient will be TensorFlow-accurate, e.g., will respect the
|
| 92 |
+
custom gradients that may be defined for the code in ``callable_tf``.
|
| 93 |
+
|
| 94 |
+
For an example and more details see the
|
| 95 |
+
`README
|
| 96 |
+
<https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
|
| 100 |
+
arguments.
|
| 101 |
+
has_side_effects: if True then it ensures that instances of this primitive
|
| 102 |
+
are not removed or replicated by JAX optimizations such as dead-code
|
| 103 |
+
elimination.
|
| 104 |
+
ordered: If true, calls are modeled as having ordered effects.
|
| 105 |
+
output_shape_dtype: An optional declaration of the expected shape and dtype
|
| 106 |
+
of the result of the called TensorFlow function. If given it will be used
|
| 107 |
+
during JAX tracing to form the abstract values of the results of the
|
| 108 |
+
`call_tf`. If not given then we form a `tf.Graph` for the called
|
| 109 |
+
TensorFlow function and we use the TensorFlow-inferred shapes and types.
|
| 110 |
+
Must be a pytree matching the structure of the nested structure returned
|
| 111 |
+
from the TensorFlow function, containing objects with `.shape` and
|
| 112 |
+
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
|
| 113 |
+
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
|
| 114 |
+
future.
|
| 115 |
+
|
| 116 |
+
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
|
| 117 |
+
op-by-op mode or in a staged context. This callable can be used with JAX's
|
| 118 |
+
reverse-mode autodiff (:func:`jax.grad`).
|
| 119 |
+
"""
|
| 120 |
+
@jax.custom_vjp
|
| 121 |
+
def make_call(*args_jax):
|
| 122 |
+
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
| 123 |
+
|
| 124 |
+
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
|
| 125 |
+
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
|
| 126 |
+
def canonical_arg(v):
|
| 127 |
+
v = v if getattr(v, "dtype", None) else np.asarray(v)
|
| 128 |
+
dtype = dtypes.canonicalize_dtype(v.dtype)
|
| 129 |
+
if dtype != v.dtype:
|
| 130 |
+
v = v.astype(dtype)
|
| 131 |
+
return v
|
| 132 |
+
|
| 133 |
+
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
|
| 134 |
+
def make_tensorspec(a_jax):
|
| 135 |
+
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
|
| 136 |
+
a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
|
| 137 |
+
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
|
| 138 |
+
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
| 139 |
+
|
| 140 |
+
if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
|
| 141 |
+
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
|
| 142 |
+
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
|
| 143 |
+
else:
|
| 144 |
+
output_avals, output_shape_dtype_tree = None, None
|
| 145 |
+
|
| 146 |
+
res_treedef = None # We'll store here the result treedef
|
| 147 |
+
res_tf_flat = None # For error reporting
|
| 148 |
+
# The function below will be called at least once, either in eager
|
| 149 |
+
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
|
| 150 |
+
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
|
| 151 |
+
args_tf = args_treedef.unflatten(args_tf_flat)
|
| 152 |
+
res_tf = callable_tf(*args_tf)
|
| 153 |
+
|
| 154 |
+
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
|
| 155 |
+
# returns a `StatefulPartitionedCall` (if the function is stateful) or
|
| 156 |
+
# `PartitionedCall` (if the function is stateless) op instead of
|
| 157 |
+
# tf.Tensors. We work around this issue by replacing the output `res_tf`
|
| 158 |
+
# with an empty list.
|
| 159 |
+
|
| 160 |
+
if isinstance(res_tf, tf.Operation):
|
| 161 |
+
assert (
|
| 162 |
+
res_tf.type == "StatefulPartitionedCall"
|
| 163 |
+
or res_tf.type == "PartitionedCall"
|
| 164 |
+
)
|
| 165 |
+
t_out = res_tf.get_attr("Tout")
|
| 166 |
+
# t_out should be an empty list.
|
| 167 |
+
assert not t_out, (
|
| 168 |
+
"The TF function returned an unexpected result, please check its"
|
| 169 |
+
f" function body. res_tf = {res_tf}"
|
| 170 |
+
)
|
| 171 |
+
res_tf = t_out
|
| 172 |
+
|
| 173 |
+
nonlocal res_treedef, res_tf_flat
|
| 174 |
+
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
|
| 175 |
+
assert res_treedef is None or res_treedef == res_treedef_now, (
|
| 176 |
+
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
|
| 177 |
+
res_treedef = res_treedef_now
|
| 178 |
+
if output_avals is not None:
|
| 179 |
+
if res_treedef != output_shape_dtype_tree:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"The pytree of the TensorFlow function results does not match the "
|
| 182 |
+
"pytree of the declared output_shape_dtype:\n"
|
| 183 |
+
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
|
| 184 |
+
assert len(output_avals) == len(res_tf_flat)
|
| 185 |
+
|
| 186 |
+
checked_res_tf_flat = [
|
| 187 |
+
check_tf_result(i, r_tf, r_aval)
|
| 188 |
+
for i, (r_tf, r_aval) in enumerate(
|
| 189 |
+
zip(res_tf_flat,
|
| 190 |
+
(output_avals
|
| 191 |
+
if output_avals is not None
|
| 192 |
+
else (None,) * len(res_tf_flat))))]
|
| 193 |
+
return checked_res_tf_flat
|
| 194 |
+
|
| 195 |
+
# Prepare a tf.function ahead of time, to cache the concrete functions. This
|
| 196 |
+
# won't be used in op-by-op execution mode.
|
| 197 |
+
function_flat_tf = tf.function(
|
| 198 |
+
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
|
| 199 |
+
|
| 200 |
+
res_jax_flat = call_tf_p.bind(
|
| 201 |
+
*args_flat_jax,
|
| 202 |
+
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
| 203 |
+
callable_flat_tf=callable_flat_tf,
|
| 204 |
+
function_flat_tf=function_flat_tf,
|
| 205 |
+
args_flat_sig_tf=args_flat_sig_tf,
|
| 206 |
+
output_avals=output_avals,
|
| 207 |
+
has_side_effects=has_side_effects,
|
| 208 |
+
ordered=ordered,
|
| 209 |
+
call_tf_graph=call_tf_graph,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# We must have called callable_flat_tf by nοw
|
| 213 |
+
assert res_treedef is not None
|
| 214 |
+
return res_treedef.unflatten(res_jax_flat)
|
| 215 |
+
|
| 216 |
+
# Define the fwd and bwd custom_vjp functions
|
| 217 |
+
def make_call_vjp_fwd(*args_jax):
|
| 218 |
+
# Return the primal arguments as the residual
|
| 219 |
+
return make_call(*args_jax), args_jax
|
| 220 |
+
|
| 221 |
+
def make_call_vjp_bwd(residual_jax, ct_res_jax):
|
| 222 |
+
args_jax = residual_jax # residual is the primal argument
|
| 223 |
+
|
| 224 |
+
def tf_vjp_fun(args_tf, ct_res_tf):
|
| 225 |
+
"""Invoke TF gradient."""
|
| 226 |
+
|
| 227 |
+
# TF does not like us to watch non-float vars
|
| 228 |
+
def replace_non_float(arg_tf):
|
| 229 |
+
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
|
| 230 |
+
return arg_tf
|
| 231 |
+
else:
|
| 232 |
+
# When watched, this will be ignored. When used in results it will
|
| 233 |
+
# result in a floating 0. gradient, which JAX will ignore (and
|
| 234 |
+
# replace it with a float0)
|
| 235 |
+
return tf.zeros((), dtype=tf.float32)
|
| 236 |
+
|
| 237 |
+
watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
|
| 238 |
+
with tf.GradientTape(persistent=True) as tape:
|
| 239 |
+
tape.watch(watched_args_tf)
|
| 240 |
+
res = callable_tf(*args_tf)
|
| 241 |
+
|
| 242 |
+
tf.nest.assert_same_structure(res, ct_res_tf)
|
| 243 |
+
dres_darg = tape.gradient(
|
| 244 |
+
tf.nest.map_structure(replace_non_float, res),
|
| 245 |
+
sources=watched_args_tf,
|
| 246 |
+
output_gradients=ct_res_tf,
|
| 247 |
+
unconnected_gradients=tf.UnconnectedGradients.ZERO)
|
| 248 |
+
|
| 249 |
+
dres_darg = tree_util.tree_map(
|
| 250 |
+
lambda x: x if x is None else tf.convert_to_tensor(x),
|
| 251 |
+
dres_darg,
|
| 252 |
+
)
|
| 253 |
+
tf.nest.assert_same_structure(dres_darg, args_tf)
|
| 254 |
+
return dres_darg
|
| 255 |
+
|
| 256 |
+
# Use call_tf to call the VJP function
|
| 257 |
+
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
|
| 258 |
+
# We must make the float0s that JAX expects
|
| 259 |
+
def fix_float0(arg_jax, ct_arg_jax):
|
| 260 |
+
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
|
| 261 |
+
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
|
| 262 |
+
if ct_arg_dtype != ct_arg_jax.dtype:
|
| 263 |
+
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
|
| 264 |
+
ct_arg_dtype))
|
| 265 |
+
return ct_arg_jax
|
| 266 |
+
|
| 267 |
+
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
|
| 268 |
+
return ct_args_jax_fixed
|
| 269 |
+
|
| 270 |
+
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
| 271 |
+
return util.wraps(callable_tf)(make_call)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
|
| 275 |
+
# Check that the TF function returns values of expected types. This
|
| 276 |
+
# improves error reporting, preventing hard-to-diagnose errors downstream
|
| 277 |
+
try:
|
| 278 |
+
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
|
| 279 |
+
except Exception as e:
|
| 280 |
+
msg = ("The called TF function returns a result that is not "
|
| 281 |
+
f"convertible to JAX: {r_tf}.")
|
| 282 |
+
raise ValueError(msg) from e
|
| 283 |
+
|
| 284 |
+
if r_aval is None:
|
| 285 |
+
return r_tf
|
| 286 |
+
# We convert to TF type, and canonicalize to 32-bit if necessary
|
| 287 |
+
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
|
| 288 |
+
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
|
| 289 |
+
# check at runtime that the returned shape matches the declared shape. I wish
|
| 290 |
+
# that tf.ensure_shape did this, but it can only take shapes that contain None
|
| 291 |
+
# not computed shapes. However, in eager mode we should be able to resolve
|
| 292 |
+
# the declared shapes to constants and we get better checking.
|
| 293 |
+
if tf.executing_eagerly():
|
| 294 |
+
r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
|
| 295 |
+
else:
|
| 296 |
+
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
|
| 297 |
+
# We do as much checking as we can here, instead of relying on tf.ensure_shape
|
| 298 |
+
# because the latter gives different errors in eager vs. compiled mode.
|
| 299 |
+
# TODO(b/279454591): This strange error is from TF. Eager function suppose
|
| 300 |
+
# return tf Val with concrete shape but not. Here we change exception to warn
|
| 301 |
+
# and bypass it. This case need revisit on TF side.
|
| 302 |
+
try:
|
| 303 |
+
_ = len(r_tf.shape)
|
| 304 |
+
except ValueError as e:
|
| 305 |
+
msg = (
|
| 306 |
+
"The shape check test cannot be performed because the shape of the"
|
| 307 |
+
"`r_tf` tensor cannot be obtained."
|
| 308 |
+
f"r_tf = {r_tf}, r_aval = {r_aval}"
|
| 309 |
+
)
|
| 310 |
+
msg += str(e)
|
| 311 |
+
logging.warning(msg)
|
| 312 |
+
return r_tf
|
| 313 |
+
if (r_tf.dtype != r_aval_dtype_tf or
|
| 314 |
+
len(r_tf.shape) != len(r_aval_shape_tf) or
|
| 315 |
+
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
|
| 316 |
+
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
|
| 317 |
+
msg = ("The shapes or dtypes returned by the TensorFlow function "
|
| 318 |
+
"do not match the declared output_shape_dtype:\n"
|
| 319 |
+
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
|
| 320 |
+
raise ValueError(msg)
|
| 321 |
+
# At this point tf.ensure_shape does not do much, it should never throw an
|
| 322 |
+
# error, albeit it may refine the shape a bit.
|
| 323 |
+
return tf.ensure_shape(r_tf, r_aval_shape_tf)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
call_tf_p = core.Primitive("call_tf")
|
| 327 |
+
call_tf_p.multiple_results = True
|
| 328 |
+
|
| 329 |
+
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
|
| 330 |
+
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
|
| 331 |
+
# On GPU we use dlpack to avoid copies of data to the host.
|
| 332 |
+
def _arg_jax_to_tf(arg_jax):
|
| 333 |
+
if (isinstance(arg_jax, jax.Array) and
|
| 334 |
+
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
|
| 335 |
+
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
|
| 336 |
+
arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
|
| 337 |
+
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
|
| 338 |
+
# The following avoids copies to the host on CPU, always for Array
|
| 339 |
+
# and even for ndarray if they are sufficiently aligned.
|
| 340 |
+
# TODO(necula): on TPU this copies to the host!
|
| 341 |
+
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
|
| 342 |
+
return tf.zeros(shape=arg_jax.shape,
|
| 343 |
+
dtype=jax2tf_internal._tf_np_dtype_for_float0)
|
| 344 |
+
return tf.constant(np.asarray(arg_jax))
|
| 345 |
+
|
| 346 |
+
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
| 347 |
+
with jax2tf_internal.inside_call_tf():
|
| 348 |
+
# Call in TF eager mode
|
| 349 |
+
res_tf_flat = callable_flat_tf(*args_tf_flat)
|
| 350 |
+
|
| 351 |
+
def _res_tf_to_jax(res_tf: TfVal):
|
| 352 |
+
res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
|
| 353 |
+
if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES:
|
| 354 |
+
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
|
| 355 |
+
res_jax_platform = res_tf_platform.lower()
|
| 356 |
+
if res_jax_platform in _DLPACK_PLATFORMS:
|
| 357 |
+
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
|
| 358 |
+
return jax.dlpack.from_dlpack(res_dlpack)
|
| 359 |
+
|
| 360 |
+
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
|
| 361 |
+
# To handle this special case, we create a numpy copy.
|
| 362 |
+
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
|
| 363 |
+
return jax.device_put(jnp.array(res_tf.numpy()))
|
| 364 |
+
else:
|
| 365 |
+
return jax.device_put(np.asarray(res_tf))
|
| 366 |
+
|
| 367 |
+
return list(map(_res_tf_to_jax, res_tf_flat))
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
call_tf_p.def_impl(_call_tf_impl)
|
| 371 |
+
|
| 372 |
+
@functools.lru_cache(maxsize=128)
|
| 373 |
+
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
|
| 374 |
+
with jax2tf_internal.inside_call_tf():
|
| 375 |
+
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Mark the effectful instances of call_tf
|
| 379 |
+
@dataclasses.dataclass(frozen=True)
|
| 380 |
+
class CallTfEffect(effects.Effect):
|
| 381 |
+
__str__ = lambda _: "CallTfEffect"
|
| 382 |
+
|
| 383 |
+
call_tf_effect = CallTfEffect()
|
| 384 |
+
|
| 385 |
+
effects.lowerable_effects.add_type(CallTfEffect)
|
| 386 |
+
effects.control_flow_allowed_effects.add_type(CallTfEffect)
|
| 387 |
+
effects.remat_allowed_effects.add_type(CallTfEffect)
|
| 388 |
+
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class CallTfOrderedEffect(effects.Effect):
|
| 392 |
+
__str__ = lambda _: "CallTfOrderedEffect"
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
call_tf_ordered_effect = CallTfOrderedEffect()
|
| 396 |
+
|
| 397 |
+
effects.lowerable_effects.add_type(CallTfOrderedEffect)
|
| 398 |
+
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
|
| 399 |
+
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
|
| 400 |
+
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
|
| 401 |
+
effects.ordered_effects.add_type(CallTfOrderedEffect)
|
| 402 |
+
effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _call_tf_abstract_eval(
|
| 406 |
+
*args_flat_avals,
|
| 407 |
+
function_flat_tf,
|
| 408 |
+
args_flat_sig_tf,
|
| 409 |
+
has_side_effects,
|
| 410 |
+
ordered,
|
| 411 |
+
output_avals,
|
| 412 |
+
call_tf_graph,
|
| 413 |
+
**__,
|
| 414 |
+
):
|
| 415 |
+
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
|
| 416 |
+
effects = set()
|
| 417 |
+
if ordered:
|
| 418 |
+
effects.add(call_tf_ordered_effect)
|
| 419 |
+
elif has_side_effects:
|
| 420 |
+
effects.add(call_tf_effect)
|
| 421 |
+
|
| 422 |
+
# If no output_avals is given, then we ask TF to infer the output shapes.
|
| 423 |
+
# We call this even if output_avals is given because it will ensure that
|
| 424 |
+
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
|
| 425 |
+
# there is a small cost of calling it more often than needed.
|
| 426 |
+
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
|
| 427 |
+
args_flat_sig_tf)
|
| 428 |
+
|
| 429 |
+
# In the case that the tf.function has no return value
|
| 430 |
+
if len(concrete_function_flat_tf.outputs) == 0:
|
| 431 |
+
return (), effects
|
| 432 |
+
|
| 433 |
+
if output_avals is not None:
|
| 434 |
+
return output_avals, effects
|
| 435 |
+
|
| 436 |
+
def is_fully_known_shape(s):
|
| 437 |
+
return s.rank is not None and all(d is not None for d in s)
|
| 438 |
+
|
| 439 |
+
if all(is_fully_known_shape(s)
|
| 440 |
+
for s in concrete_function_flat_tf.output_shapes):
|
| 441 |
+
avals_from_tf = tuple(
|
| 442 |
+
# We convert to JAX type, and canonicalize to 32-bit if necessary
|
| 443 |
+
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
|
| 444 |
+
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
| 445 |
+
concrete_function_flat_tf.output_shapes))
|
| 446 |
+
return avals_from_tf, effects
|
| 447 |
+
|
| 448 |
+
msg = ("call_tf cannot call functions whose output has dynamic shape. "
|
| 449 |
+
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
|
| 450 |
+
"Consider using the `output_shape_dtype` argument to call_tf. "
|
| 451 |
+
"\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
|
| 452 |
+
" for a discussion.")
|
| 453 |
+
raise ValueError(msg)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _call_tf_lowering(
|
| 460 |
+
ctx: mlir.LoweringRuleContext,
|
| 461 |
+
*args_op,
|
| 462 |
+
platform,
|
| 463 |
+
function_flat_tf,
|
| 464 |
+
args_flat_sig_tf,
|
| 465 |
+
has_side_effects,
|
| 466 |
+
ordered,
|
| 467 |
+
call_tf_graph,
|
| 468 |
+
output_avals,
|
| 469 |
+
**_,
|
| 470 |
+
):
|
| 471 |
+
# We use the same TF lowering device as for the embedding JAX computation.
|
| 472 |
+
# One example when this is needed is when the code refers to variables on one
|
| 473 |
+
# device. Or, for sharding annotations (only supported on TPU).
|
| 474 |
+
|
| 475 |
+
if platform in ["cpu", "tpu"]:
|
| 476 |
+
tf_platform = platform.upper()
|
| 477 |
+
elif platform == "cuda":
|
| 478 |
+
tf_platform = "GPU"
|
| 479 |
+
else:
|
| 480 |
+
raise ValueError("platform {platform} not supported")
|
| 481 |
+
|
| 482 |
+
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
|
| 483 |
+
|
| 484 |
+
captured_inputs = []
|
| 485 |
+
if concrete_function_flat_tf.captured_inputs:
|
| 486 |
+
# The function uses either captured variables or tensors.
|
| 487 |
+
msg = (
|
| 488 |
+
"call_tf works best with a TensorFlow function that does not capture "
|
| 489 |
+
"variables or tensors from the context. "
|
| 490 |
+
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
|
| 491 |
+
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
| 492 |
+
logging.warning(msg)
|
| 493 |
+
for inp in concrete_function_flat_tf.captured_inputs:
|
| 494 |
+
if inp.dtype == tf.resource: # A variable; lookup by handle
|
| 495 |
+
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
|
| 496 |
+
assert len(inp_vars) == 1, f"Found {inp_vars}"
|
| 497 |
+
captured_inputs.append(inp_vars[0])
|
| 498 |
+
else:
|
| 499 |
+
captured_inputs.append(inp)
|
| 500 |
+
|
| 501 |
+
captured_ops = tuple(
|
| 502 |
+
mlir.ir_constant(np.asarray(inp))
|
| 503 |
+
for inp in captured_inputs
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if call_tf_graph:
|
| 507 |
+
with jax2tf_internal.inside_call_tf():
|
| 508 |
+
return emit_tf_embedded_graph_custom_call(
|
| 509 |
+
ctx,
|
| 510 |
+
concrete_function_flat_tf,
|
| 511 |
+
tuple(args_op) + captured_ops,
|
| 512 |
+
has_side_effects,
|
| 513 |
+
ordered,
|
| 514 |
+
output_avals,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
def convert_to_spec(x):
|
| 518 |
+
if isinstance(x, tf.TensorSpec):
|
| 519 |
+
return x
|
| 520 |
+
else:
|
| 521 |
+
return tf.TensorSpec.from_tensor(x)
|
| 522 |
+
|
| 523 |
+
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
|
| 524 |
+
|
| 525 |
+
with jax2tf_internal.inside_call_tf():
|
| 526 |
+
try:
|
| 527 |
+
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
|
| 528 |
+
*args_tf_flat
|
| 529 |
+
)(stage="hlo_serialized", platform_name=tf_platform)
|
| 530 |
+
except Exception as e:
|
| 531 |
+
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
|
| 532 |
+
"\ncall_tf can used " +
|
| 533 |
+
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
| 534 |
+
"compilable functions with static output shapes.\n" +
|
| 535 |
+
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
|
| 536 |
+
"\n\nCaught TensorFlow exception: " + str(e))
|
| 537 |
+
raise ValueError(msg) from e
|
| 538 |
+
|
| 539 |
+
xla_comp = xla_client.XlaComputation(func_tf_hlo)
|
| 540 |
+
|
| 541 |
+
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
| 542 |
+
def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
|
| 543 |
+
if not res_shape.is_static():
|
| 544 |
+
msg = ("Compiled TensorFlow function has dynamic output shape " +
|
| 545 |
+
f"{res_shape}. call_tf can used " +
|
| 546 |
+
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
| 547 |
+
"compilable functions with static output shapes. " +
|
| 548 |
+
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
|
| 549 |
+
raise ValueError(msg)
|
| 550 |
+
|
| 551 |
+
res_dtype = res_shape.numpy_dtype()
|
| 552 |
+
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
| 553 |
+
return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
|
| 554 |
+
|
| 555 |
+
result_shape = xla_comp.program_shape().result_shape()
|
| 556 |
+
if not result_shape.is_tuple():
|
| 557 |
+
# TF does not wrap singletons as tuples, but JAX expects tuples because
|
| 558 |
+
# call_tf is a multiple_results primitive.
|
| 559 |
+
result_shapes = (result_shape,)
|
| 560 |
+
else:
|
| 561 |
+
result_shapes = result_shape.tuple_shapes() # type: ignore
|
| 562 |
+
|
| 563 |
+
result_avals = tuple(map(canonical_res_aval, result_shapes))
|
| 564 |
+
|
| 565 |
+
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
|
| 566 |
+
symtab = ir.SymbolTable(submodule.operation)
|
| 567 |
+
callee_result_types = symtab["main"].type.results
|
| 568 |
+
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
| 569 |
+
f"call_tf_{function_flat_tf.name}",
|
| 570 |
+
submodule,
|
| 571 |
+
dst_symtab=ctx.module_context.symbol_table)
|
| 572 |
+
call = func_dialect.CallOp(callee_result_types,
|
| 573 |
+
ir.FlatSymbolRefAttr.get(fn),
|
| 574 |
+
tuple(args_op) + captured_ops)
|
| 575 |
+
if result_shape.is_tuple():
|
| 576 |
+
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
|
| 577 |
+
for i in range(len(result_shapes))]
|
| 578 |
+
else:
|
| 579 |
+
flat_results = call.results
|
| 580 |
+
|
| 581 |
+
if ordered:
|
| 582 |
+
raise NotImplementedError(
|
| 583 |
+
"ordered=True is not supported in the jitted context without"
|
| 584 |
+
" `call_tf_graph=True`"
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
outputs = []
|
| 588 |
+
for op, res_aval, res_shape in zip(flat_results, result_avals,
|
| 589 |
+
result_shapes):
|
| 590 |
+
if res_aval.dtype != res_shape.numpy_dtype():
|
| 591 |
+
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
|
| 592 |
+
outputs.append(op)
|
| 593 |
+
return outputs
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def _register_call_lowering(platform):
|
| 597 |
+
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
|
| 598 |
+
platform=platform),
|
| 599 |
+
platform=platform)
|
| 600 |
+
for platform in ("cpu", "cuda", "tpu"):
|
| 601 |
+
_register_call_lowering(platform)
|
| 602 |
+
|
| 603 |
+
# Support the call_tf under jax2tf.convert in eager mode
|
| 604 |
+
def _jax2tf_call_tf(*args: TfVal,
|
| 605 |
+
callable_flat_tf: Callable,
|
| 606 |
+
**_) -> TfVal:
|
| 607 |
+
with jax2tf_internal.inside_call_tf():
|
| 608 |
+
res_tf_flat = callable_flat_tf(*args)
|
| 609 |
+
return res_tf_flat
|
| 610 |
+
|
| 611 |
+
jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def emit_tf_embedded_graph_custom_call(
|
| 615 |
+
ctx: mlir.LoweringRuleContext,
|
| 616 |
+
concrete_function_flat_tf,
|
| 617 |
+
operands: Sequence[ir.Value],
|
| 618 |
+
has_side_effects,
|
| 619 |
+
ordered,
|
| 620 |
+
output_avals,
|
| 621 |
+
):
|
| 622 |
+
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
|
| 623 |
+
|
| 624 |
+
All call_tf called function information is stored in tf.metadata.
|
| 625 |
+
This includes:
|
| 626 |
+
(1) The called function name: This name will be used by the runtime to execute
|
| 627 |
+
the callback.
|
| 628 |
+
(2) The called function index in the XLACallModule `function_list` attribute.
|
| 629 |
+
"""
|
| 630 |
+
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
|
| 631 |
+
if call_tf_concrete_function_list is None:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
| 634 |
+
)
|
| 635 |
+
# TODO(necula): It is dangerous to modify global state when lowering because
|
| 636 |
+
# there are a number of lowering caches that only cache the StableHLO.
|
| 637 |
+
# See call_tf_test.py:test_multi_platform_call_tf_graph.
|
| 638 |
+
called_index = add_to_call_tf_concrete_function_list(
|
| 639 |
+
concrete_function_flat_tf, call_tf_concrete_function_list)
|
| 640 |
+
tf_backend_config = {
|
| 641 |
+
"has_token_input_output": ir.BoolAttr.get(ordered),
|
| 642 |
+
"called_index": mlir.i64_attr(called_index),
|
| 643 |
+
}
|
| 644 |
+
result_avals = ctx.avals_out if ctx.avals_out is not None else ()
|
| 645 |
+
|
| 646 |
+
operands = list(operands)
|
| 647 |
+
result_types = list(
|
| 648 |
+
util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
|
| 649 |
+
)
|
| 650 |
+
if ordered:
|
| 651 |
+
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
|
| 652 |
+
result_types.insert(0, mlir.token_type()[0])
|
| 653 |
+
|
| 654 |
+
custom_call = hlo.CustomCallOp(
|
| 655 |
+
result_types,
|
| 656 |
+
operands,
|
| 657 |
+
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
|
| 658 |
+
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
| 659 |
+
api_version=mlir.i32_attr(2),
|
| 660 |
+
called_computations=ir.ArrayAttr.get([]),
|
| 661 |
+
backend_config=ir.StringAttr.get(""),
|
| 662 |
+
)
|
| 663 |
+
# Store TF metadata in unregistered attribute
|
| 664 |
+
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
|
| 665 |
+
tf_backend_config
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
results = list(custom_call.results)
|
| 669 |
+
if ordered:
|
| 670 |
+
token = results.pop(0)
|
| 671 |
+
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
|
| 672 |
+
|
| 673 |
+
return results
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
|
| 677 |
+
try:
|
| 678 |
+
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
|
| 679 |
+
except ValueError:
|
| 680 |
+
called_index = len(call_tf_concrete_function_list)
|
| 681 |
+
call_tf_concrete_function_list.append(concrete_tf_fn)
|
| 682 |
+
return called_index
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Demonstrates reuse of a jax2tf model in Keras.
|
| 15 |
+
|
| 16 |
+
Includes the flags from saved_model_main.py.
|
| 17 |
+
|
| 18 |
+
See README.md.
|
| 19 |
+
"""
|
| 20 |
+
import logging
|
| 21 |
+
from absl import app
|
| 22 |
+
from absl import flags
|
| 23 |
+
from jax.experimental.jax2tf.examples import mnist_lib
|
| 24 |
+
from jax.experimental.jax2tf.examples import saved_model_main
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
import tensorflow_datasets as tfds # type: ignore
|
| 27 |
+
import tensorflow_hub as hub # type: ignore
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
FLAGS = flags.FLAGS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main(_):
|
| 34 |
+
FLAGS.model_classifier_layer = False # We only need the features
|
| 35 |
+
# Train the model and save the feature extractor
|
| 36 |
+
saved_model_main.train_and_save()
|
| 37 |
+
|
| 38 |
+
tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances()
|
| 39 |
+
feature_model_dir = saved_model_main.savedmodel_dir()
|
| 40 |
+
|
| 41 |
+
# With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level
|
| 42 |
+
# analogue of the tf.device(...) placement seen above.
|
| 43 |
+
# It works on CPU, GPU and TPU.
|
| 44 |
+
# Actual high-performance training would use the appropriately replicated
|
| 45 |
+
# TF Distribution Strategy.
|
| 46 |
+
strategy = tf.distribute.OneDeviceStrategy(tf_accelerator)
|
| 47 |
+
with strategy.scope():
|
| 48 |
+
images = tf.keras.layers.Input(
|
| 49 |
+
mnist_lib.input_shape, batch_size=mnist_lib.train_batch_size)
|
| 50 |
+
keras_feature_extractor = hub.KerasLayer(feature_model_dir, trainable=True)
|
| 51 |
+
features = keras_feature_extractor(images)
|
| 52 |
+
predictor = tf.keras.layers.Dense(10, activation="softmax")
|
| 53 |
+
predictions = predictor(features)
|
| 54 |
+
keras_model = tf.keras.Model(images, predictions)
|
| 55 |
+
|
| 56 |
+
keras_model.compile(
|
| 57 |
+
loss=tf.keras.losses.categorical_crossentropy,
|
| 58 |
+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
|
| 59 |
+
metrics=["accuracy"])
|
| 60 |
+
logging.info(keras_model.summary())
|
| 61 |
+
|
| 62 |
+
train_ds = mnist_lib.load_mnist(
|
| 63 |
+
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
| 64 |
+
test_ds = mnist_lib.load_mnist(
|
| 65 |
+
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
| 66 |
+
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
|
| 67 |
+
|
| 68 |
+
if saved_model_main.SHOW_IMAGES.value:
|
| 69 |
+
mnist_lib.plot_images(
|
| 70 |
+
test_ds,
|
| 71 |
+
1,
|
| 72 |
+
5,
|
| 73 |
+
f"Keras inference with reuse of {saved_model_main.model_description()}",
|
| 74 |
+
inference_fn=lambda images: keras_model(tf.convert_to_tensor(images)))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
app.run(main)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main_test.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from absl import flags
|
| 17 |
+
from absl.testing import absltest
|
| 18 |
+
from absl.testing import parameterized
|
| 19 |
+
import jax
|
| 20 |
+
from jax._src import test_util as jtu
|
| 21 |
+
|
| 22 |
+
from jax.experimental.jax2tf.examples import keras_reuse_main
|
| 23 |
+
from jax.experimental.jax2tf.tests import tf_test_util
|
| 24 |
+
|
| 25 |
+
jax.config.parse_flags_with_absl()
|
| 26 |
+
FLAGS = flags.FLAGS
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
|
| 30 |
+
|
| 31 |
+
def setUp(self):
|
| 32 |
+
super().setUp()
|
| 33 |
+
FLAGS.model_path = os.path.join(absltest.get_default_test_tmpdir(),
|
| 34 |
+
"saved_models")
|
| 35 |
+
FLAGS.num_epochs = 1
|
| 36 |
+
FLAGS.test_savedmodel = True
|
| 37 |
+
FLAGS.mock_data = True
|
| 38 |
+
FLAGS.show_images = False
|
| 39 |
+
FLAGS.serving_batch_size = 1
|
| 40 |
+
|
| 41 |
+
@parameterized.named_parameters(
|
| 42 |
+
dict(testcase_name=f"_{model}", model=model)
|
| 43 |
+
for model in ["mnist_pure_jax", "mnist_flax"])
|
| 44 |
+
def test_keras_reuse(self, model="mnist_pure_jax"):
|
| 45 |
+
FLAGS.model = model
|
| 46 |
+
keras_reuse_main.main(None)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/mnist_lib.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Definitions of two versions of MNIST (model and training code ).
|
| 15 |
+
|
| 16 |
+
One definition uses pure JAX (for those who prefer an example with fewer
|
| 17 |
+
moving parts, at the expense of code size), and another using Flax.
|
| 18 |
+
|
| 19 |
+
See README.md for how these are used.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
from collections.abc import Sequence
|
| 25 |
+
import functools
|
| 26 |
+
import logging
|
| 27 |
+
import re
|
| 28 |
+
import time
|
| 29 |
+
from typing import Any, Callable, Optional
|
| 30 |
+
from absl import flags
|
| 31 |
+
|
| 32 |
+
import flax
|
| 33 |
+
from flax import linen as nn
|
| 34 |
+
|
| 35 |
+
import jax
|
| 36 |
+
import jax.numpy as jnp
|
| 37 |
+
|
| 38 |
+
from matplotlib import pyplot as plt
|
| 39 |
+
import numpy as np
|
| 40 |
+
import optax
|
| 41 |
+
import tensorflow as tf
|
| 42 |
+
import tensorflow_datasets as tfds # type: ignore
|
| 43 |
+
|
| 44 |
+
_MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
|
| 45 |
+
"Use fake data, for testing.")
|
| 46 |
+
|
| 47 |
+
#### Model parameters
|
| 48 |
+
|
| 49 |
+
# For fun, let's use different batch sizes for training and for evaluation.
|
| 50 |
+
train_batch_size = 128
|
| 51 |
+
test_batch_size = 16
|
| 52 |
+
|
| 53 |
+
# Define common parameters for both the JAX and the Flax models.
|
| 54 |
+
input_shape = (28, 28, 1) # Excluding batch_size
|
| 55 |
+
layer_sizes = [784, 512, 512, 10] # 10 is the number of classes
|
| 56 |
+
param_scale = 0.1
|
| 57 |
+
step_size = 0.001
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_mnist(split: tfds.Split, batch_size: int):
|
| 61 |
+
"""Loads either training or test MNIST data.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
split: either tfds.Split.TRAIN or tfds.Split.TEST.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
an iterator with pairs (images, labels). The images have shape
|
| 68 |
+
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
|
| 69 |
+
"""
|
| 70 |
+
if _MOCK_DATA.value:
|
| 71 |
+
with tfds.testing.mock_data(num_examples=batch_size):
|
| 72 |
+
try:
|
| 73 |
+
ds = tfds.load("mnist", split=split)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
|
| 76 |
+
if m:
|
| 77 |
+
msg = ("TFDS mock_data is missing the mnist metadata files. Run the "
|
| 78 |
+
"`saved_model_main.py` binary and see where TFDS downloads "
|
| 79 |
+
"the mnist data set (typically ~/tensorflow_datasets/mnist). "
|
| 80 |
+
f"Copy the `mnist` directory to {m.group(1)} and re-run the test")
|
| 81 |
+
raise ValueError(msg) from e
|
| 82 |
+
else:
|
| 83 |
+
raise e
|
| 84 |
+
else:
|
| 85 |
+
ds = tfds.load("mnist", split=split)
|
| 86 |
+
|
| 87 |
+
def _prepare_example(x):
|
| 88 |
+
image = tf.cast(x["image"], tf.float32) / 255.0
|
| 89 |
+
label = tf.one_hot(x["label"], 10)
|
| 90 |
+
return (image, label)
|
| 91 |
+
|
| 92 |
+
ds = ds.map(_prepare_example)
|
| 93 |
+
# drop_remainder=True is important for use with Keras
|
| 94 |
+
ds = ds.cache().shuffle(1000).batch(batch_size, drop_remainder=True)
|
| 95 |
+
return ds
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class PureJaxMNIST:
|
| 99 |
+
"""An MNIST model written using pure JAX.
|
| 100 |
+
|
| 101 |
+
There is an option for the model to skip the classifier layer, for
|
| 102 |
+
demonstrating reuse of the classifier-less model into a larger model.
|
| 103 |
+
See README.md.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
name = "mnist_pure_jax"
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def predict(params: Sequence[tuple[Any, Any]], inputs, with_classifier=True):
|
| 110 |
+
"""The prediction function.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
params: a list with pairs of weights and biases for each layer.
|
| 114 |
+
inputs: the batch of images (B, 28, 28, 1)
|
| 115 |
+
with_classifier: whether to include the classifier layer.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
either the predictions (B, 10) if with_classifier=True, or the
|
| 119 |
+
final set of logits of shape (B, 512).
|
| 120 |
+
"""
|
| 121 |
+
x = inputs.reshape((inputs.shape[0], -1)) # flatten to f32[B, 784]
|
| 122 |
+
for w, b in params[:-1]:
|
| 123 |
+
x = jnp.dot(x, w) + b
|
| 124 |
+
x = jnp.tanh(x)
|
| 125 |
+
|
| 126 |
+
if not with_classifier:
|
| 127 |
+
return x
|
| 128 |
+
final_w, final_b = params[-1]
|
| 129 |
+
logits = jnp.dot(x, final_w) + final_b
|
| 130 |
+
return logits - jax.scipy.special.logsumexp(
|
| 131 |
+
logits, axis=1, keepdims=True)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def loss(params, inputs, labels):
|
| 135 |
+
predictions = PureJaxMNIST.predict(params, inputs, with_classifier=True)
|
| 136 |
+
return -jnp.mean(jnp.sum(predictions * labels, axis=1))
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def accuracy(predict: Callable, params, dataset):
|
| 140 |
+
|
| 141 |
+
@jax.jit
|
| 142 |
+
def _per_batch(inputs, labels):
|
| 143 |
+
target_class = jnp.argmax(labels, axis=1)
|
| 144 |
+
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
| 145 |
+
return jnp.mean(predicted_class == target_class)
|
| 146 |
+
|
| 147 |
+
batched = [
|
| 148 |
+
_per_batch(inputs, labels) for inputs, labels in tfds.as_numpy(dataset)
|
| 149 |
+
]
|
| 150 |
+
return jnp.mean(jnp.stack(batched))
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def update(params, inputs, labels):
|
| 154 |
+
grads = jax.grad(PureJaxMNIST.loss)(params, inputs, labels)
|
| 155 |
+
return [(w - step_size * dw, b - step_size * db)
|
| 156 |
+
for (w, b), (dw, db) in zip(params, grads)]
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def train(train_ds, test_ds, num_epochs, with_classifier=True):
|
| 160 |
+
"""Trains a pure JAX MNIST predictor.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
a tuple with two elements:
|
| 164 |
+
- a predictor function with signature "(Params, ImagesBatch) ->
|
| 165 |
+
Predictions".
|
| 166 |
+
If `with_classifier=False` then the output of the predictor function
|
| 167 |
+
is the last layer of logits.
|
| 168 |
+
- the parameters "Params" for the predictor function
|
| 169 |
+
"""
|
| 170 |
+
rng = jax.random.PRNGKey(0)
|
| 171 |
+
params = [(param_scale * jax.random.normal(rng, (m, n)),
|
| 172 |
+
param_scale * jax.random.normal(rng, (n,)))
|
| 173 |
+
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
|
| 174 |
+
|
| 175 |
+
for epoch in range(num_epochs):
|
| 176 |
+
start_time = time.time()
|
| 177 |
+
for inputs, labels in tfds.as_numpy(train_ds):
|
| 178 |
+
params = jax.jit(PureJaxMNIST.update)(params, inputs, labels)
|
| 179 |
+
epoch_time = time.time() - start_time
|
| 180 |
+
train_acc = PureJaxMNIST.accuracy(PureJaxMNIST.predict, params, train_ds)
|
| 181 |
+
test_acc = PureJaxMNIST.accuracy(PureJaxMNIST.predict, params, test_ds)
|
| 182 |
+
logging.info("%s: Epoch %d in %0.2f sec", PureJaxMNIST.name, epoch,
|
| 183 |
+
epoch_time)
|
| 184 |
+
logging.info("%s: Training set accuracy %0.2f%%", PureJaxMNIST.name,
|
| 185 |
+
100. * train_acc)
|
| 186 |
+
logging.info("%s: Test set accuracy %0.2f%%", PureJaxMNIST.name,
|
| 187 |
+
100. * test_acc)
|
| 188 |
+
|
| 189 |
+
return (functools.partial(
|
| 190 |
+
PureJaxMNIST.predict, with_classifier=with_classifier), params)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class FlaxMNIST:
|
| 194 |
+
"""An MNIST model using Flax."""
|
| 195 |
+
|
| 196 |
+
name = "mnist_flax"
|
| 197 |
+
|
| 198 |
+
class Module(nn.Module):
|
| 199 |
+
"""A simple CNN model for MNIST.
|
| 200 |
+
|
| 201 |
+
There is an option for the model to skip the classifier layer, for
|
| 202 |
+
demonstrating reuse of the classifier-less model into a larger model.
|
| 203 |
+
See README.md.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
@nn.compact
|
| 207 |
+
def __call__(self, x, with_classifier=True):
|
| 208 |
+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
|
| 209 |
+
x = nn.relu(x)
|
| 210 |
+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
|
| 211 |
+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
|
| 212 |
+
x = nn.relu(x)
|
| 213 |
+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
|
| 214 |
+
x = x.reshape((x.shape[0], -1)) # flatten
|
| 215 |
+
x = nn.Dense(features=256)(x)
|
| 216 |
+
x = nn.relu(x)
|
| 217 |
+
if not with_classifier:
|
| 218 |
+
return x
|
| 219 |
+
x = nn.Dense(features=10)(x)
|
| 220 |
+
x = nn.log_softmax(x)
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
# Create the model and save it
|
| 224 |
+
model = Module()
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def predict(params, inputs, with_classifier=True):
|
| 228 |
+
return FlaxMNIST.model.apply({"params": params},
|
| 229 |
+
inputs,
|
| 230 |
+
with_classifier=with_classifier)
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def loss(params, inputs, labels): # Same as the pure JAX example
|
| 234 |
+
# Must use the classifier layer because the labels are classes
|
| 235 |
+
predictions = FlaxMNIST.predict(params, inputs, with_classifier=True)
|
| 236 |
+
return -jnp.mean(jnp.sum(predictions * labels, axis=1))
|
| 237 |
+
|
| 238 |
+
@staticmethod
|
| 239 |
+
def update(tx, params, opt_state, inputs, labels):
|
| 240 |
+
grad = jax.grad(FlaxMNIST.loss)(params, inputs, labels)
|
| 241 |
+
updates, opt_state = tx.update(grad, opt_state)
|
| 242 |
+
params = optax.apply_updates(params, updates)
|
| 243 |
+
return params, opt_state
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def train(train_ds, test_ds, num_epochs, with_classifier=True):
|
| 247 |
+
"""Trains a pure JAX MNIST predictor.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
a tuple with two elements:
|
| 251 |
+
- a predictor function with signature "(Params, ImagesBatch) ->
|
| 252 |
+
Predictions".
|
| 253 |
+
If `with_classifier=False` then the output of the predictor function
|
| 254 |
+
is the last layer of logits.
|
| 255 |
+
- the parameters "Params" for the predictor function
|
| 256 |
+
"""
|
| 257 |
+
rng = jax.random.PRNGKey(0)
|
| 258 |
+
momentum_mass = 0.9
|
| 259 |
+
|
| 260 |
+
init_shape = jnp.ones((1,) + input_shape, jnp.float32)
|
| 261 |
+
params = FlaxMNIST.model.init(rng, init_shape)["params"]
|
| 262 |
+
tx = optax.sgd(learning_rate=step_size, momentum=momentum_mass)
|
| 263 |
+
opt_state = tx.init(params)
|
| 264 |
+
|
| 265 |
+
for epoch in range(num_epochs):
|
| 266 |
+
start_time = time.time()
|
| 267 |
+
for inputs, labels in tfds.as_numpy(train_ds):
|
| 268 |
+
params, opt_state = jax.jit(FlaxMNIST.update,
|
| 269 |
+
static_argnums=0)(tx, params, opt_state,
|
| 270 |
+
inputs, labels)
|
| 271 |
+
epoch_time = time.time() - start_time
|
| 272 |
+
# Same accuracy function as for the pure JAX example
|
| 273 |
+
train_acc = PureJaxMNIST.accuracy(FlaxMNIST.predict, params,
|
| 274 |
+
train_ds)
|
| 275 |
+
test_acc = PureJaxMNIST.accuracy(FlaxMNIST.predict, params,
|
| 276 |
+
test_ds)
|
| 277 |
+
logging.info("%s: Epoch %d in %0.2f sec", FlaxMNIST.name, epoch,
|
| 278 |
+
epoch_time)
|
| 279 |
+
logging.info("%s: Training set accuracy %0.2f%%", FlaxMNIST.name,
|
| 280 |
+
100. * train_acc)
|
| 281 |
+
logging.info("%s: Test set accuracy %0.2f%%", FlaxMNIST.name,
|
| 282 |
+
100. * test_acc)
|
| 283 |
+
|
| 284 |
+
# See discussion in README.md for packaging Flax models for conversion
|
| 285 |
+
predict_fn = functools.partial(FlaxMNIST.predict,
|
| 286 |
+
with_classifier=with_classifier)
|
| 287 |
+
return (predict_fn, params)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def plot_images(ds,
|
| 291 |
+
nr_rows: int,
|
| 292 |
+
nr_cols: int,
|
| 293 |
+
title: str,
|
| 294 |
+
inference_fn: Callable | None = None):
|
| 295 |
+
"""Plots a grid of images with their predictions.
|
| 296 |
+
|
| 297 |
+
Params:
|
| 298 |
+
ds: a tensorflow dataset from where to pick the images and labels.
|
| 299 |
+
nr_rows, nr_cols: the size of the grid to plot
|
| 300 |
+
title: the title of the plot
|
| 301 |
+
inference_fn: if None then print the existing label, else use this function
|
| 302 |
+
on the batch of images to produce a batch of inference results, which
|
| 303 |
+
get printed.
|
| 304 |
+
inference_batch_size: the size of the batch of images passed to
|
| 305 |
+
`inference_fn`.
|
| 306 |
+
"""
|
| 307 |
+
count = nr_rows * nr_cols
|
| 308 |
+
fig = plt.figure(figsize=(8., 4.), num=title)
|
| 309 |
+
# Get the first batch
|
| 310 |
+
(images, labels), = list(tfds.as_numpy(ds.take(1)))
|
| 311 |
+
if inference_fn:
|
| 312 |
+
inferred_labels = inference_fn(images)
|
| 313 |
+
for i, image in enumerate(images[:count]):
|
| 314 |
+
digit = fig.add_subplot(nr_rows, nr_cols, i + 1)
|
| 315 |
+
if inference_fn:
|
| 316 |
+
digit_title = f"infer: {np.argmax(inferred_labels[i])}\n"
|
| 317 |
+
else:
|
| 318 |
+
digit_title = ""
|
| 319 |
+
digit_title += f"label: {np.argmax(labels[i])}"
|
| 320 |
+
digit.set_title(digit_title)
|
| 321 |
+
plt.imshow(
|
| 322 |
+
(np.reshape(image, (28, 28)) * 255).astype(np.uint8),
|
| 323 |
+
interpolation="nearest")
|
| 324 |
+
plt.show()
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Defines a helper function for creating a SavedModel from a jax2tf trained model.
|
| 15 |
+
|
| 16 |
+
This has been tested with TensorFlow Hub, TensorFlow JavaScript,
|
| 17 |
+
and TensorFlow Serving.
|
| 18 |
+
|
| 19 |
+
Note that the code in this file is provided only as an example. The functions
|
| 20 |
+
generated by `jax2tf.convert` are standard TensorFlow functions and you can
|
| 21 |
+
save them in a SavedModel using standard TensorFlow code. This decoupling
|
| 22 |
+
of jax2tf from SavedModel is important, because it allows the user to have full
|
| 23 |
+
control over what metadata is saved in the SavedModel. Please copy and
|
| 24 |
+
customize this function as needed.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
from collections.abc import Sequence
|
| 30 |
+
from typing import Any, Callable, Optional, Union
|
| 31 |
+
|
| 32 |
+
from jax.experimental import jax2tf
|
| 33 |
+
import tensorflow as tf
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def convert_and_save_model(
|
| 37 |
+
jax_fn: Callable[[Any, Any], Any],
|
| 38 |
+
params,
|
| 39 |
+
model_dir: str,
|
| 40 |
+
*,
|
| 41 |
+
input_signatures: Sequence[tf.TensorSpec],
|
| 42 |
+
polymorphic_shapes: str | None = None,
|
| 43 |
+
with_gradient: bool = False,
|
| 44 |
+
enable_xla: bool = True,
|
| 45 |
+
compile_model: bool = True,
|
| 46 |
+
saved_model_options: tf.saved_model.SaveOptions | None = None):
|
| 47 |
+
"""Convert a JAX function and saves a SavedModel.
|
| 48 |
+
|
| 49 |
+
This is an example, we do not promise backwards compatibility for this code.
|
| 50 |
+
For serious uses, please copy and expand it as needed (see note at the top
|
| 51 |
+
of the module).
|
| 52 |
+
|
| 53 |
+
Use this function if you have a trained ML model that has both a prediction
|
| 54 |
+
function and trained parameters, which you want to save separately from the
|
| 55 |
+
function graph as variables (e.g., to avoid limits on the size of the
|
| 56 |
+
GraphDef, or to enable fine-tuning.) If you don't have such parameters,
|
| 57 |
+
you can still use this library function but probably don't need it
|
| 58 |
+
(see jax2tf/README.md for some simple examples).
|
| 59 |
+
|
| 60 |
+
In order to use this wrapper you must first convert your model to a function
|
| 61 |
+
with two arguments: the parameters and the input on which you want to do
|
| 62 |
+
inference. Both arguments may be np.ndarray or (nested)
|
| 63 |
+
tuples/lists/dictionaries thereof.
|
| 64 |
+
|
| 65 |
+
See the README.md for a discussion of how to prepare Flax and Haiku models.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
jax_fn: a JAX function taking two arguments, the parameters and the inputs.
|
| 69 |
+
Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
|
| 70 |
+
params: the parameters, to be used as first argument for `jax_fn`. These
|
| 71 |
+
must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
|
| 72 |
+
saved as the variables of the SavedModel.
|
| 73 |
+
model_dir: the directory where the model should be saved.
|
| 74 |
+
input_signatures: the input signatures for the second argument of `jax_fn`
|
| 75 |
+
(the input). A signature must be a `tensorflow.TensorSpec` instance, or a
|
| 76 |
+
(nested) tuple/list/dictionary thereof with a structure matching the
|
| 77 |
+
second argument of `jax_fn`. The first input_signature will be saved as
|
| 78 |
+
the default serving signature. The additional signatures will be used
|
| 79 |
+
only to ensure that the `jax_fn` is traced and converted to TF for the
|
| 80 |
+
corresponding input shapes.
|
| 81 |
+
with_gradient: the value to use for the `with_gradient` parameter for
|
| 82 |
+
`jax2tf.convert`.
|
| 83 |
+
enable_xla: the value to use for the `enable_xla` parameter for
|
| 84 |
+
`jax2tf.convert`.
|
| 85 |
+
compile_model: use TensorFlow jit_compiler on the SavedModel. This
|
| 86 |
+
is needed if the SavedModel will be used for TensorFlow serving.
|
| 87 |
+
polymorphic_shapes: if given then it will be used as the
|
| 88 |
+
`polymorphic_shapes` argument to jax2tf.convert for the second parameter of
|
| 89 |
+
`jax_fn`. In this case, a single `input_signatures` is supported, and
|
| 90 |
+
should have `None` in the polymorphic dimensions.
|
| 91 |
+
saved_model_options: options to pass to savedmodel.save.
|
| 92 |
+
"""
|
| 93 |
+
if not input_signatures:
|
| 94 |
+
raise ValueError("At least one input_signature must be given")
|
| 95 |
+
if polymorphic_shapes is not None:
|
| 96 |
+
if len(input_signatures) > 1:
|
| 97 |
+
raise ValueError("For shape-polymorphic conversion a single "
|
| 98 |
+
"input_signature is supported.")
|
| 99 |
+
tf_fn = jax2tf.convert(
|
| 100 |
+
jax_fn,
|
| 101 |
+
with_gradient=with_gradient,
|
| 102 |
+
polymorphic_shapes=[None, polymorphic_shapes],
|
| 103 |
+
enable_xla=enable_xla)
|
| 104 |
+
|
| 105 |
+
# Create tf.Variables for the parameters. If you want more useful variable
|
| 106 |
+
# names, you can use `tree.map_structure_with_path` from the `dm-tree` package
|
| 107 |
+
param_vars = tf.nest.map_structure(
|
| 108 |
+
lambda param: tf.Variable(param, trainable=with_gradient),
|
| 109 |
+
params)
|
| 110 |
+
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
|
| 111 |
+
autograph=False,
|
| 112 |
+
jit_compile=compile_model)
|
| 113 |
+
|
| 114 |
+
signatures = {}
|
| 115 |
+
# This signature is needed for TensorFlow Serving use.
|
| 116 |
+
signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
| 117 |
+
tf_graph.get_concrete_function(input_signatures[0])
|
| 118 |
+
for input_signature in input_signatures[1:]:
|
| 119 |
+
# If there are more signatures, trace and cache a TF function for each one
|
| 120 |
+
tf_graph.get_concrete_function(input_signature)
|
| 121 |
+
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
|
| 122 |
+
if with_gradient:
|
| 123 |
+
if not saved_model_options:
|
| 124 |
+
saved_model_options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
|
| 125 |
+
else:
|
| 126 |
+
saved_model_options.experimental_custom_gradients = True
|
| 127 |
+
tf.saved_model.save(wrapper, model_dir, signatures=signatures,
|
| 128 |
+
options=saved_model_options)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class _ReusableSavedModelWrapper(tf.train.Checkpoint):
|
| 132 |
+
"""Wraps a function and its parameters for saving to a SavedModel.
|
| 133 |
+
|
| 134 |
+
Implements the interface described at
|
| 135 |
+
https://www.tensorflow.org/hub/reusable_saved_models.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, tf_graph, param_vars):
|
| 139 |
+
"""Args:
|
| 140 |
+
|
| 141 |
+
tf_graph: a tf.function taking one argument (the inputs), which can be
|
| 142 |
+
be tuples/lists/dictionaries of np.ndarray or tensors. The function
|
| 143 |
+
may have references to the tf.Variables in `param_vars`.
|
| 144 |
+
param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
|
| 145 |
+
to be saved as the variables of the SavedModel.
|
| 146 |
+
"""
|
| 147 |
+
super().__init__()
|
| 148 |
+
# Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
|
| 149 |
+
self.variables = tf.nest.flatten(param_vars)
|
| 150 |
+
self.trainable_variables = [v for v in self.variables if v.trainable]
|
| 151 |
+
# If you intend to prescribe regularization terms for users of the model,
|
| 152 |
+
# add them as @tf.functions with no inputs to this list. Else drop this.
|
| 153 |
+
self.regularization_losses = []
|
| 154 |
+
self.__call__ = tf_graph
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Demonstrates training models and saving the result as a SavedModel.
|
| 15 |
+
|
| 16 |
+
By default, uses a pure JAX implementation of MNIST. There are flags to choose
|
| 17 |
+
a Flax CNN version of MNIST, or to skip the training and just test a
|
| 18 |
+
previously saved SavedModel. It is possible to save a batch-polymorphic
|
| 19 |
+
version of the model, or a model prepared for specific batch sizes.
|
| 20 |
+
|
| 21 |
+
Try --help to see all flags.
|
| 22 |
+
|
| 23 |
+
This file is used both as an executable, and as a library in two other examples.
|
| 24 |
+
See discussion in README.md.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import logging
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
from absl import app
|
| 31 |
+
from absl import flags
|
| 32 |
+
|
| 33 |
+
from jax.experimental.jax2tf.examples import mnist_lib
|
| 34 |
+
from jax.experimental.jax2tf.examples import saved_model_lib
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import tensorflow as tf
|
| 38 |
+
import tensorflow_datasets as tfds # type: ignore
|
| 39 |
+
|
| 40 |
+
_MODEL = flags.DEFINE_enum(
|
| 41 |
+
"model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
|
| 42 |
+
"Which model to use.")
|
| 43 |
+
_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
|
| 44 |
+
("The model should include the classifier layer, or just "
|
| 45 |
+
"the last layer of logits. Set this to False when you "
|
| 46 |
+
"want to reuse the classifier-less model in a larger "
|
| 47 |
+
"model. See keras_reuse_main.py and README.md."))
|
| 48 |
+
_MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
|
| 49 |
+
"Path under which to save the SavedModel.")
|
| 50 |
+
_MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
|
| 51 |
+
("The version number for the SavedModel. Needed for "
|
| 52 |
+
"serving, larger versions will take precedence"),
|
| 53 |
+
lower_bound=1)
|
| 54 |
+
_SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
|
| 55 |
+
"For what batch size to prepare the serving signature. "
|
| 56 |
+
"Use -1 for converting and saving with batch polymorphism.")
|
| 57 |
+
flags.register_validator(
|
| 58 |
+
"serving_batch_size",
|
| 59 |
+
lambda serving_batch_size: serving_batch_size > 0
|
| 60 |
+
or serving_batch_size == -1,
|
| 61 |
+
message="--serving_batch_size must be either -1 or a positive integer.",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
_NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
|
| 65 |
+
"For how many epochs to train.",
|
| 66 |
+
lower_bound=1)
|
| 67 |
+
_GENERATE_MODEL = flags.DEFINE_boolean(
|
| 68 |
+
"generate_model", True,
|
| 69 |
+
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
| 70 |
+
_COMPILE_MODEL = flags.DEFINE_boolean(
|
| 71 |
+
"compile_model", True,
|
| 72 |
+
"Enable TensorFlow jit_compiler for the SavedModel. This is "
|
| 73 |
+
"necessary if you want to use the model for TensorFlow serving.")
|
| 74 |
+
_SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
|
| 75 |
+
"Show details of saved SavedModel.")
|
| 76 |
+
SHOW_IMAGES = flags.DEFINE_boolean(
|
| 77 |
+
"show_images", False,
|
| 78 |
+
"Plot some sample images with labels and inference results.")
|
| 79 |
+
_TEST_SAVEDMODEL = flags.DEFINE_boolean(
|
| 80 |
+
"test_savedmodel", True,
|
| 81 |
+
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def train_and_save():
|
| 85 |
+
logging.info("Loading the MNIST TensorFlow dataset")
|
| 86 |
+
train_ds = mnist_lib.load_mnist(
|
| 87 |
+
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
| 88 |
+
test_ds = mnist_lib.load_mnist(
|
| 89 |
+
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
| 90 |
+
|
| 91 |
+
if SHOW_IMAGES.value:
|
| 92 |
+
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
|
| 93 |
+
|
| 94 |
+
the_model_class = pick_model_class()
|
| 95 |
+
model_dir = savedmodel_dir(with_version=True)
|
| 96 |
+
|
| 97 |
+
if _GENERATE_MODEL.value:
|
| 98 |
+
model_descr = model_description()
|
| 99 |
+
logging.info("Generating model for %s", model_descr)
|
| 100 |
+
(predict_fn, predict_params) = the_model_class.train(
|
| 101 |
+
train_ds,
|
| 102 |
+
test_ds,
|
| 103 |
+
num_epochs=_NUM_EPOCHS.value,
|
| 104 |
+
with_classifier=_MODEL_CLASSIFIER_LAYER.value)
|
| 105 |
+
|
| 106 |
+
if _SERVING_BATCH_SIZE.value == -1:
|
| 107 |
+
# Batch-polymorphic SavedModel
|
| 108 |
+
input_signatures = [
|
| 109 |
+
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
|
| 110 |
+
]
|
| 111 |
+
polymorphic_shapes = "(batch, ...)"
|
| 112 |
+
else:
|
| 113 |
+
input_signatures = [
|
| 114 |
+
# The first one will be the serving signature
|
| 115 |
+
tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
|
| 116 |
+
tf.float32),
|
| 117 |
+
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
|
| 118 |
+
tf.float32),
|
| 119 |
+
tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape,
|
| 120 |
+
tf.float32),
|
| 121 |
+
]
|
| 122 |
+
polymorphic_shapes = None
|
| 123 |
+
|
| 124 |
+
logging.info("Saving model for %s", model_descr)
|
| 125 |
+
saved_model_lib.convert_and_save_model(
|
| 126 |
+
predict_fn,
|
| 127 |
+
predict_params,
|
| 128 |
+
model_dir,
|
| 129 |
+
with_gradient=True,
|
| 130 |
+
input_signatures=input_signatures,
|
| 131 |
+
polymorphic_shapes=polymorphic_shapes,
|
| 132 |
+
compile_model=_COMPILE_MODEL.value)
|
| 133 |
+
|
| 134 |
+
if _TEST_SAVEDMODEL.value:
|
| 135 |
+
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
|
| 136 |
+
with tf.device(tf_accelerator):
|
| 137 |
+
logging.info("Testing savedmodel")
|
| 138 |
+
pure_restored_model = tf.saved_model.load(model_dir)
|
| 139 |
+
|
| 140 |
+
if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
|
| 141 |
+
mnist_lib.plot_images(
|
| 142 |
+
test_ds,
|
| 143 |
+
1,
|
| 144 |
+
5,
|
| 145 |
+
f"Inference results for {model_descr}",
|
| 146 |
+
inference_fn=pure_restored_model)
|
| 147 |
+
|
| 148 |
+
test_input = np.ones(
|
| 149 |
+
(mnist_lib.test_batch_size,) + mnist_lib.input_shape,
|
| 150 |
+
dtype=np.float32)
|
| 151 |
+
np.testing.assert_allclose(
|
| 152 |
+
pure_restored_model(tf.convert_to_tensor(test_input)),
|
| 153 |
+
predict_fn(predict_params, test_input), **tolerances)
|
| 154 |
+
|
| 155 |
+
if _SHOW_MODEL.value:
|
| 156 |
+
def print_model(model_dir: str):
|
| 157 |
+
cmd = f"saved_model_cli show --all --dir {model_dir}"
|
| 158 |
+
print(cmd)
|
| 159 |
+
os.system(cmd)
|
| 160 |
+
|
| 161 |
+
print_model(model_dir)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def pick_model_class():
|
| 165 |
+
"""Picks one of PureJaxMNIST or FlaxMNIST."""
|
| 166 |
+
if _MODEL.value == "mnist_pure_jax":
|
| 167 |
+
return mnist_lib.PureJaxMNIST
|
| 168 |
+
elif _MODEL.value == "mnist_flax":
|
| 169 |
+
return mnist_lib.FlaxMNIST
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(f"Unrecognized model: {_MODEL.value}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def model_description() -> str:
|
| 175 |
+
"""A short description of the picked model."""
|
| 176 |
+
res = pick_model_class().name
|
| 177 |
+
if not _MODEL_CLASSIFIER_LAYER.value:
|
| 178 |
+
res += " (features_only)"
|
| 179 |
+
return res
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def savedmodel_dir(with_version: bool = True) -> str:
|
| 183 |
+
"""The directory where we save the SavedModel."""
|
| 184 |
+
model_dir = os.path.join(
|
| 185 |
+
_MODEL_PATH.value,
|
| 186 |
+
_MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
|
| 187 |
+
)
|
| 188 |
+
if with_version:
|
| 189 |
+
model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
|
| 190 |
+
return model_dir
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def tf_accelerator_and_tolerances():
|
| 194 |
+
"""Picks the TF accelerator to use and the tolerances for numerical checks."""
|
| 195 |
+
tf_accelerator = (tf.config.list_logical_devices("TPU") +
|
| 196 |
+
tf.config.list_logical_devices("GPU") +
|
| 197 |
+
tf.config.list_logical_devices("CPU"))[0]
|
| 198 |
+
logging.info("Using tf_accelerator = %s", tf_accelerator)
|
| 199 |
+
if tf_accelerator.device_type == "TPU":
|
| 200 |
+
tolerances = dict(atol=1e-6, rtol=1e-6)
|
| 201 |
+
elif tf_accelerator.device_type == "GPU":
|
| 202 |
+
tolerances = dict(atol=1e-6, rtol=1e-4)
|
| 203 |
+
elif tf_accelerator.device_type == "CPU":
|
| 204 |
+
tolerances = dict(atol=1e-5, rtol=1e-5)
|
| 205 |
+
logging.info("Using tolerances %s", tolerances)
|
| 206 |
+
return tf_accelerator, tolerances
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
app.run(lambda _: train_and_save())
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main_test.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Tests for mnist_lib, saved_model_lib, saved_model_main."""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from absl import flags
|
| 18 |
+
from absl.testing import absltest
|
| 19 |
+
from absl.testing import parameterized
|
| 20 |
+
|
| 21 |
+
from jax._src import config
|
| 22 |
+
from jax._src import test_util as jtu
|
| 23 |
+
|
| 24 |
+
from jax.experimental.jax2tf.examples import saved_model_main
|
| 25 |
+
from jax.experimental.jax2tf.tests import tf_test_util
|
| 26 |
+
|
| 27 |
+
config.parse_flags_with_absl()
|
| 28 |
+
FLAGS = flags.FLAGS
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SavedModelMainTest(tf_test_util.JaxToTfTestCase):
|
| 32 |
+
|
| 33 |
+
def setUp(self):
|
| 34 |
+
super().setUp()
|
| 35 |
+
FLAGS.model_path = os.path.join(absltest.get_default_test_tmpdir(),
|
| 36 |
+
"saved_models")
|
| 37 |
+
FLAGS.num_epochs = 1
|
| 38 |
+
FLAGS.test_savedmodel = True
|
| 39 |
+
FLAGS.mock_data = True
|
| 40 |
+
|
| 41 |
+
@parameterized.named_parameters(
|
| 42 |
+
dict(
|
| 43 |
+
testcase_name=f"_{model}_batch={serving_batch_size}",
|
| 44 |
+
model=model,
|
| 45 |
+
serving_batch_size=serving_batch_size)
|
| 46 |
+
for model in ["mnist_pure_jax", "mnist_flax"]
|
| 47 |
+
for serving_batch_size in [1, -1])
|
| 48 |
+
def test_train_and_save_full(self,
|
| 49 |
+
model="mnist_flax",
|
| 50 |
+
serving_batch_size=-1):
|
| 51 |
+
if (serving_batch_size == -1 and
|
| 52 |
+
config.jax2tf_default_native_serialization.value and
|
| 53 |
+
not config.dynamic_shapes.value):
|
| 54 |
+
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
|
| 55 |
+
FLAGS.model = model
|
| 56 |
+
FLAGS.model_classifier_layer = True
|
| 57 |
+
FLAGS.serving_batch_size = serving_batch_size
|
| 58 |
+
saved_model_main.train_and_save()
|
| 59 |
+
|
| 60 |
+
@parameterized.named_parameters(
|
| 61 |
+
dict(testcase_name=f"_{model}", model=model)
|
| 62 |
+
for model in ["mnist_pure_jax", "mnist_flax"])
|
| 63 |
+
def test_train_and_save_features(self, model="mnist_flax"):
|
| 64 |
+
FLAGS.model = model
|
| 65 |
+
FLAGS.model_classifier_layer = False
|
| 66 |
+
saved_model_main.train_and_save()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/model_server_request.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Demonstrates using jax2tf with TensorFlow model server.
|
| 15 |
+
|
| 16 |
+
See README.md for instructions.
|
| 17 |
+
"""
|
| 18 |
+
import grpc # type: ignore
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import requests
|
| 22 |
+
|
| 23 |
+
from absl import app
|
| 24 |
+
from absl import flags
|
| 25 |
+
|
| 26 |
+
from jax.experimental.jax2tf.examples import mnist_lib
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import tensorflow as tf
|
| 30 |
+
import tensorflow_datasets as tfds # type: ignore[import-not-found]
|
| 31 |
+
from tensorflow_serving.apis import predict_pb2 # type: ignore[import-not-found]
|
| 32 |
+
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
_USE_GRPC = flags.DEFINE_boolean(
|
| 36 |
+
"use_grpc", True,
|
| 37 |
+
"Use the gRPC API (default), or the HTTP REST API.")
|
| 38 |
+
|
| 39 |
+
_MODEL_SPEC_NAME = flags.DEFINE_string(
|
| 40 |
+
"model_spec_name", "",
|
| 41 |
+
"The name you used to export your model to model server (e.g., mnist_flax).")
|
| 42 |
+
|
| 43 |
+
_PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
|
| 44 |
+
"prediction_service_addr",
|
| 45 |
+
"localhost:8500",
|
| 46 |
+
"Stubby endpoint for the prediction service. If you serve your model "
|
| 47 |
+
"locally using TensorFlow model server, then you can use \"localhost:8500\""
|
| 48 |
+
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
|
| 49 |
+
|
| 50 |
+
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
| 51 |
+
"serving_batch_size",
|
| 52 |
+
1,
|
| 53 |
+
"Batch size for the serving request. Must match the "
|
| 54 |
+
"batch size at which the model was saved. Must divide "
|
| 55 |
+
"--count_images",
|
| 56 |
+
lower_bound=1,
|
| 57 |
+
)
|
| 58 |
+
_COUNT_IMAGES = flags.DEFINE_integer(
|
| 59 |
+
"count_images", 16, "How many images to test.", lower_bound=1
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def serving_call_mnist(images):
|
| 64 |
+
"""Send an RPC or REST request to the model server.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
images: A numpy.ndarray of shape [B, 28, 28, 1] with the batch of images to
|
| 68 |
+
perform inference on.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
|
| 72 |
+
"""
|
| 73 |
+
if _USE_GRPC.value:
|
| 74 |
+
channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
|
| 75 |
+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
| 76 |
+
|
| 77 |
+
request = predict_pb2.PredictRequest()
|
| 78 |
+
request.model_spec.name = _MODEL_SPEC_NAME.value
|
| 79 |
+
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
| 80 |
+
# You can see the name of the input ("inputs") in the SavedModel dump.
|
| 81 |
+
request.inputs["inputs"].CopyFrom(
|
| 82 |
+
tf.make_tensor_proto(images, dtype=images.dtype, shape=images.shape))
|
| 83 |
+
response = stub.Predict(request)
|
| 84 |
+
# We could also use response.outputs["output_0"], where "output_0" is the
|
| 85 |
+
# name of the output (which you can see in the SavedModel dump.)
|
| 86 |
+
# Alternatively, we just get the first output.
|
| 87 |
+
outputs, = response.outputs.values()
|
| 88 |
+
return tf.make_ndarray(outputs)
|
| 89 |
+
else:
|
| 90 |
+
# Use the HTTP REST api
|
| 91 |
+
images_json = json.dumps(images.tolist())
|
| 92 |
+
# You can see the name of the input ("inputs") in the SavedModel dump.
|
| 93 |
+
data = f'{{"inputs": {images_json}}}'
|
| 94 |
+
predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
|
| 95 |
+
response = requests.post(predict_url, data=data)
|
| 96 |
+
if response.status_code != 200:
|
| 97 |
+
msg = (f"Received error response {response.status_code} from model "
|
| 98 |
+
f"server: {response.text}")
|
| 99 |
+
raise ValueError(msg)
|
| 100 |
+
outputs = response.json()["outputs"]
|
| 101 |
+
return np.array(outputs)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def main(_):
|
| 105 |
+
if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
|
| 106 |
+
raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
|
| 107 |
+
"multiple of "
|
| 108 |
+
f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
|
| 109 |
+
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
|
| 110 |
+
batch_size=_SERVING_BATCH_SIZE.value)
|
| 111 |
+
images_and_labels = tfds.as_numpy(test_ds.take(
|
| 112 |
+
_COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
|
| 113 |
+
|
| 114 |
+
accurate_count = 0
|
| 115 |
+
for batch_idx, (images, labels) in enumerate(images_and_labels):
|
| 116 |
+
predictions_one_hot = serving_call_mnist(images)
|
| 117 |
+
predictions_digit = np.argmax(predictions_one_hot, axis=1)
|
| 118 |
+
labels_digit = np.argmax(labels, axis=1)
|
| 119 |
+
accurate_count += np.sum(labels_digit == predictions_digit)
|
| 120 |
+
running_accuracy = (
|
| 121 |
+
100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
|
| 122 |
+
logging.info(
|
| 123 |
+
" predicted digits = %s labels %s. Running accuracy %.3f%%",
|
| 124 |
+
predictions_digit, labels_digit, running_accuracy)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
app.run(main)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/impl_no_xla.py
ADDED
|
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Workarounds for jax2tf transforms when XLA is not linked in."""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import builtins
|
| 19 |
+
from collections.abc import Sequence
|
| 20 |
+
import dataclasses
|
| 21 |
+
from functools import partial, wraps
|
| 22 |
+
import math
|
| 23 |
+
import string
|
| 24 |
+
from typing import Any, Callable, Optional
|
| 25 |
+
|
| 26 |
+
from jax._src import core
|
| 27 |
+
from jax import lax
|
| 28 |
+
from jax._src.lax import slicing as lax_slicing
|
| 29 |
+
from jax._src import dtypes
|
| 30 |
+
from jax._src import util
|
| 31 |
+
|
| 32 |
+
from jax.experimental.jax2tf import jax2tf
|
| 33 |
+
|
| 34 |
+
import numpy as np
|
| 35 |
+
import tensorflow as tf
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Implementation rules for primitives when XLA is not linked in. These
|
| 39 |
+
# implementations are workarounds, making use of TF ops that do work when XLA is
|
| 40 |
+
# not linked in. They are only used when the argument `enable_xla=False` when
|
| 41 |
+
# calling jax2tf.convert().
|
| 42 |
+
tf_impl_no_xla: dict[core.Primitive, Callable[..., Any]] = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
TfVal = Any
|
| 46 |
+
DType = Any
|
| 47 |
+
PrecisionType = Any
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _error(primitive_name: str, suffix_msg: str = "") -> Exception:
|
| 51 |
+
msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
|
| 52 |
+
if suffix_msg:
|
| 53 |
+
msg += (f" {suffix_msg} - See source code for the precise conditions under "
|
| 54 |
+
"which it can be converted without XLA.")
|
| 55 |
+
return NotImplementedError(msg)
|
| 56 |
+
|
| 57 |
+
_conv_error = lambda msg: _error("conv_general_dilated", msg)
|
| 58 |
+
_reduce_error = lambda msg: _error("reduce_window", msg)
|
| 59 |
+
_scatter_error = lambda msg: _error("scatter_(update/add/multiply/min/max)", msg
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def _unimplemented(name):
|
| 63 |
+
|
| 64 |
+
def op(*arg, **kwargs):
|
| 65 |
+
raise _error(name)
|
| 66 |
+
|
| 67 |
+
return op
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# TODO(marcvanzee): Remove this function and use `tf.math.invert_permutation`
|
| 71 |
+
# once it is implemented by TFjs:
|
| 72 |
+
# https://github.com/tensorflow/tfjs/issues/6395.
|
| 73 |
+
def _invert_permutation(perm):
|
| 74 |
+
return tuple(perm.index(i) for i in range(len(perm)))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> tuple[TfVal, core.Shape]:
|
| 78 |
+
"""Computes transposition of x and its shape.
|
| 79 |
+
|
| 80 |
+
x_shape matches x.shape in the known dimensions, and it has dimension
|
| 81 |
+
polynomials elsewhere, while x.shape has None.
|
| 82 |
+
"""
|
| 83 |
+
return tf.transpose(x, perm=permutation), tuple(x_shape[i] for i in permutation)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
|
| 87 |
+
rhs, rhs_shape: core.Shape, dimension_numbers):
|
| 88 |
+
"""Transposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
|
| 89 |
+
|
| 90 |
+
The shapes passed in and returned may contain polynomials, and thus may
|
| 91 |
+
be different than lhs.shape and rhs.shape.
|
| 92 |
+
"""
|
| 93 |
+
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
|
| 94 |
+
lhs_perm, rhs_perm, _ = dimension_numbers
|
| 95 |
+
|
| 96 |
+
# TODO(marcvanzee): Consider merging transposes if we want to optimize.
|
| 97 |
+
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
|
| 98 |
+
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW"
|
| 99 |
+
if len(lhs_perm) == 3:
|
| 100 |
+
# For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution
|
| 101 |
+
# logic can be applied downstream.
|
| 102 |
+
lhs = lhs[:, :, :, np.newaxis]
|
| 103 |
+
lhs_shape = tuple(lhs_shape) + (1,)
|
| 104 |
+
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
|
| 105 |
+
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
|
| 106 |
+
|
| 107 |
+
# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
|
| 108 |
+
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, rhs_perm) # rhs --> "OIHW"
|
| 109 |
+
# Handle conv1d case.
|
| 110 |
+
if len(rhs_perm) == 3:
|
| 111 |
+
rhs = rhs[:, :, :, np.newaxis]
|
| 112 |
+
rhs_shape = tuple(rhs_shape) + (1,)
|
| 113 |
+
# For the tf ops, rhs is expected to be "OIHW".
|
| 114 |
+
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
|
| 115 |
+
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
|
| 116 |
+
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
|
| 117 |
+
return lhs, lhs_shape, rhs, rhs_shape
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
|
| 121 |
+
for pad_str in ["VALID", "SAME"]:
|
| 122 |
+
pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str)
|
| 123 |
+
if list(pads) == list(padding):
|
| 124 |
+
return pad_str
|
| 125 |
+
return "EXPLICIT"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _pad_spatial_dims(x, x_shape, padding):
|
| 129 |
+
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
|
| 130 |
+
padding = tuple(padding)
|
| 131 |
+
if len(padding) == len(x_shape) - 2:
|
| 132 |
+
# If necessary, add empty padding for batch and feature dimensions.
|
| 133 |
+
no_pad = ((0, 0),)
|
| 134 |
+
padding = no_pad + padding + no_pad
|
| 135 |
+
x = tf.pad(x, padding)
|
| 136 |
+
assert len(x.shape) == len(padding)
|
| 137 |
+
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
|
| 138 |
+
jax2tf._assert_matching_abstract_shape(x, x_shape)
|
| 139 |
+
return x, x_shape
|
| 140 |
+
|
| 141 |
+
def _check_pad_spatial_dims(x, x_shape, padding):
|
| 142 |
+
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
|
| 143 |
+
padding = tuple(padding)
|
| 144 |
+
if len(padding) == len(x_shape) - 2:
|
| 145 |
+
# If necessary, add empty padding for batch and feature dimensions.
|
| 146 |
+
no_pad = ((0, 0),)
|
| 147 |
+
padding = no_pad + padding + no_pad
|
| 148 |
+
assert len(x.shape) == len(padding)
|
| 149 |
+
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
|
| 150 |
+
return x, x_shape, padding
|
| 151 |
+
|
| 152 |
+
def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding):
|
| 153 |
+
"""Finds the padding type for a transpose convolution."""
|
| 154 |
+
# This is simply checking agreement with lax._conv_transpose_padding.
|
| 155 |
+
is_valid = True
|
| 156 |
+
is_same = True
|
| 157 |
+
if not len(kernel_sdims) == len(lhs_dilation) == len(padding):
|
| 158 |
+
raise ValueError(f'Found different lengths for '
|
| 159 |
+
f'kernel_sdims ({kernel_sdims}), '
|
| 160 |
+
f'lhs_dilation ({lhs_dilation}), '
|
| 161 |
+
f'and padding ({padding}).')
|
| 162 |
+
for k, s, (begin, end) in zip(kernel_sdims, lhs_dilation, padding):
|
| 163 |
+
# Check for VALID padding.
|
| 164 |
+
pad_len_valid = k + s - 2 + builtins.max(k - s, 0)
|
| 165 |
+
pad_a = k - 1
|
| 166 |
+
pad_b = pad_len_valid - pad_a
|
| 167 |
+
if begin != pad_a or end != pad_b:
|
| 168 |
+
is_valid = False
|
| 169 |
+
|
| 170 |
+
# Check for SAME padding.
|
| 171 |
+
pad_len_same = k + s - 2
|
| 172 |
+
if s > k - 1:
|
| 173 |
+
pad_a = k - 1
|
| 174 |
+
else:
|
| 175 |
+
pad_a = int(np.ceil(pad_len_same / 2))
|
| 176 |
+
pad_b = pad_len_same - pad_a
|
| 177 |
+
if begin != pad_a or end != pad_b:
|
| 178 |
+
is_same = False
|
| 179 |
+
|
| 180 |
+
if is_valid:
|
| 181 |
+
return 'VALID'
|
| 182 |
+
elif is_same:
|
| 183 |
+
return 'SAME'
|
| 184 |
+
raise ValueError('Transpose convolution padding mode must be '
|
| 185 |
+
'`SAME` or `VALID`.')
|
| 186 |
+
|
| 187 |
+
def _validate_spatial_dimensions(lhs: TfVal, lhs_shape: core.Shape,
|
| 188 |
+
rhs: TfVal, rhs_shape: core.Shape):
|
| 189 |
+
"""Check spatial dimension support."""
|
| 190 |
+
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
|
| 191 |
+
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
|
| 192 |
+
|
| 193 |
+
nr_spatial_dimensions = len(lhs_shape) - 2
|
| 194 |
+
# Currently we only support 1D+2D convolutions because it keeps the code
|
| 195 |
+
# relatively simple and covers most cases.
|
| 196 |
+
if nr_spatial_dimensions > 2:
|
| 197 |
+
raise _conv_error(
|
| 198 |
+
"We only support 1D or 2D convolutions, but found "
|
| 199 |
+
f"{nr_spatial_dimensions}.")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _normalize_padding_and_dilations(
|
| 203 |
+
padding, lhs_dilation, rhs_dilation, is_conv1d):
|
| 204 |
+
if is_conv1d:
|
| 205 |
+
lhs_dilation = list(lhs_dilation) + [1]
|
| 206 |
+
rhs_dilation = list(rhs_dilation) + [1]
|
| 207 |
+
# Empty padding in the dummy dimension.
|
| 208 |
+
# Note that when kernel_size=stride=1, padding of (0, 0) is both 'VALID' and
|
| 209 |
+
# 'SAME'. So the inferred padding type will still register according to the
|
| 210 |
+
# first dimension padding.
|
| 211 |
+
padding = list(padding) + [(0, 0)]
|
| 212 |
+
return padding, lhs_dilation, rhs_dilation
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _normalize_window_strides(window_strides):
|
| 216 |
+
"""Ensure window_strides has length 4."""
|
| 217 |
+
# Some TF ops require len(window_strides) == 4 while others do not. We simply
|
| 218 |
+
# ensure it always has len(4).
|
| 219 |
+
if len(window_strides) == 1:
|
| 220 |
+
# This is the Conv1D case. We add a dummy dimension to allow using 2D ops,
|
| 221 |
+
# and use stride=1 on the dummy dimension.
|
| 222 |
+
window_strides = list(window_strides) + [1]
|
| 223 |
+
if len(window_strides) == 2:
|
| 224 |
+
window_strides = [1] + list(window_strides) + [1]
|
| 225 |
+
return window_strides
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _validate_conv_features(
|
| 229 |
+
is_transpose, is_atrous, is_depthwise, feature_group_count,
|
| 230 |
+
batch_group_count, preferred_element_type, lhs_dtype):
|
| 231 |
+
if feature_group_count > 1 and not is_depthwise:
|
| 232 |
+
raise _conv_error("Grouped convolutions are unsupported")
|
| 233 |
+
if (is_depthwise and is_atrous) and not is_transpose:
|
| 234 |
+
# We allow dilated depthwise convolutions.
|
| 235 |
+
pass
|
| 236 |
+
elif [is_depthwise, is_atrous, is_transpose].count(True) > 1:
|
| 237 |
+
raise _conv_error(
|
| 238 |
+
f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) "
|
| 239 |
+
f"and transposed convolutions ({is_transpose})")
|
| 240 |
+
|
| 241 |
+
# We can implement batch grouping when there is a need for it.
|
| 242 |
+
if batch_group_count != 1:
|
| 243 |
+
raise _conv_error("Unimplemented support for batch_group_count != 1 "
|
| 244 |
+
f"(found {batch_group_count})")
|
| 245 |
+
|
| 246 |
+
if (preferred_element_type is not None and
|
| 247 |
+
preferred_element_type != lhs_dtype):
|
| 248 |
+
raise _conv_error("Unimplemented support for preferred_element_type")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _conv_general_dilated(
|
| 252 |
+
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
| 253 |
+
dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int,
|
| 254 |
+
batch_group_count: int,
|
| 255 |
+
precision: tuple[PrecisionType, PrecisionType] | None,
|
| 256 |
+
preferred_element_type: DType | None,
|
| 257 |
+
_in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
|
| 258 |
+
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
| 259 |
+
# In presence of shape polymorphism, lhs.shape and rhs.shape may contain
|
| 260 |
+
# None. The actual dimension polynomial shapes are in _in_avals.
|
| 261 |
+
del precision # Unused arguments.
|
| 262 |
+
lhs_shape, rhs_shape = _in_avals[0].shape, _in_avals[1].shape
|
| 263 |
+
out_shape = _out_aval.shape
|
| 264 |
+
_validate_spatial_dimensions(lhs, lhs_shape, rhs, rhs_shape)
|
| 265 |
+
is_conv1d = len(lhs_shape) - 2 == 1
|
| 266 |
+
|
| 267 |
+
tf_window_strides = _normalize_window_strides(window_strides)
|
| 268 |
+
padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations(
|
| 269 |
+
padding, lhs_dilation, rhs_dilation, is_conv1d)
|
| 270 |
+
|
| 271 |
+
lhs, lhs_shape, rhs, rhs_shape = _transpose_for_tf_conv(lhs, lhs_shape,
|
| 272 |
+
rhs, rhs_shape,
|
| 273 |
+
dimension_numbers)
|
| 274 |
+
in_channels = lhs_shape[-1]
|
| 275 |
+
*rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
|
| 276 |
+
|
| 277 |
+
is_transpose = any(d != 1 for d in lhs_dilation)
|
| 278 |
+
is_atrous = any(d != 1 for d in rhs_dilation)
|
| 279 |
+
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
|
| 280 |
+
_validate_conv_features(is_transpose, is_atrous, is_depthwise,
|
| 281 |
+
feature_group_count, batch_group_count,
|
| 282 |
+
preferred_element_type, lhs.dtype.as_numpy_dtype)
|
| 283 |
+
|
| 284 |
+
rhs_dilated_shape = [
|
| 285 |
+
(k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation)
|
| 286 |
+
]
|
| 287 |
+
output_perm = dimension_numbers[2]
|
| 288 |
+
|
| 289 |
+
if is_transpose:
|
| 290 |
+
padding_type = _conv_transpose_pads_to_padtype(
|
| 291 |
+
rhs_spatial_shapes, lhs_dilation, padding)
|
| 292 |
+
else:
|
| 293 |
+
padding_type = pads_to_padtype(
|
| 294 |
+
lhs_shape[1:3], rhs_dilated_shape, window_strides, padding)
|
| 295 |
+
# We only manually pad if we aren't using a transposed convolutions.
|
| 296 |
+
if padding_type == "EXPLICIT":
|
| 297 |
+
lhs, lhs_shape, padding = _check_pad_spatial_dims(lhs, lhs_shape, padding)
|
| 298 |
+
padding_type = padding
|
| 299 |
+
|
| 300 |
+
if padding_type != "SAME" and any(l < r for l, r in zip(lhs_shape[1:3], rhs_dilated_shape)):
|
| 301 |
+
# If the input shape is smaller than the filter shape in a spatial dimension,
|
| 302 |
+
# lax returns only zeros while tf.conv2d returns an error.
|
| 303 |
+
# We thus return zeros to make sure the behavior is consistent.
|
| 304 |
+
return tf.broadcast_to(tf.constant(0, dtype=tf.float32),
|
| 305 |
+
jax2tf._eval_shape(out_shape))
|
| 306 |
+
|
| 307 |
+
if is_depthwise:
|
| 308 |
+
# Reshape filter from
|
| 309 |
+
# [filter_height, filter_width, 1, in_channels * channel_multiplier] to
|
| 310 |
+
# [filter_height, filter_width, in_channels, channel_multiplier].
|
| 311 |
+
new_rhs_shape = tuple(rhs_spatial_shapes) + (in_channels,
|
| 312 |
+
rhs_out_channel // in_channels)
|
| 313 |
+
output = tf.nn.depthwise_conv2d(
|
| 314 |
+
input=lhs,
|
| 315 |
+
filter=tf.reshape(rhs, jax2tf._eval_shape(new_rhs_shape)),
|
| 316 |
+
strides=tf_window_strides,
|
| 317 |
+
padding=padding_type,
|
| 318 |
+
dilations=rhs_dilation)
|
| 319 |
+
|
| 320 |
+
elif is_transpose:
|
| 321 |
+
# tf.nn.conv2d_transpose requires a transposed filter.
|
| 322 |
+
rhs_t = tf.reverse(rhs, [0, 1])
|
| 323 |
+
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
|
| 324 |
+
|
| 325 |
+
# We should transpose `out_shape` to "NHWC", which is what TF expects.
|
| 326 |
+
# First transpose to "NCHW".
|
| 327 |
+
if is_conv1d:
|
| 328 |
+
tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,)
|
| 329 |
+
else:
|
| 330 |
+
tf_out_shape = tuple(out_shape[i] for i in output_perm)
|
| 331 |
+
# Then transpose "NCHW" to "NHWC".
|
| 332 |
+
tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1))
|
| 333 |
+
output = tf.nn.conv2d_transpose(
|
| 334 |
+
input=lhs,
|
| 335 |
+
filters=rhs_t,
|
| 336 |
+
output_shape=jax2tf._eval_shape(tf_out_shape),
|
| 337 |
+
strides=lhs_dilation,
|
| 338 |
+
padding=padding_type)
|
| 339 |
+
|
| 340 |
+
else:
|
| 341 |
+
output = tf.nn.conv2d(
|
| 342 |
+
input=lhs,
|
| 343 |
+
filters=rhs,
|
| 344 |
+
strides=tf_window_strides,
|
| 345 |
+
padding=padding_type,
|
| 346 |
+
dilations=rhs_dilation)
|
| 347 |
+
|
| 348 |
+
# TF outputs in format "NHWC", so convert to "NCHW", which is lax's default
|
| 349 |
+
# format.
|
| 350 |
+
output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW"
|
| 351 |
+
if is_conv1d:
|
| 352 |
+
output = output[:, :, :, 0]
|
| 353 |
+
# To determine the right permutation, we compute the inverse permutation of
|
| 354 |
+
# `output_perm`, so that when `output_perm` is applied to `output`, we obtain
|
| 355 |
+
# the outpt in NCHW format.
|
| 356 |
+
inverse_perm = _invert_permutation(output_perm)
|
| 357 |
+
output = tf.transpose(output, inverse_perm) # "NCHW" -> desired output shape.
|
| 358 |
+
return output
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _dot_general(lhs, rhs, *, dimension_numbers,
|
| 365 |
+
precision: tuple[PrecisionType, PrecisionType] | None,
|
| 366 |
+
preferred_element_type: DType | None,
|
| 367 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 368 |
+
_out_aval: core.ShapedArray):
|
| 369 |
+
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
| 370 |
+
# Unused arguments.
|
| 371 |
+
del precision
|
| 372 |
+
del preferred_element_type
|
| 373 |
+
|
| 374 |
+
lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
|
| 375 |
+
lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
|
| 376 |
+
|
| 377 |
+
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
| 378 |
+
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
|
| 379 |
+
|
| 380 |
+
# This condition ensures that:
|
| 381 |
+
# 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
|
| 382 |
+
# not strictly necessary, but we would have to reshape the array if that
|
| 383 |
+
# were not the case;
|
| 384 |
+
# 2) lhs and rhs have the same number of dimensions +/- 1
|
| 385 |
+
# 3) the number of non-batch dimensions in both tensors is either 1 or 2
|
| 386 |
+
# 4) the contracting dimensions are consistent with those of a classic
|
| 387 |
+
# matrix/matrix, vector/matrix or matrix/vector multiplication.
|
| 388 |
+
if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) and
|
| 389 |
+
lhs_ndim - rhs_ndim in [-1, 0, 1] and
|
| 390 |
+
1 <= lhs_ndim - len(lhs_batch) <= 2 and
|
| 391 |
+
1 <= rhs_ndim - len(rhs_batch) <= 2 and
|
| 392 |
+
lhs_contracting == (len(lhs.shape) - 1,) and
|
| 393 |
+
rhs_contracting == (len(lhs_batch),)):
|
| 394 |
+
# All the inputs to tf.linalg.matmul must have 2 inner dimensions,
|
| 395 |
+
# after their batch dimensions, so we need to expand the dimensions
|
| 396 |
+
# appropriately. We can get to this branch with three combinations of
|
| 397 |
+
# inner shapes:
|
| 398 |
+
# - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
|
| 399 |
+
# - in this case, the resulting inner shape is [a, c];
|
| 400 |
+
# - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
|
| 401 |
+
# - in this case, we need to expand lhs to [1, b], and the resulting
|
| 402 |
+
# shape is [c]. We need to squeeze the result of tf.linalg.matmul
|
| 403 |
+
# as it will have shape [1, c];
|
| 404 |
+
# - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
|
| 405 |
+
# - in this case, we need to expand rhs to [b, 1], and the resulting
|
| 406 |
+
# shape is [a]. We need to squeeze the result of tf.linalg.matmul
|
| 407 |
+
# as it will have shape [a, 1];
|
| 408 |
+
# - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
|
| 409 |
+
# - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
|
| 410 |
+
# and the resulting shape is (). We need to squeeze the result of
|
| 411 |
+
# tf.linalg.matmul as it will have shape [1, 1].
|
| 412 |
+
squeeze_idxs = []
|
| 413 |
+
if lhs_ndim - len(lhs_batch) == 1:
|
| 414 |
+
lhs = tf.expand_dims(lhs, lhs_ndim - 1)
|
| 415 |
+
squeeze_idxs.append(len(lhs.shape) - 2)
|
| 416 |
+
if rhs_ndim - len(rhs_batch) == 1:
|
| 417 |
+
rhs = tf.expand_dims(rhs, rhs_ndim)
|
| 418 |
+
squeeze_idxs.append(len(rhs.shape) - 1)
|
| 419 |
+
result = tf.linalg.matmul(lhs, rhs)
|
| 420 |
+
if len(squeeze_idxs) != 0:
|
| 421 |
+
assert all(result.shape[i] == 1 for i in squeeze_idxs)
|
| 422 |
+
result = tf.squeeze(result, squeeze_idxs)
|
| 423 |
+
return convert_result(result)
|
| 424 |
+
|
| 425 |
+
new_id = iter(string.ascii_letters)
|
| 426 |
+
lhs_axis_ids = [next(new_id) for _ in lhs.shape]
|
| 427 |
+
rhs_axis_ids = [next(new_id) for _ in rhs.shape]
|
| 428 |
+
lhs_out_axis_ids = lhs_axis_ids[:]
|
| 429 |
+
rhs_out_axis_ids = rhs_axis_ids[:]
|
| 430 |
+
|
| 431 |
+
for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
|
| 432 |
+
shared_id = next(new_id)
|
| 433 |
+
lhs_axis_ids[lhs_axis] = shared_id
|
| 434 |
+
rhs_axis_ids[rhs_axis] = shared_id
|
| 435 |
+
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
|
| 436 |
+
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
|
| 437 |
+
|
| 438 |
+
batch_ids = []
|
| 439 |
+
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
|
| 440 |
+
shared_id = next(new_id)
|
| 441 |
+
lhs_axis_ids[lhs_axis] = shared_id
|
| 442 |
+
rhs_axis_ids[rhs_axis] = shared_id
|
| 443 |
+
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
|
| 444 |
+
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
|
| 445 |
+
batch_ids.append(shared_id)
|
| 446 |
+
|
| 447 |
+
not_none = lambda x: x is not None
|
| 448 |
+
out_axis_ids = list(
|
| 449 |
+
filter(not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids))
|
| 450 |
+
assert lhs.dtype == rhs.dtype
|
| 451 |
+
spec = "{},{}->{}".format("".join(lhs_axis_ids), "".join(rhs_axis_ids),
|
| 452 |
+
"".join(out_axis_ids))
|
| 453 |
+
return convert_result(tf.linalg.einsum(spec, lhs, rhs))
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
tf_impl_no_xla[lax.dot_general_p] = _dot_general
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _interior_padding(operand, padding_value, padding_config, operand_shape):
|
| 460 |
+
# Used only when enable_xla=False
|
| 461 |
+
# Applies only the interior padding from the padding_config.
|
| 462 |
+
# We do this somewhat inefficiently, as a scatter.
|
| 463 |
+
# For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
|
| 464 |
+
# f is the dilation factor for the dimension, i.e., 1 + interior_padding.
|
| 465 |
+
# Then we compute the cartesian production of the indices (using broadcast
|
| 466 |
+
# and concat).
|
| 467 |
+
|
| 468 |
+
# We could make this code more complex and do all the padding at once, but
|
| 469 |
+
# we prefer to keep it simple.
|
| 470 |
+
indices_by_dim = []
|
| 471 |
+
indices_shape = operand_shape + (1,)
|
| 472 |
+
output_shape = [] # considering only interior padding
|
| 473 |
+
for d, (dsz, (_, _, i)) in enumerate(zip(operand_shape, padding_config)):
|
| 474 |
+
dilation_factor = i + 1
|
| 475 |
+
output_shape.append(dsz * dilation_factor - i)
|
| 476 |
+
indices = tf.range(dsz) * dilation_factor
|
| 477 |
+
expansion = [None] * (1 + len(operand_shape))
|
| 478 |
+
expansion[d] = slice(None, None, None)
|
| 479 |
+
indices_by_dim.append(tf.broadcast_to(indices[expansion], indices_shape))
|
| 480 |
+
|
| 481 |
+
indices_cartesian = tf.concat(indices_by_dim, axis=len(operand_shape))
|
| 482 |
+
scattered = tf.scatter_nd(indices_cartesian, operand, output_shape)
|
| 483 |
+
# What elements from the output array we use from
|
| 484 |
+
mask = tf.scatter_nd(indices_cartesian, tf.ones_like(operand, dtype=np.bool_),
|
| 485 |
+
output_shape)
|
| 486 |
+
return tf.where(mask, scattered, padding_value)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _pad(operand, padding_value, *, padding_config,
|
| 490 |
+
_in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
|
| 491 |
+
# Do only the interior padding first. This is rarely needed.
|
| 492 |
+
if any(i != 0 for _, _, i in padding_config):
|
| 493 |
+
operand = _interior_padding(operand, padding_value, padding_config,
|
| 494 |
+
jax2tf._eval_shape(_in_avals[0].shape))
|
| 495 |
+
|
| 496 |
+
# Now do the non-negative edge padding. This is the common case, use tf.pad.
|
| 497 |
+
non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0))
|
| 498 |
+
for lo, hi, _ in padding_config]
|
| 499 |
+
operand = tf.pad(
|
| 500 |
+
operand,
|
| 501 |
+
non_negative_padding,
|
| 502 |
+
mode="CONSTANT",
|
| 503 |
+
constant_values=padding_value)
|
| 504 |
+
# Now the negative edge padding (this is also rare)
|
| 505 |
+
if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config):
|
| 506 |
+
output_shape = jax2tf._eval_shape(_out_aval.shape)
|
| 507 |
+
begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config]
|
| 508 |
+
operand = tf.slice(operand, begins, output_shape)
|
| 509 |
+
|
| 510 |
+
return operand
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
tf_impl_no_xla[lax.pad_p] = _pad
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
|
| 517 |
+
index_dtype: DType, _in_avals: Sequence[core.ShapedArray],
|
| 518 |
+
_out_aval: core.ShapedArray):
|
| 519 |
+
# The following is known to diverge from JAX behavior for NaN.
|
| 520 |
+
axis, = axes
|
| 521 |
+
output_type = tf.int32
|
| 522 |
+
if dtypes.iinfo(index_dtype).bits > 32:
|
| 523 |
+
output_type = tf.int64
|
| 524 |
+
# TODO(phawkins): handle axes larger than 2^31.
|
| 525 |
+
fn = tf.math.argmin if is_min else tf.math.argmax
|
| 526 |
+
result = fn(operand, axis=axis, output_type=output_type)
|
| 527 |
+
return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
tf_impl_no_xla[lax.argmin_p] = partial(_argminmax, True)
|
| 531 |
+
tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
|
| 535 |
+
window_dimensions, window_strides,
|
| 536 |
+
base_dilation, window_dilation):
|
| 537 |
+
if computation_name not in ["min", "max", "add"]:
|
| 538 |
+
raise _reduce_error("Reduction function should be either min, max, or add.")
|
| 539 |
+
if computation_name in ["min", "max"] and dtype in [
|
| 540 |
+
tf.bool, tf.uint32, tf.uint64, tf.complex64, tf.complex128
|
| 541 |
+
]:
|
| 542 |
+
raise _reduce_error("Min/max pool does not support operands of type "
|
| 543 |
+
f"{dtype}")
|
| 544 |
+
if computation_name == "min" and dtype in [tf.uint8, tf.uint16]:
|
| 545 |
+
# TODO(marcvanzee): We currently implement min pooling by negating the
|
| 546 |
+
# input, but this doesn't work for uint. We could work around it using
|
| 547 |
+
# tf.math.reduce_min.
|
| 548 |
+
raise _reduce_error(f"Min pool does not support operands of type {dtype}")
|
| 549 |
+
if computation_name == "add" and dtype not in [
|
| 550 |
+
tf.bfloat16,
|
| 551 |
+
tf.float16,
|
| 552 |
+
tf.float32,
|
| 553 |
+
tf.float64,
|
| 554 |
+
tf.int16,
|
| 555 |
+
tf.int32,
|
| 556 |
+
]:
|
| 557 |
+
raise _reduce_error("Add pooling does not support operands of type "
|
| 558 |
+
f"{dtype}")
|
| 559 |
+
|
| 560 |
+
if (len(operand_shape) != len(window_dimensions) != len(window_strides) !=
|
| 561 |
+
len(window_dilation)):
|
| 562 |
+
raise _reduce_error("Input shapes, window dimensions, window stride "
|
| 563 |
+
"dimensions, and window dilation dimensions should "
|
| 564 |
+
"match.")
|
| 565 |
+
|
| 566 |
+
has_only_spatial_dims = True
|
| 567 |
+
if len(operand_shape) > 4:
|
| 568 |
+
raise _reduce_error("Only 1D or 2D input are supported.")
|
| 569 |
+
if len(operand_shape) > 2:
|
| 570 |
+
# operand_shape = (batch, spatial_dims, ..., channel).
|
| 571 |
+
has_only_spatial_dims = False
|
| 572 |
+
|
| 573 |
+
for name, value in [("window_dimensions", window_dimensions),
|
| 574 |
+
("window_strides", window_strides),
|
| 575 |
+
("window_dilation", window_dilation)]:
|
| 576 |
+
if value[0] != value[-1] != 1:
|
| 577 |
+
raise _reduce_error("Only 1D or 2D input are supported, expected "
|
| 578 |
+
f"{name}=(1, spatial_dims, ..., 1), but got "
|
| 579 |
+
f"{value}")
|
| 580 |
+
|
| 581 |
+
if list(base_dilation) != [1] * len(operand_shape):
|
| 582 |
+
# TODO(marcvanzee): Add support for base dilations. We can do this using
|
| 583 |
+
# a scatter on operand.
|
| 584 |
+
raise _reduce_error("Unimplemented support for base dilation.")
|
| 585 |
+
|
| 586 |
+
return has_only_spatial_dims
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def _padding_reduce_window(operand, operand_shape, computation_name,
|
| 590 |
+
window_dimensions, window_strides, padding):
|
| 591 |
+
padding_type = pads_to_padtype(operand_shape, window_dimensions,
|
| 592 |
+
window_strides, padding)
|
| 593 |
+
|
| 594 |
+
# https://github.com/google/jax/issues/11874.
|
| 595 |
+
needs_manual_padding = (
|
| 596 |
+
padding_type == "SAME" and computation_name == "add" and
|
| 597 |
+
window_dimensions != [1] * len(operand_shape))
|
| 598 |
+
|
| 599 |
+
if needs_manual_padding or padding_type == "EXPLICIT":
|
| 600 |
+
operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding)
|
| 601 |
+
padding_type = "VALID"
|
| 602 |
+
|
| 603 |
+
return operand, operand_shape, padding_type
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _reshape_reduce_window(operand, operand_shape, window_dimensions,
|
| 607 |
+
window_strides, window_dilation, *,
|
| 608 |
+
has_only_spatial_dims):
|
| 609 |
+
# Reshape inputs so they are accepted by tf.nn.pool, which expects batch and
|
| 610 |
+
# channel dimensions for operand but not for any of the other inputs.
|
| 611 |
+
if has_only_spatial_dims: # len(operand_shape) <= 2
|
| 612 |
+
# Call eval_shape on a shape that may contain polynomials, otherwise TF does
|
| 613 |
+
# not know what to do with polynomials in the shape.
|
| 614 |
+
operand_shape = jax2tf._eval_shape(operand_shape)
|
| 615 |
+
# Add batch and channel dimensions to operand.
|
| 616 |
+
operand = tf.reshape(operand, (1,) + operand_shape + (1,))
|
| 617 |
+
else:
|
| 618 |
+
# This branch assumes operand.shape = (batch, spatial_dims, ..., channel),
|
| 619 |
+
# and dimensions, strides, dilation are all (1, spatial_values, ..., 1).
|
| 620 |
+
# Input validation for this is done in _validate_reduce_window_inputs.
|
| 621 |
+
window_dimensions = window_dimensions[1:-1]
|
| 622 |
+
window_strides = window_strides[1:-1]
|
| 623 |
+
window_dilation = window_dilation[1:-1]
|
| 624 |
+
|
| 625 |
+
return operand, window_dimensions, window_strides, window_dilation
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
| 629 |
+
base_dilation, window_dilation, computation_name,
|
| 630 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 631 |
+
_out_aval: core.ShapedArray):
|
| 632 |
+
dtype = operand.dtype
|
| 633 |
+
# In presence of shape polymorphism, operand.shape may contain None. The
|
| 634 |
+
# actual dimension polynomial shapes are in _in_avals.
|
| 635 |
+
operand_shape = _in_avals[0].shape
|
| 636 |
+
|
| 637 |
+
# TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to
|
| 638 |
+
# Gather, to simplify function calls.
|
| 639 |
+
has_only_spatial_dims = _validate_reduce_window_inputs(
|
| 640 |
+
operand_shape, computation_name, dtype, window_dimensions, window_strides,
|
| 641 |
+
base_dilation, window_dilation)
|
| 642 |
+
|
| 643 |
+
operand, operand_shape, padding_type = _padding_reduce_window(
|
| 644 |
+
operand, operand_shape, computation_name, window_dimensions,
|
| 645 |
+
window_strides, padding)
|
| 646 |
+
|
| 647 |
+
operand, window_dimensions, window_strides, dilations = _reshape_reduce_window(
|
| 648 |
+
operand,
|
| 649 |
+
operand_shape,
|
| 650 |
+
window_dimensions,
|
| 651 |
+
window_strides,
|
| 652 |
+
window_dilation,
|
| 653 |
+
has_only_spatial_dims=has_only_spatial_dims)
|
| 654 |
+
|
| 655 |
+
def tf_pool(inputs, pooling_type):
|
| 656 |
+
if any(not core.is_constant_shape(s) for s in
|
| 657 |
+
(window_dimensions, window_strides, dilations)):
|
| 658 |
+
raise NotImplementedError(
|
| 659 |
+
f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} "
|
| 660 |
+
f" {window_strides=} {dilations=}")
|
| 661 |
+
# tf.nn.pool() currently does not suport tf.int32 and so we cast back and
|
| 662 |
+
# forth in order to be able to convert.
|
| 663 |
+
if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add":
|
| 664 |
+
original_dtype = inputs.dtype
|
| 665 |
+
inputs = tf.cast(inputs, dtype=tf.float32)
|
| 666 |
+
else:
|
| 667 |
+
original_dtype = None
|
| 668 |
+
result = tf.nn.pool(
|
| 669 |
+
inputs,
|
| 670 |
+
window_shape=window_dimensions,
|
| 671 |
+
pooling_type=pooling_type,
|
| 672 |
+
padding=padding_type,
|
| 673 |
+
strides=window_strides,
|
| 674 |
+
dilations=dilations)
|
| 675 |
+
if original_dtype:
|
| 676 |
+
result = tf.cast(result, dtype=original_dtype)
|
| 677 |
+
|
| 678 |
+
if has_only_spatial_dims:
|
| 679 |
+
# If the input only had spatial dimensions we need to contract the batch
|
| 680 |
+
# and channel dimensions before returning the output.
|
| 681 |
+
result = tf.squeeze(result, [0, -1])
|
| 682 |
+
|
| 683 |
+
jax2tf._assert_matching_abstract_shape(result, _out_aval.shape)
|
| 684 |
+
return result
|
| 685 |
+
|
| 686 |
+
negate = lambda x: tf.multiply(x, tf.constant(-1, dtype))
|
| 687 |
+
if computation_name == "max":
|
| 688 |
+
return tf_pool(operand, "MAX")
|
| 689 |
+
elif computation_name == "min":
|
| 690 |
+
return negate(tf_pool(negate(operand), "MAX"))
|
| 691 |
+
elif computation_name == "add":
|
| 692 |
+
# TODO(marcvanzee): This may give very large deviations on TPU when using
|
| 693 |
+
# floats as inputs. Alternatively, we could implement this using a
|
| 694 |
+
# convolution with an all-1's kernel.
|
| 695 |
+
return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions))
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
| 699 |
+
window_strides, padding, base_dilation, window_dilation,
|
| 700 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 701 |
+
_out_aval: tuple[core.ShapedArray, ...]
|
| 702 |
+
) -> tuple[TfVal, ...]:
|
| 703 |
+
assert len(consts) == 0, "Reduction computation cannot have constants"
|
| 704 |
+
operands, init_values = util.split_list(args, [len(args) // 2])
|
| 705 |
+
|
| 706 |
+
if len(operands) != 1 or len(init_values) != 1:
|
| 707 |
+
raise _reduce_error("jax2tf does not support variadic reduce_window")
|
| 708 |
+
|
| 709 |
+
operand, init_value = operands[0], init_values[0]
|
| 710 |
+
# Infer operation type from jaxpr.
|
| 711 |
+
if (len(jaxpr.eqns) != 1 or
|
| 712 |
+
len(jaxpr.eqns[0].invars) != 2 or
|
| 713 |
+
len(jaxpr.eqns[0].outvars) != 1 or
|
| 714 |
+
jaxpr.eqns[0].primitive.name not in ["min", "max", "add"]):
|
| 715 |
+
raise _reduce_error("Reduction function should be either min, max, or add.")
|
| 716 |
+
|
| 717 |
+
computation_name = jaxpr.eqns[0].primitive.name
|
| 718 |
+
result = _reduce_monoid(operand,
|
| 719 |
+
window_dimensions=window_dimensions,
|
| 720 |
+
window_strides=window_strides,
|
| 721 |
+
padding=padding,
|
| 722 |
+
base_dilation=base_dilation,
|
| 723 |
+
window_dilation=window_dilation,
|
| 724 |
+
computation_name=computation_name,
|
| 725 |
+
_in_avals=(_in_avals[0],), # Don't pass init_value.
|
| 726 |
+
_out_aval=_out_aval[0]) # Returns single value.
|
| 727 |
+
|
| 728 |
+
reduce_fn = {
|
| 729 |
+
"min": tf.minimum,
|
| 730 |
+
"max": tf.maximum,
|
| 731 |
+
"add": tf.add,
|
| 732 |
+
}[computation_name]
|
| 733 |
+
result = reduce_fn(result, init_value)
|
| 734 |
+
|
| 735 |
+
# The output is expected to be wrapped in a tuple, and since we don't use
|
| 736 |
+
# variadic reductions, this tuple always contains a single element.
|
| 737 |
+
return (result,)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
tf_impl_no_xla[lax.reduce_window_min_p] = (
|
| 741 |
+
partial(_reduce_monoid, computation_name="min"))
|
| 742 |
+
tf_impl_no_xla[lax.reduce_window_max_p] = (
|
| 743 |
+
partial(_reduce_monoid, computation_name="max"))
|
| 744 |
+
tf_impl_no_xla[lax.reduce_window_sum_p] = (
|
| 745 |
+
partial(_reduce_monoid, computation_name="add"))
|
| 746 |
+
|
| 747 |
+
tf_impl_no_xla[lax.reduce_window_p] = _reduce_window
|
| 748 |
+
|
| 749 |
+
tf_impl_no_xla[lax.reduce_p] = _unimplemented("reduce")
|
| 750 |
+
|
| 751 |
+
tf_impl_no_xla[lax.select_and_scatter_add_p] = _unimplemented(
|
| 752 |
+
"select_and_scatter_add")
|
| 753 |
+
|
| 754 |
+
tf_impl_no_xla[lax.rng_bit_generator_p] = _unimplemented("rng_bit_generator")
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def _clip(max_indices: Sequence[TfVal], start_indices: Sequence[TfVal],
|
| 758 |
+
slice_sizes: Sequence[TfVal]):
|
| 759 |
+
"""Simulates XLA clipping behavior with TF ops.
|
| 760 |
+
|
| 761 |
+
Various TF ops have different clipping behavior than XLA:
|
| 762 |
+
* If `start_indices` is out-of-bounds, then TF fails but XLA clips the indices
|
| 763 |
+
to
|
| 764 |
+
[0, max_len].
|
| 765 |
+
* If `start_indices + slice_size` is out-of-bounds, then TF fails, but XLA
|
| 766 |
+
adjust
|
| 767 |
+
`start_indices` so that a full slice is returned.
|
| 768 |
+
This function clips the start indices correctly.
|
| 769 |
+
"""
|
| 770 |
+
# We cast both arguments to `tf.clip_by_value` to int32. Otherwise, this
|
| 771 |
+
# function may return uint32 which is not always compatible with TF ops, so
|
| 772 |
+
# this may result in type errors.
|
| 773 |
+
max_start = tf.cast(tf.subtract(max_indices, slice_sizes), dtype=tf.int32)
|
| 774 |
+
return tf.clip_by_value(tf.cast(start_indices, dtype=tf.int32), 0, max_start)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
@dataclasses.dataclass
|
| 778 |
+
class GatherArgs:
|
| 779 |
+
operand: TfVal
|
| 780 |
+
start_indices: TfVal
|
| 781 |
+
dnums: lax.GatherDimensionNumbers
|
| 782 |
+
slice_sizes: TfVal
|
| 783 |
+
op_shape: core.Shape
|
| 784 |
+
start_indices_shape: core.Shape
|
| 785 |
+
out_aval: core.ShapedArray
|
| 786 |
+
|
| 787 |
+
def __post_init__(self):
|
| 788 |
+
assert len(self.op_shape) == len(self.slice_sizes)
|
| 789 |
+
|
| 790 |
+
def __repr__(self):
|
| 791 |
+
return (f"operand shape={self.op_shape}, "
|
| 792 |
+
f"start_indices={self.start_indices}, "
|
| 793 |
+
f"dimension_numbes={self.dnums}, "
|
| 794 |
+
f"slice_sizes={self.slice_sizes}")
|
| 795 |
+
@property
|
| 796 |
+
def batch_dims(self):
|
| 797 |
+
return tuple(x for x in range(len(self.out_aval.shape))
|
| 798 |
+
if x not in self.dnums.offset_dims)
|
| 799 |
+
|
| 800 |
+
def gather_precondition(precondition_fn: Callable[[GatherArgs], None]):
|
| 801 |
+
"""Decorator for specifying a precondition function.
|
| 802 |
+
|
| 803 |
+
This decorator should be put on a function with argument `arg` of type
|
| 804 |
+
`GatherArgs`. It will first call `precondition_fn` with `arg` (which may throw
|
| 805 |
+
an exception), and then call the function it is decorating with `arg` as well.
|
| 806 |
+
"""
|
| 807 |
+
|
| 808 |
+
def decorator(gather_fn: Callable[[GatherArgs], Any]):
|
| 809 |
+
|
| 810 |
+
@wraps(gather_fn)
|
| 811 |
+
def wrapper(args: GatherArgs):
|
| 812 |
+
# Call `precondition_fn`; we assume it may throw an exception.
|
| 813 |
+
precondition_fn(args)
|
| 814 |
+
return gather_fn(args)
|
| 815 |
+
|
| 816 |
+
return wrapper
|
| 817 |
+
|
| 818 |
+
return decorator
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def _pre_gather_for_scalar_indexing(args: GatherArgs):
|
| 822 |
+
"""Returns True if this call to gather represents scalar indexing into arrays.
|
| 823 |
+
|
| 824 |
+
E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
|
| 825 |
+
"""
|
| 826 |
+
# TODO(marcvanzee): Add more assumptions here, because this is currently too
|
| 827 |
+
# permissive.
|
| 828 |
+
if len(args.start_indices_shape) != 1:
|
| 829 |
+
raise ValueError("start_indices shape should be 1")
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
@gather_precondition(_pre_gather_for_scalar_indexing)
|
| 833 |
+
def _gather_for_scalar_indexing(args: GatherArgs):
|
| 834 |
+
"""Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice.
|
| 835 |
+
|
| 836 |
+
E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
|
| 837 |
+
"""
|
| 838 |
+
indices = tf.expand_dims(args.dnums.start_index_map, 1)
|
| 839 |
+
# lax.gather uses an "index map" which maps `start_indices` to the right axes
|
| 840 |
+
# in `operand`. Since tf.strided_slice uses a single array for specifying the
|
| 841 |
+
# start indices, we use a scatter to map the start indices to the right axes.
|
| 842 |
+
op_shape = jax2tf._eval_shape(args.op_shape)
|
| 843 |
+
slice_sizes_tf = jax2tf._eval_shape(args.slice_sizes)
|
| 844 |
+
# TODO(marcvanzee): Consider transposing `operand`, which is probably more
|
| 845 |
+
# optimization friendly.
|
| 846 |
+
begin = tf.scatter_nd(indices, args.start_indices, [len(op_shape)])
|
| 847 |
+
begin = _clip(op_shape, begin, slice_sizes_tf)
|
| 848 |
+
end = slice_sizes_tf + begin
|
| 849 |
+
|
| 850 |
+
# `collapsed_slice_dims` is a tuple of dimensions to collapse, e.g. (0, 2).
|
| 851 |
+
# `tf.strided_slice` expects a binary mask to specify the shrink axes, i.e.,
|
| 852 |
+
# if we want to shrink axis 0 and 2, this corresponds to binary mask 101,
|
| 853 |
+
# which is 5 in decimals. The following line converts the lax representation
|
| 854 |
+
# to the one used by `tf.strided_slice`.
|
| 855 |
+
shrink_mask = sum(2**x for x in args.dnums.collapsed_slice_dims)
|
| 856 |
+
res = tf.strided_slice(args.operand, begin, end, shrink_axis_mask=shrink_mask)
|
| 857 |
+
# Shape inference doesn't work for tf.strided_slice.
|
| 858 |
+
res = jax2tf._ensure_tf_shape_if_dynamic(
|
| 859 |
+
res, jax2tf._aval_to_tf_shape(args.out_aval)
|
| 860 |
+
)
|
| 861 |
+
return res
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def _pre_gather_for_multidim_indexing(args: GatherArgs):
|
| 865 |
+
"""Returns True if this call to gather represents multi-dimensional indexing.
|
| 866 |
+
|
| 867 |
+
E.g., jnp.take(op, [[0], [1]], axis=0).
|
| 868 |
+
Note we currently only support multi-dimensional indexing if the last
|
| 869 |
+
dimension is 1.
|
| 870 |
+
"""
|
| 871 |
+
# Handle only the case when tf.gather argument batch_dims=0.
|
| 872 |
+
# Find axis to match the tf.gather semantics
|
| 873 |
+
# Let I = len(start_indices_shape)
|
| 874 |
+
# let O = len(op_shape)
|
| 875 |
+
# slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:]
|
| 876 |
+
# collapsed_slice_dims == (axis,)
|
| 877 |
+
# start_index_map == (axis,)
|
| 878 |
+
# offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1)
|
| 879 |
+
# We added a trailing dimension of size 1
|
| 880 |
+
op_shape = args.op_shape
|
| 881 |
+
start_index_map = args.dnums.start_index_map
|
| 882 |
+
collapsed_slice_dims = args.dnums.collapsed_slice_dims
|
| 883 |
+
offset_dims = args.dnums.offset_dims
|
| 884 |
+
if not (len(op_shape) >= 1 and len(start_index_map) == 1 and
|
| 885 |
+
len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0]
|
| 886 |
+
== start_index_map[0] and len(offset_dims) == len(op_shape) - 1):
|
| 887 |
+
raise ValueError("unsupported dimension numbers")
|
| 888 |
+
# We added a trailing dimension of size 1
|
| 889 |
+
if not core.definitely_equal(args.start_indices_shape[-1], 1):
|
| 890 |
+
raise ValueError("start_indices shape[-1] should be 1")
|
| 891 |
+
# Guess the axis
|
| 892 |
+
axis = collapsed_slice_dims[0]
|
| 893 |
+
index_dims = len(args.start_indices_shape) - 1
|
| 894 |
+
expected_offset_dims = tuple(
|
| 895 |
+
list(range(axis)) +
|
| 896 |
+
list(range(axis + index_dims,
|
| 897 |
+
len(op_shape) + index_dims - 1)))
|
| 898 |
+
if offset_dims != expected_offset_dims:
|
| 899 |
+
raise ValueError("unsupported offset_dims")
|
| 900 |
+
expected_slice_sizes = op_shape[:axis] + (1,) + op_shape[axis + 1:] # type: ignore
|
| 901 |
+
if not core.definitely_equal_shape(args.slice_sizes, expected_slice_sizes):
|
| 902 |
+
raise ValueError("unsupported slice_sizes")
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
@gather_precondition(_pre_gather_for_multidim_indexing)
|
| 906 |
+
def _gather_for_multidim_indexing(args: GatherArgs):
|
| 907 |
+
"""Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather.
|
| 908 |
+
|
| 909 |
+
E.g., jnp.take(op, [[0], [1]], axis=0).
|
| 910 |
+
"""
|
| 911 |
+
# Guess the axis.
|
| 912 |
+
axis = args.dnums.collapsed_slice_dims[0]
|
| 913 |
+
squeezed_indices = tf.squeeze(args.start_indices, -1)
|
| 914 |
+
op_shape = jax2tf._eval_shape(args.op_shape)
|
| 915 |
+
start_indices = _clip((op_shape[axis],), squeezed_indices, (1,))
|
| 916 |
+
return tf.gather(args.operand, start_indices, axis=axis, batch_dims=0)
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def _pre_gather_with_batch_dim(args: GatherArgs):
|
| 920 |
+
"""Returns True if this call to gather has non-empty batch dimensions.
|
| 921 |
+
|
| 922 |
+
This is for instance triggered when doing jax.vmap(lax.dynamic_slice).
|
| 923 |
+
"""
|
| 924 |
+
# We assume exactly one batch (and one or more non-batch dimensions).
|
| 925 |
+
if len(args.batch_dims) != 1:
|
| 926 |
+
raise ValueError(f"batch_dims is {len(args.batch_dims)} but should be 1")
|
| 927 |
+
|
| 928 |
+
# `start_index_map` maps indices in `start_indices` to indices in `operand`.
|
| 929 |
+
# For simplicity, we currently only consider the case where this mapping is
|
| 930 |
+
# the identity function, i.e., [2, 3] in `start_indices` maps to
|
| 931 |
+
# `operand[2, 3]`.
|
| 932 |
+
if args.dnums.start_index_map != tuple(range(args.start_indices_shape[-1])):
|
| 933 |
+
raise ValueError("unsupported start_index_map")
|
| 934 |
+
|
| 935 |
+
# The batch dims in `start_indices` and `operand` should agree.
|
| 936 |
+
if not core.definitely_equal(args.op_shape[0], args.start_indices_shape[0]):
|
| 937 |
+
raise ValueError("Batch dimensions in operand and start_indices don't "
|
| 938 |
+
"agree")
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def _pre_gather_with_batch_dims(args: GatherArgs):
|
| 942 |
+
"""Returns True if this call to gather has non-empty 2D batch dimensions.
|
| 943 |
+
|
| 944 |
+
This is for instance triggered when doing
|
| 945 |
+
jax.vmap(jax.vmap(lax.dynamic_slice)).
|
| 946 |
+
"""
|
| 947 |
+
if len(args.dnums.collapsed_slice_dims) != 0:
|
| 948 |
+
# NOTE: this can be relaxed in _gather_with_batch_dims but we might
|
| 949 |
+
# also need to re-work the output reshaping
|
| 950 |
+
raise ValueError("only len(collapsed_slice_dims) == 0 is supported")
|
| 951 |
+
|
| 952 |
+
# NOTE: This supports higher dimensions than listed (the highest dimension
|
| 953 |
+
# in the tests is 3D so it is limited to that, but the implementation is
|
| 954 |
+
# designed to handle higher dimensions (N-Dimensional)).
|
| 955 |
+
if len(args.batch_dims) not in [1, 2, 3]:
|
| 956 |
+
raise ValueError(
|
| 957 |
+
f"Size of batch_dims is {len(args.batch_dims)} but should be up to 3"
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
@gather_precondition(_pre_gather_with_batch_dim)
|
| 961 |
+
def _gather_with_batch_dim(args: GatherArgs):
|
| 962 |
+
"""Implements call to gather with non-empty batch dimensions.
|
| 963 |
+
|
| 964 |
+
E.g., when doing `jax.vmap(lax.dynamic_slice).
|
| 965 |
+
"""
|
| 966 |
+
op_shape = jax2tf._eval_shape(args.op_shape)
|
| 967 |
+
start_indices = _clip(op_shape, args.start_indices, args.slice_sizes)
|
| 968 |
+
result = tf.map_fn(
|
| 969 |
+
lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes),
|
| 970 |
+
start_indices,
|
| 971 |
+
fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype)
|
| 972 |
+
)
|
| 973 |
+
result = tf.reshape(result, jax2tf._eval_shape(args.out_aval.shape))
|
| 974 |
+
return result
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
def _gather_generate_indices(shape: tuple[int, ...]):
|
| 978 |
+
"""
|
| 979 |
+
Returns the indices of the according to `shape`:
|
| 980 |
+
each element in the output is the index of an element of an array
|
| 981 |
+
of the provided shape. The result's shape is (math.prod(shape), len(shape))
|
| 982 |
+
|
| 983 |
+
For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1)
|
| 984 |
+
"""
|
| 985 |
+
return tf.reshape(
|
| 986 |
+
tf.stack(
|
| 987 |
+
tf.meshgrid(
|
| 988 |
+
*[tf.range(start=0, limit=x) for x in shape], indexing="ij"
|
| 989 |
+
),
|
| 990 |
+
axis=-1,
|
| 991 |
+
),
|
| 992 |
+
(-1, len(shape)),
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
@gather_precondition(_pre_gather_with_batch_dims)
|
| 997 |
+
def _gather_with_batch_dims(args: GatherArgs):
|
| 998 |
+
"""Implements call to gather with non-empty 2D batch dimensions."""
|
| 999 |
+
op_shape = jax2tf._eval_shape(args.op_shape)
|
| 1000 |
+
output_shape = jax2tf._eval_shape(args.out_aval.shape)
|
| 1001 |
+
# Used to map the start_indices w.r.t start_index_map
|
| 1002 |
+
indices = tf.expand_dims(args.dnums.start_index_map, 1)
|
| 1003 |
+
|
| 1004 |
+
# batch_indices is shaped (N,d) where N is the number of slices and d is
|
| 1005 |
+
# the number of batch_dims; batch_indices_size equals to N
|
| 1006 |
+
batch_indices = _gather_generate_indices(
|
| 1007 |
+
tuple(output_shape[i] for i in args.batch_dims)
|
| 1008 |
+
)
|
| 1009 |
+
batch_indices_size = jax2tf._eval_shape(batch_indices.shape)[0]
|
| 1010 |
+
# offset_indices is shaped (K,d) where K is the number of elements in each
|
| 1011 |
+
# slice and d is the number of offset_dims; offset_indices_size equals to K
|
| 1012 |
+
offset_indices = _gather_generate_indices(
|
| 1013 |
+
tuple(output_shape[i] for i in args.dnums.offset_dims)
|
| 1014 |
+
)
|
| 1015 |
+
offset_indices_size = jax2tf._eval_shape(offset_indices.shape)[0]
|
| 1016 |
+
|
| 1017 |
+
# After we compute the result we need to reshape the axes with respect to
|
| 1018 |
+
# the output batch_dims and offset_dims.
|
| 1019 |
+
dim_mask = args.batch_dims + args.dnums.offset_dims
|
| 1020 |
+
mask_output_shape = tuple(output_shape[x] for x in dim_mask)
|
| 1021 |
+
|
| 1022 |
+
def get_scatter_indices(indices, batch_indices_size, size_of_index_map):
|
| 1023 |
+
"""Generate the start indices of each slice, which index into the operand."""
|
| 1024 |
+
# Tile indices batch_indices_size times
|
| 1025 |
+
tiled_indices = tf.tile(
|
| 1026 |
+
tf.expand_dims(indices, 0), [batch_indices_size, 1, 1]
|
| 1027 |
+
)
|
| 1028 |
+
# The above tiles need to index the proper element of batch_indices
|
| 1029 |
+
# To do this generate a repeated sequence of numbers
|
| 1030 |
+
temp_batch_indices = tf.repeat(
|
| 1031 |
+
tf.range(start=0, limit=batch_indices_size), size_of_index_map
|
| 1032 |
+
)
|
| 1033 |
+
# Reshape the above sequence so it follows the same shape of tiled_indices
|
| 1034 |
+
temp_batch_indices = tf.reshape(
|
| 1035 |
+
temp_batch_indices, (batch_indices_size, size_of_index_map, 1)
|
| 1036 |
+
)
|
| 1037 |
+
# Now we concatenate to create indices offset by the temp_batch_indices
|
| 1038 |
+
return tf.concat([temp_batch_indices, tiled_indices], axis=-1)
|
| 1039 |
+
|
| 1040 |
+
slice_start_indices = tf.gather_nd(args.start_indices, batch_indices)
|
| 1041 |
+
# TODO: In the case where start_index_map is the identity we can skip this.
|
| 1042 |
+
scatter_indices = get_scatter_indices(
|
| 1043 |
+
indices, batch_indices_size, len(args.dnums.start_index_map)
|
| 1044 |
+
)
|
| 1045 |
+
# We map the scatter_indices w.r.t start_index_map
|
| 1046 |
+
indices_in_operand = tf.scatter_nd(
|
| 1047 |
+
scatter_indices, slice_start_indices, [batch_indices_size, len(op_shape)]
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
# We clip the indices as OOB cases are possible when offsetting past
|
| 1051 |
+
# the operand boundaries
|
| 1052 |
+
clipped_start_indices = _clip(op_shape, indices_in_operand, args.slice_sizes)
|
| 1053 |
+
# Here we need to broadcast clipped_start_indices and add each of the offsets
|
| 1054 |
+
# which will generate a large index tensor of shape (T,d) where T is the
|
| 1055 |
+
# number of slices times the size of each slice (i.e total number of items
|
| 1056 |
+
# across all sices); d is rank(operand)
|
| 1057 |
+
slice_element_indices = tf.add(
|
| 1058 |
+
tf.repeat(clipped_start_indices, offset_indices_size, axis=0),
|
| 1059 |
+
tf.tile(offset_indices, (batch_indices_size, 1)),
|
| 1060 |
+
)
|
| 1061 |
+
results = tf.gather_nd(args.operand, slice_element_indices)
|
| 1062 |
+
|
| 1063 |
+
# Here results comes shaped as (N,1). Because collapsed_slice_dims is 0,
|
| 1064 |
+
# offset_dims is effectviely slice_sizes.
|
| 1065 |
+
# We reshape to mask_output_shape because if we directly reshape to the
|
| 1066 |
+
# output shape and our batch_dims are non-contiguous we will produce the
|
| 1067 |
+
# wrong shape. Reshaping to mask_output_shape gives (...,*slice_sizes),
|
| 1068 |
+
# which we then transpose to permute the axes in the proper way.
|
| 1069 |
+
# Note that if the batch_dims are contiguous this won't change the output.
|
| 1070 |
+
temp = tf.reshape(results, shape=mask_output_shape)
|
| 1071 |
+
return tf.transpose(temp, perm=tf.math.invert_permutation(dim_mask))
|
| 1072 |
+
|
| 1073 |
+
def _gather(operand, start_indices, *, dimension_numbers,
|
| 1074 |
+
slice_sizes: core.Shape, indices_are_sorted, unique_indices, mode,
|
| 1075 |
+
fill_value, _in_avals: Sequence[core.ShapedArray],
|
| 1076 |
+
_out_aval: core.ShapedArray):
|
| 1077 |
+
"""Tensorflow implementation of gather."""
|
| 1078 |
+
if mode == lax.GatherScatterMode.FILL_OR_DROP:
|
| 1079 |
+
gather_fill_fn = jax2tf._convert_jax_impl(lax_slicing._gather_fill,
|
| 1080 |
+
multiple_results=False)
|
| 1081 |
+
return gather_fill_fn(
|
| 1082 |
+
operand, start_indices, dimension_numbers=dimension_numbers,
|
| 1083 |
+
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
| 1084 |
+
indices_are_sorted=indices_are_sorted, fill_value=fill_value,
|
| 1085 |
+
output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)
|
| 1086 |
+
|
| 1087 |
+
# TODO(marcvanzee): Check if we need more tests in shape_poly for gather with
|
| 1088 |
+
# enable_xla=False.
|
| 1089 |
+
gather_args = GatherArgs(
|
| 1090 |
+
operand=operand,
|
| 1091 |
+
start_indices=start_indices,
|
| 1092 |
+
dnums=dimension_numbers,
|
| 1093 |
+
slice_sizes=slice_sizes,
|
| 1094 |
+
op_shape=_in_avals[0].shape,
|
| 1095 |
+
start_indices_shape=_in_avals[1].shape,
|
| 1096 |
+
out_aval=_out_aval)
|
| 1097 |
+
|
| 1098 |
+
errors = []
|
| 1099 |
+
|
| 1100 |
+
for gather_fn in [
|
| 1101 |
+
_gather_for_scalar_indexing,
|
| 1102 |
+
_gather_for_multidim_indexing,
|
| 1103 |
+
_gather_with_batch_dim,
|
| 1104 |
+
_gather_with_batch_dims,
|
| 1105 |
+
]:
|
| 1106 |
+
try:
|
| 1107 |
+
return gather_fn(gather_args)
|
| 1108 |
+
except ValueError as e:
|
| 1109 |
+
errors.append(f"{gather_fn}: {e!r}")
|
| 1110 |
+
|
| 1111 |
+
error_msg = (f"Unsupported arguments for gather: {gather_args}, errors:\n" +
|
| 1112 |
+
"\n".join(errors))
|
| 1113 |
+
|
| 1114 |
+
raise _error("gather", error_msg)
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
tf_impl_no_xla[lax.gather_p] = _gather
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
|
| 1121 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 1122 |
+
_out_aval: core.ShapedArray):
|
| 1123 |
+
start_indices = tf.stack(start_indices)
|
| 1124 |
+
slice_sizes_tf = jax2tf._eval_shape(slice_sizes)
|
| 1125 |
+
|
| 1126 |
+
operand_shape = jax2tf._eval_shape(_in_avals[0].shape)
|
| 1127 |
+
start_indices = _clip(operand_shape, start_indices, slice_sizes_tf)
|
| 1128 |
+
return tf.slice(operand, start_indices, size=slice_sizes_tf)
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
tf_impl_no_xla[lax.dynamic_slice_p] = _dynamic_slice
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def _dynamic_update_slice(operand, update, *start_indices,
|
| 1135 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 1136 |
+
_out_aval: core.ShapedArray):
|
| 1137 |
+
start_indices = tf.stack(start_indices)
|
| 1138 |
+
|
| 1139 |
+
op_shape = jax2tf._eval_shape(_in_avals[0].shape)
|
| 1140 |
+
op_size = tf.size(operand)
|
| 1141 |
+
update_shape_tf = jax2tf._eval_shape(_in_avals[1].shape)
|
| 1142 |
+
|
| 1143 |
+
start_indices = _clip(op_shape, start_indices, update_shape_tf)
|
| 1144 |
+
end_indices = tf.add(start_indices, update_shape_tf)
|
| 1145 |
+
|
| 1146 |
+
# Get the cells to update in `operand` as an array of ids.
|
| 1147 |
+
id_tensor = tf.reshape(tf.range(op_size), op_shape)
|
| 1148 |
+
scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices)
|
| 1149 |
+
|
| 1150 |
+
# Create an array containing updates at scattered_indices and zeros otherwise.
|
| 1151 |
+
flat_indices = tf.expand_dims(tf.nest.flatten(scattered_indices), -1)
|
| 1152 |
+
flat_update = tf.nest.flatten(update)
|
| 1153 |
+
update = tf.scatter_nd(flat_indices, flat_update, (op_size,))
|
| 1154 |
+
update = tf.reshape(update, op_shape)
|
| 1155 |
+
|
| 1156 |
+
# Create a bool mask that is True only where `operand` should be updated.
|
| 1157 |
+
update_mask = tf.ones_like(flat_update, dtype=tf.bool)
|
| 1158 |
+
update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size,))
|
| 1159 |
+
update_mask = tf.reshape(update_mask, op_shape)
|
| 1160 |
+
|
| 1161 |
+
# Use the mask to only update `operand` with `update`.
|
| 1162 |
+
return tf.where(update_mask, update, operand)
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
tf_impl_no_xla[lax.dynamic_update_slice_p] = _dynamic_update_slice
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def shift_axes_forward(operand,
|
| 1169 |
+
axes: tuple[int, ...],
|
| 1170 |
+
inverse: bool = False,
|
| 1171 |
+
forward: bool = True):
|
| 1172 |
+
"""Shifts the tuple of axes to the front of an array"""
|
| 1173 |
+
other_axes = tuple(i for i in range(len(operand.shape)) if i not in axes)
|
| 1174 |
+
fwd_order = axes + other_axes if forward else other_axes + axes
|
| 1175 |
+
order = fwd_order if not inverse else _invert_permutation(fwd_order)
|
| 1176 |
+
return tf.transpose(operand, order)
|
| 1177 |
+
|
| 1178 |
+
def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None):
|
| 1179 |
+
|
| 1180 |
+
def _sparse_scatter(operand, scatter_indices, updates, unique_indices, mode,
|
| 1181 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 1182 |
+
_out_aval: core.ShapedArray):
|
| 1183 |
+
"""Implementation of scatter specialised to indexing from the front axes.
|
| 1184 |
+
|
| 1185 |
+
This covers unique indices and non-unique indices of single depth.
|
| 1186 |
+
Note on unique indices: `tf.tensor_scatter_nd_update` interprets indices
|
| 1187 |
+
thusly: every axis except the final one encodes a batch dimension, the final
|
| 1188 |
+
axis encoding the actual indices to scatter in to. It enforces, at least
|
| 1189 |
+
one, batch dimension so we add an empty dimension to indices and updates if
|
| 1190 |
+
lacking.
|
| 1191 |
+
|
| 1192 |
+
Note on non-unique indices: There is no tf op for non-single depth indexing,
|
| 1193 |
+
but if indexing is single depth, this can be viewed as a segment op.
|
| 1194 |
+
"""
|
| 1195 |
+
# Infer unique indices from lack of batch dimension
|
| 1196 |
+
unique_indices = unique_indices or (len(scatter_indices.shape) == 1)
|
| 1197 |
+
if unique_indices:
|
| 1198 |
+
suboperand = tf.gather_nd(operand, scatter_indices)
|
| 1199 |
+
updated_suboperand = update_op(suboperand, updates)
|
| 1200 |
+
# add a batch dim if none exist
|
| 1201 |
+
if len(scatter_indices.shape) == 1:
|
| 1202 |
+
scatter_indices = scatter_indices[None]
|
| 1203 |
+
updated_suboperand = updated_suboperand[None]
|
| 1204 |
+
y = tf.tensor_scatter_nd_update(operand, scatter_indices, updated_suboperand)
|
| 1205 |
+
else:
|
| 1206 |
+
if (scatter_indices.shape[-1] == 1) and unsorted_segment_op:
|
| 1207 |
+
# If only indexing into the first dimension, it's a segment op
|
| 1208 |
+
operand_update = unsorted_segment_op(updates,
|
| 1209 |
+
tf.squeeze(scatter_indices, -1),
|
| 1210 |
+
operand.shape[0])
|
| 1211 |
+
y = update_op(operand, operand_update)
|
| 1212 |
+
else:
|
| 1213 |
+
raise _scatter_error(
|
| 1214 |
+
"Scatter only supports non-unique "
|
| 1215 |
+
"indices with indexing into only one dimension for (add, mul, min, "
|
| 1216 |
+
"max)")
|
| 1217 |
+
return y
|
| 1218 |
+
|
| 1219 |
+
def sparse_scatter(operand, scatter_indices, updates, update_jaxpr,
|
| 1220 |
+
update_consts, dimension_numbers, indices_are_sorted: bool,
|
| 1221 |
+
unique_indices: bool, mode,
|
| 1222 |
+
_in_avals: Sequence[core.ShapedArray],
|
| 1223 |
+
_out_aval: core.ShapedArray):
|
| 1224 |
+
"""
|
| 1225 |
+
Wrapper around the scatter function.
|
| 1226 |
+
The underlying tf ops `tf.tensor_scatter_nd_update` and
|
| 1227 |
+
`tf.math.unsorted_segment_*` index from the front dimensions.
|
| 1228 |
+
`tf.math.unsorted_segment_*` indexes to a depth 1 from the front.
|
| 1229 |
+
`tf.tensor_scatter_nd_update` indexes from the front dimensions onwards,
|
| 1230 |
+
with no ability to skip a dimension. This function shifts the axes to be
|
| 1231 |
+
indexed to the front then calls a front-specific implementation, then
|
| 1232 |
+
inverse-shifts the output.
|
| 1233 |
+
|
| 1234 |
+
scatter_dims_to_operand_dims: dimensions which the scatter indexes in to.
|
| 1235 |
+
We shift these to the front to match tf syntax. All other dims are batch
|
| 1236 |
+
update_window_dims: dimensions which are not batch dimensions. We shift
|
| 1237 |
+
these to the back as the remaining dimensions are batch dimensions.
|
| 1238 |
+
"""
|
| 1239 |
+
del update_jaxpr, update_consts, indices_are_sorted # Unused arguments
|
| 1240 |
+
|
| 1241 |
+
update_window_dims = dimension_numbers.update_window_dims
|
| 1242 |
+
inserted_window_dims = dimension_numbers.inserted_window_dims
|
| 1243 |
+
scatter_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims
|
| 1244 |
+
|
| 1245 |
+
dtype = operand.dtype # assume updates has same dtype as operand
|
| 1246 |
+
if dtype in [tf.bool, tf.complex64]:
|
| 1247 |
+
raise _scatter_error(f"Scatter does not support operands of type {dtype}")
|
| 1248 |
+
|
| 1249 |
+
if inserted_window_dims != scatter_to_operand_dims:
|
| 1250 |
+
raise _scatter_error("Complex scatters are not supported")
|
| 1251 |
+
|
| 1252 |
+
if (mode != lax.GatherScatterMode.FILL_OR_DROP and
|
| 1253 |
+
mode != lax.GatherScatterMode.PROMISE_IN_BOUNDS):
|
| 1254 |
+
# The OOB behavior for tf.scatter is as follows:
|
| 1255 |
+
# - When running in eager or graph mode, it throws an error.
|
| 1256 |
+
# TODO(marcvanzee): Fix this case by removing the OOB indices.
|
| 1257 |
+
# - When running in compile mode, the OOB indices are dropped, which is
|
| 1258 |
+
# the same behavior as FILL_OR_DROP and PROMISE_IN_BOUNDS.
|
| 1259 |
+
# To ensure correctness, we disallow CLIP mode for now.
|
| 1260 |
+
raise _scatter_error("Only scatter modes `FILL_OR_DROP` and "
|
| 1261 |
+
"`PROMISE_IN_BOUNDS` are supported.")
|
| 1262 |
+
|
| 1263 |
+
# Shift axes to the front to match tf syntax, inverse afterwards
|
| 1264 |
+
fwd = partial(shift_axes_forward, axes=scatter_to_operand_dims)
|
| 1265 |
+
inv = partial(fwd, inverse=True)
|
| 1266 |
+
|
| 1267 |
+
# Shift update value axes to the back, so batch are at the front
|
| 1268 |
+
updates_shifted = shift_axes_forward(
|
| 1269 |
+
updates, axes=update_window_dims, forward=False)
|
| 1270 |
+
|
| 1271 |
+
return inv(
|
| 1272 |
+
_sparse_scatter(
|
| 1273 |
+
fwd(operand), scatter_indices, updates_shifted, unique_indices,
|
| 1274 |
+
mode, _in_avals, _out_aval))
|
| 1275 |
+
return sparse_scatter
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
tf_impl_no_xla[lax.scatter_p] = convert_scatter_jax_to_tf(
|
| 1279 |
+
lambda x, y: y) # just replace with the update
|
| 1280 |
+
tf_impl_no_xla[lax.scatter_add_p] = convert_scatter_jax_to_tf(tf.add, tf.math.unsorted_segment_sum)
|
| 1281 |
+
tf_impl_no_xla[lax.scatter_mul_p] = convert_scatter_jax_to_tf(tf.multiply, tf.math.unsorted_segment_prod)
|
| 1282 |
+
tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.math.unsorted_segment_min)
|
| 1283 |
+
tf_impl_no_xla[lax.scatter_max_p] = convert_scatter_jax_to_tf(tf.maximum, tf.math.unsorted_segment_max)
|
| 1284 |
+
|
| 1285 |
+
tf_impl_no_xla[lax.sort_p] = _unimplemented("sort")
|
| 1286 |
+
|
| 1287 |
+
tf_impl_no_xla[lax.reduce_precision_p] = _unimplemented("reduce_precision")
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/jax2tf.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/back_compat_tf_test.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Tests for backwards compatibility of custom calls involving TensorFlow.
|
| 15 |
+
|
| 16 |
+
See the back_compat_test_util module docstring for how to setup and update
|
| 17 |
+
these tests.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import base64
|
| 23 |
+
from collections.abc import Sequence
|
| 24 |
+
import io
|
| 25 |
+
import os
|
| 26 |
+
import tarfile
|
| 27 |
+
from typing import Callable, Optional
|
| 28 |
+
|
| 29 |
+
from absl.testing import absltest
|
| 30 |
+
import jax
|
| 31 |
+
from jax._src import test_util as jtu
|
| 32 |
+
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
| 33 |
+
from jax._src.lib import xla_extension
|
| 34 |
+
from jax.experimental import jax2tf
|
| 35 |
+
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
|
| 36 |
+
import jax.numpy as jnp
|
| 37 |
+
import tensorflow as tf
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
jax.config.parse_flags_with_absl()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def serialize_directory(directory_path):
|
| 44 |
+
"""Seriliaze the directory as a string."""
|
| 45 |
+
tar_buffer = io.BytesIO()
|
| 46 |
+
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
|
| 47 |
+
tar.add(directory_path, arcname=os.path.basename(directory_path))
|
| 48 |
+
|
| 49 |
+
# Convert the binary data to a base64-encoded string
|
| 50 |
+
serialized_string = base64.b64encode(tar_buffer.getvalue())
|
| 51 |
+
return serialized_string
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def deserialize_directory(serialized_string, output_directory):
|
| 55 |
+
"""Deserialize the string to the directory."""
|
| 56 |
+
# Convert the base64-encoded string back to binary data
|
| 57 |
+
tar_data = base64.b64decode(serialized_string)
|
| 58 |
+
|
| 59 |
+
# Extract the tar archive to the output directory
|
| 60 |
+
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
|
| 61 |
+
tar.extractall(output_directory)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CompatTensoflowTest(bctu.CompatTestBase):
|
| 65 |
+
"""Compatibility tests that use TF.
|
| 66 |
+
|
| 67 |
+
Uses tf.Graph to serialize and run the functions; expects that `func`
|
| 68 |
+
contains a `jax2tf.call_tf` and uses `jax2tf.convert` to generate a
|
| 69 |
+
`tf.Graph` containing a XlaCallModule with the actual MLIR module.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def run_current(self, func: Callable, data: bctu.CompatTestData):
|
| 73 |
+
# Here we use tf.saved_model and provide string serialize/deserialize methods
|
| 74 |
+
# for the whole directory.
|
| 75 |
+
@tf.function(autograph=False, jit_compile=True)
|
| 76 |
+
def tf_func(the_input): # Use recognizable names for input and result
|
| 77 |
+
res = jax2tf.convert(func, native_serialization=True)(the_input)
|
| 78 |
+
return tf.identity(res, name="the_result")
|
| 79 |
+
|
| 80 |
+
self.tf_func = tf_func
|
| 81 |
+
return tf_func(*data.inputs)
|
| 82 |
+
|
| 83 |
+
def serialize(
|
| 84 |
+
self,
|
| 85 |
+
func: Callable,
|
| 86 |
+
data: bctu.CompatTestData,
|
| 87 |
+
polymorphic_shapes: Sequence[str] | None = None,
|
| 88 |
+
allow_unstable_custom_call_targets: Sequence[str] = (),
|
| 89 |
+
):
|
| 90 |
+
# We serialize as a tf.Graph
|
| 91 |
+
assert len(data.inputs) == 1 # We only support a single input now
|
| 92 |
+
tf_graph = self.tf_func.get_concrete_function(*data.inputs).graph
|
| 93 |
+
for op in tf_graph.get_operations():
|
| 94 |
+
if op.type == "XlaCallModule":
|
| 95 |
+
serialized_module = op.get_attr("module")
|
| 96 |
+
module_str = xla_extension.mlir.deserialize_portable_artifact(
|
| 97 |
+
serialized_module
|
| 98 |
+
)
|
| 99 |
+
module_version = op.get_attr("version")
|
| 100 |
+
break
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError("Cannot find an XlaCallModule")
|
| 103 |
+
tf_graph_def = tf_graph.as_graph_def()
|
| 104 |
+
# module_str is just for human readability, add both the MLIR module
|
| 105 |
+
# and the tf.Graph
|
| 106 |
+
module_str = (
|
| 107 |
+
"# First the MLIR module:\n"
|
| 108 |
+
+ module_str
|
| 109 |
+
+ "\n# Then the tf.Graph:\n"
|
| 110 |
+
+ str(tf_graph_def)
|
| 111 |
+
)
|
| 112 |
+
# serialized = tf_graph_def.SerializeToString()
|
| 113 |
+
module = tf.Module()
|
| 114 |
+
module.call = self.tf_func.get_concrete_function(*data.inputs)
|
| 115 |
+
root_dir = self.create_tempdir()
|
| 116 |
+
saved_model_dir = os.path.join(root_dir, "saved_model")
|
| 117 |
+
os.mkdir(saved_model_dir)
|
| 118 |
+
tf.saved_model.save(
|
| 119 |
+
module,
|
| 120 |
+
saved_model_dir,
|
| 121 |
+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
| 122 |
+
)
|
| 123 |
+
serialized = serialize_directory(saved_model_dir)
|
| 124 |
+
nr_devices = 1
|
| 125 |
+
return serialized, module_str, module_version, nr_devices
|
| 126 |
+
|
| 127 |
+
def run_serialized(
|
| 128 |
+
self,
|
| 129 |
+
data: bctu.CompatTestData,
|
| 130 |
+
polymorphic_shapes: Sequence[str] | None = None,
|
| 131 |
+
):
|
| 132 |
+
root_dir = self.create_tempdir()
|
| 133 |
+
deserialize_directory(data.mlir_module_serialized, root_dir)
|
| 134 |
+
saved_model_dir = os.path.join(root_dir, "saved_model")
|
| 135 |
+
loaded_model = tf.saved_model.load(saved_model_dir)
|
| 136 |
+
return (loaded_model.call(*data.inputs).numpy(),)
|
| 137 |
+
|
| 138 |
+
def test_tf_call_tf_function(self):
|
| 139 |
+
# A custom call tf.call_tf_function is generated when we lower call_tf
|
| 140 |
+
# with the call_tf_graph=True option.
|
| 141 |
+
def func(x):
|
| 142 |
+
def func_tf(x):
|
| 143 |
+
return tf.math.sin(x)
|
| 144 |
+
|
| 145 |
+
return jnp.cos(
|
| 146 |
+
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
data = self.load_testdata(tf_call_tf_function.data_2023_07_29)
|
| 150 |
+
self.run_one_test(func, data)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/call_tf_test.py
ADDED
|
@@ -0,0 +1,1821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Tests for call_tf."""
|
| 15 |
+
|
| 16 |
+
import contextlib
|
| 17 |
+
from functools import partial
|
| 18 |
+
import os
|
| 19 |
+
from typing import Callable
|
| 20 |
+
import unittest
|
| 21 |
+
|
| 22 |
+
from absl import logging
|
| 23 |
+
from absl.testing import absltest
|
| 24 |
+
from absl.testing import parameterized
|
| 25 |
+
import jax
|
| 26 |
+
from jax import dlpack
|
| 27 |
+
from jax import dtypes
|
| 28 |
+
from jax import lax
|
| 29 |
+
from jax import numpy as jnp
|
| 30 |
+
from jax._src import config
|
| 31 |
+
from jax._src import test_util as jtu
|
| 32 |
+
from jax._src.lib.mlir import ir
|
| 33 |
+
from jax._src.lib.mlir.dialects import hlo
|
| 34 |
+
from jax.experimental import export
|
| 35 |
+
from jax.experimental import jax2tf
|
| 36 |
+
from jax.experimental.jax2tf.tests import tf_test_util
|
| 37 |
+
import numpy as np
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import tensorflow as tf
|
| 41 |
+
except ImportError:
|
| 42 |
+
tf = None
|
| 43 |
+
|
| 44 |
+
jax.config.parse_flags_with_absl()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
|
| 48 |
+
if with_jit:
|
| 49 |
+
return jax.jit(func)
|
| 50 |
+
else:
|
| 51 |
+
return func
|
| 52 |
+
|
| 53 |
+
def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable:
|
| 54 |
+
if with_jit:
|
| 55 |
+
return tf.function(func, autograph=False, jit_compile=True)
|
| 56 |
+
else:
|
| 57 |
+
return func
|
| 58 |
+
|
| 59 |
+
def _named_test(**kwargs):
|
| 60 |
+
return dict(kwargs,
|
| 61 |
+
testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]))
|
| 62 |
+
|
| 63 |
+
_parameterized_jit = parameterized.named_parameters(
|
| 64 |
+
_named_test(with_jit=with_jit)
|
| 65 |
+
for with_jit in [True, False])
|
| 66 |
+
|
| 67 |
+
_call_tf_non_compilable_error = "Error compiling TensorFlow function"
|
| 68 |
+
_call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape"
|
| 69 |
+
|
| 70 |
+
class CallTfTest(tf_test_util.JaxToTfTestCase):
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def setUpClass(cls):
|
| 74 |
+
# One TF device of each device_type
|
| 75 |
+
cls.tf_devices = []
|
| 76 |
+
for tf_device in tf.config.list_logical_devices():
|
| 77 |
+
if tf_device.device_type == "TPU_SYSTEM":
|
| 78 |
+
continue # A virtual device
|
| 79 |
+
if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
|
| 80 |
+
cls.tf_devices.append(tf_device)
|
| 81 |
+
|
| 82 |
+
super().setUpClass()
|
| 83 |
+
|
| 84 |
+
def setUp(self):
|
| 85 |
+
if tf is None:
|
| 86 |
+
raise unittest.SkipTest("Test requires tensorflow")
|
| 87 |
+
# TODO(b/171320191): this line works around a missing context initialization
|
| 88 |
+
# bug in TensorFlow.
|
| 89 |
+
_ = tf.add(1, 1)
|
| 90 |
+
super().setUp()
|
| 91 |
+
|
| 92 |
+
@_parameterized_jit
|
| 93 |
+
def test_eval_scalar_arg(self, with_jit=True):
|
| 94 |
+
def f_tf(x):
|
| 95 |
+
return tf.math.sin(x)
|
| 96 |
+
x = 3.
|
| 97 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
| 98 |
+
self.assertAllClose(jnp.sin(x), res)
|
| 99 |
+
|
| 100 |
+
@_parameterized_jit
|
| 101 |
+
def test_eval_scalar_res(self, with_jit=True):
|
| 102 |
+
x = 3.
|
| 103 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(lambda x: 4.))(x)
|
| 104 |
+
self.assertAllClose(4., res, check_dtypes=False)
|
| 105 |
+
|
| 106 |
+
@_parameterized_jit
|
| 107 |
+
def test_eval_numpy_arg(self, with_jit=True):
|
| 108 |
+
x = np.ones((2, 3), dtype=np.float32)
|
| 109 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
| 110 |
+
self.assertAllClose(jnp.sin(x), res)
|
| 111 |
+
|
| 112 |
+
@_parameterized_jit
|
| 113 |
+
def test_eval_numpy_res(self, with_jit=False):
|
| 114 |
+
x = np.ones((2, 3))
|
| 115 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(lambda _: x))(x)
|
| 116 |
+
self.assertAllClose(x, res)
|
| 117 |
+
|
| 118 |
+
@_parameterized_jit
|
| 119 |
+
def test_eval_devicearray_arg(self, with_jit=False):
|
| 120 |
+
x = jnp.ones((2, 3), dtype=np.float32)
|
| 121 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
|
| 122 |
+
self.assertAllClose(jnp.sin(x), res)
|
| 123 |
+
|
| 124 |
+
x = jnp.array(3.0, dtype=jnp.bfloat16)
|
| 125 |
+
res = jax2tf.call_tf(lambda x: x)(x)
|
| 126 |
+
self.assertAllClose(x, res)
|
| 127 |
+
# bfloat16 scalar will create a copy.
|
| 128 |
+
with self.assertRaises(AssertionError):
|
| 129 |
+
self.assertTrue(np.shares_memory(x, res))
|
| 130 |
+
|
| 131 |
+
@_parameterized_jit
|
| 132 |
+
def test_eval_pytree(self, with_jit=True):
|
| 133 |
+
|
| 134 |
+
def fun_tf(x: dict, y: tuple) -> tuple:
|
| 135 |
+
return (x["first"] * x["second"], y[0] + y[1])
|
| 136 |
+
|
| 137 |
+
x = dict(first=np.float32(3.), second=np.float32(4.))
|
| 138 |
+
y = (np.float64(5.), np.float64(6.))
|
| 139 |
+
fun_jax = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))
|
| 140 |
+
res = fun_jax(x, y)
|
| 141 |
+
self.assertAllClose((np.float32(12.), np.float64(11.)), res)
|
| 142 |
+
|
| 143 |
+
def test_result_tuple(self):
|
| 144 |
+
x1 = np.ones(3, dtype=np.int32)
|
| 145 |
+
x2 = np.ones(5, dtype=np.float32)
|
| 146 |
+
def fun_tf():
|
| 147 |
+
return tf.tuple([x1, x2])
|
| 148 |
+
|
| 149 |
+
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
| 150 |
+
res = fun_jax()
|
| 151 |
+
self.assertAllClose(res, (x1, x2))
|
| 152 |
+
|
| 153 |
+
def test_error_non_compilable_strings(self):
|
| 154 |
+
# Check that in op-by-op we call a function in eager mode.
|
| 155 |
+
def f_tf_non_compilable(x):
|
| 156 |
+
return tf.strings.length(tf.strings.format("Hello {}!", [x]))
|
| 157 |
+
|
| 158 |
+
f_jax = jax2tf.call_tf(f_tf_non_compilable)
|
| 159 |
+
x = np.float32(0.7)
|
| 160 |
+
self.assertAllClose(f_tf_non_compilable(x).numpy(), f_jax(x))
|
| 161 |
+
with self.assertRaisesRegex(ValueError,
|
| 162 |
+
_call_tf_non_compilable_error):
|
| 163 |
+
jax.jit(f_jax)(x)
|
| 164 |
+
|
| 165 |
+
with self.assertRaisesRegex(ValueError,
|
| 166 |
+
_call_tf_non_compilable_error):
|
| 167 |
+
lax.cond(True, lambda x: f_jax(x), lambda x: f_jax(x), x)
|
| 168 |
+
|
| 169 |
+
def test_error_non_compilable_dynamic_shape(self):
|
| 170 |
+
# Check that in op-by-op we call a function in eager mode.
|
| 171 |
+
def f_tf_non_compilable(x):
|
| 172 |
+
return tf.cond(x[0], lambda: x[1:], lambda: x)
|
| 173 |
+
|
| 174 |
+
f_jax = jax2tf.call_tf(f_tf_non_compilable)
|
| 175 |
+
x = np.array([True, False], dtype=np.bool_)
|
| 176 |
+
self.assertAllClose(f_tf_non_compilable(x), f_jax(x)) # Works in eager mode
|
| 177 |
+
|
| 178 |
+
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
|
| 179 |
+
jax.jit(f_jax)(x)
|
| 180 |
+
|
| 181 |
+
def test_error_bad_result_tensorarray(self):
|
| 182 |
+
# Call a function that returns a tf.TensorArray. This should be detected
|
| 183 |
+
# early on. If we don't the function is actually compilable but returns
|
| 184 |
+
# a tuple instead of a single result.
|
| 185 |
+
def fun_tf():
|
| 186 |
+
ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
|
| 187 |
+
ta = ta.unstack([0, 1, 2, 3, 4])
|
| 188 |
+
return ta
|
| 189 |
+
|
| 190 |
+
with self.assertRaisesRegex(ValueError,
|
| 191 |
+
"The called TF function returns a result that is not convertible to JAX"):
|
| 192 |
+
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
| 193 |
+
fun_jax()
|
| 194 |
+
|
| 195 |
+
def test_error_bad_result_string(self):
|
| 196 |
+
def fun_tf():
|
| 197 |
+
return tf.constant("foo")
|
| 198 |
+
|
| 199 |
+
# Now under jit, should fail because the function is not compilable
|
| 200 |
+
with self.assertRaisesRegex(ValueError,
|
| 201 |
+
"The called TF function returns a result that is not convertible to JAX"):
|
| 202 |
+
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
| 203 |
+
fun_jax()
|
| 204 |
+
|
| 205 |
+
@_parameterized_jit
|
| 206 |
+
def test_control_flow(self, with_jit=True):
|
| 207 |
+
|
| 208 |
+
def times_5_tf(x):
|
| 209 |
+
# Multiply x * 5 using a loop
|
| 210 |
+
c = lambda i, acc: tf.less(i, 5)
|
| 211 |
+
b = lambda i, acc: (tf.add(i, 1), tf.add(acc, x))
|
| 212 |
+
_, acc = tf.while_loop(c, b, [tf.constant(0), tf.constant(0.)])
|
| 213 |
+
return acc
|
| 214 |
+
|
| 215 |
+
def fun_jax(x):
|
| 216 |
+
# Calls times_5_tf 3 times in a loop
|
| 217 |
+
def body(_, acc):
|
| 218 |
+
return jax2tf.call_tf(times_5_tf)(acc)
|
| 219 |
+
|
| 220 |
+
return lax.fori_loop(0, 3, body, x)
|
| 221 |
+
|
| 222 |
+
x = np.float32(3.)
|
| 223 |
+
res = _maybe_jit(with_jit, fun_jax)(x)
|
| 224 |
+
self.assertAllClose(np.float32(x * 5 * 5 * 5), res)
|
| 225 |
+
|
| 226 |
+
@parameterized.named_parameters(
|
| 227 |
+
dict(
|
| 228 |
+
testcase_name=f"_{dtype.__name__}{'_jit' if with_jit else ''}",
|
| 229 |
+
dtype=dtype,
|
| 230 |
+
with_jit=with_jit)
|
| 231 |
+
for dtype in set(jtu.dtypes.all) - {np.bool_}
|
| 232 |
+
for with_jit in [True, False])
|
| 233 |
+
def test_dtypes(self, dtype=np.int32, with_jit=True):
|
| 234 |
+
|
| 235 |
+
def fun_tf(x):
|
| 236 |
+
# AddV2 supports more types
|
| 237 |
+
return tf.raw_ops.AddV2(x=x, y=tf.constant(3, dtype=dtype))
|
| 238 |
+
|
| 239 |
+
def fun_jax(x):
|
| 240 |
+
return jax2tf.call_tf(fun_tf)(x) + x
|
| 241 |
+
|
| 242 |
+
x = np.ones((3,), dtype=dtype)
|
| 243 |
+
res = _maybe_jit(with_jit, fun_jax)(x)
|
| 244 |
+
self.assertAllClose(dtype(2 * x + 3), res)
|
| 245 |
+
|
| 246 |
+
@_parameterized_jit
|
| 247 |
+
def test_bool(self, with_jit=False):
|
| 248 |
+
|
| 249 |
+
def fun_tf(x, y):
|
| 250 |
+
return tf.math.logical_and(x, y)
|
| 251 |
+
|
| 252 |
+
x = np.array([True, False, True, False], dtype=np.bool_)
|
| 253 |
+
y = np.array([True, True, False, False], dtype=np.bool_)
|
| 254 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x, y)
|
| 255 |
+
self.assertAllClose(
|
| 256 |
+
np.array([True, False, False, False], dtype=np.bool_), res)
|
| 257 |
+
|
| 258 |
+
@_parameterized_jit
|
| 259 |
+
def test_x64_input(self, with_jit=True):
|
| 260 |
+
def f_tf(x):
|
| 261 |
+
return tf.math.sin(x)
|
| 262 |
+
|
| 263 |
+
x = 5. # TF interprets this as f64
|
| 264 |
+
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
| 265 |
+
res_jax = jnp.sin(x)
|
| 266 |
+
self.assertAllClose(res_call_tf, res_jax)
|
| 267 |
+
|
| 268 |
+
@_parameterized_jit
|
| 269 |
+
def test_x64_output(self, with_jit=True):
|
| 270 |
+
def f_tf(x):
|
| 271 |
+
return (tf.constant(3., tf.float64), x)
|
| 272 |
+
|
| 273 |
+
x = np.float32(5.)
|
| 274 |
+
res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
|
| 275 |
+
res_jax = (3., x)
|
| 276 |
+
self.assertAllClose(res_call_tf, res_jax)
|
| 277 |
+
|
| 278 |
+
res_call_tf_jit = jax.jit(jax2tf.call_tf(f_tf))(x)
|
| 279 |
+
self.assertAllClose(res_call_tf_jit, res_jax)
|
| 280 |
+
|
| 281 |
+
@_parameterized_jit
|
| 282 |
+
def test_with_var_read(self, with_jit=True):
|
| 283 |
+
# The variable is placed on the default TF device.
|
| 284 |
+
outer_var_array = np.array([3., 4.], dtype=np.float32)
|
| 285 |
+
outer_var = tf.Variable(outer_var_array)
|
| 286 |
+
|
| 287 |
+
def fun_tf(x):
|
| 288 |
+
return x * outer_var + 1.
|
| 289 |
+
|
| 290 |
+
x = np.array([2., 5.,], dtype=np.float32)
|
| 291 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 292 |
+
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
|
| 293 |
+
|
| 294 |
+
@_parameterized_jit
|
| 295 |
+
def test_with_var_read_x64(self, with_jit=True):
|
| 296 |
+
outer_var_array = np.array([3., 4.], dtype=np.float64)
|
| 297 |
+
outer_var = tf.Variable(outer_var_array)
|
| 298 |
+
|
| 299 |
+
def fun_tf(x):
|
| 300 |
+
return x * tf.cast(outer_var, x.dtype) + 1.
|
| 301 |
+
|
| 302 |
+
x = np.array([2., 5.,], dtype=np.float32)
|
| 303 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 304 |
+
self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
|
| 305 |
+
|
| 306 |
+
def test_with_var_different_shape(self):
|
| 307 |
+
# See https://github.com/google/jax/issues/6050
|
| 308 |
+
v = tf.Variable((4., 2.), dtype=tf.float32)
|
| 309 |
+
|
| 310 |
+
def tf_func(x):
|
| 311 |
+
return v + x
|
| 312 |
+
x = np.float32(123.)
|
| 313 |
+
tf_out = tf_func(x)
|
| 314 |
+
|
| 315 |
+
jax_func = jax.jit(jax2tf.call_tf(tf_func))
|
| 316 |
+
jax_out = jax_func(x)
|
| 317 |
+
|
| 318 |
+
self.assertAllClose(tf_out, jax_out, check_dtypes=False)
|
| 319 |
+
|
| 320 |
+
@_parameterized_jit
|
| 321 |
+
def test_with_var_write_error(self, with_jit=True):
|
| 322 |
+
if with_jit:
|
| 323 |
+
raise unittest.SkipTest("variable writes not yet working")
|
| 324 |
+
outer_var = tf.Variable(3., dtype=np.float32)
|
| 325 |
+
|
| 326 |
+
def fun_tf(x):
|
| 327 |
+
outer_var.assign(tf.constant(4.))
|
| 328 |
+
return x * outer_var + 1.
|
| 329 |
+
|
| 330 |
+
x = np.float32(2.)
|
| 331 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 332 |
+
self.assertAllClose(x * 4. + 1, res, check_dtypes=False)
|
| 333 |
+
|
| 334 |
+
@_parameterized_jit
|
| 335 |
+
def test_with_tensor_capture(self, with_jit=True):
|
| 336 |
+
outer_tensor = tf.constant(3., dtype=np.float32)
|
| 337 |
+
|
| 338 |
+
def fun_tf(x):
|
| 339 |
+
return x * outer_tensor + 1.
|
| 340 |
+
|
| 341 |
+
x = np.float32(2.)
|
| 342 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 343 |
+
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
|
| 344 |
+
|
| 345 |
+
@_parameterized_jit
|
| 346 |
+
def test_with_tensor_capture_x64(self, with_jit=True):
|
| 347 |
+
outer_tensor = tf.constant(3., dtype=np.float64)
|
| 348 |
+
|
| 349 |
+
def fun_tf(x):
|
| 350 |
+
return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1.
|
| 351 |
+
|
| 352 |
+
x = np.float32(2.)
|
| 353 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 354 |
+
self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
|
| 355 |
+
|
| 356 |
+
@_parameterized_jit
|
| 357 |
+
def test_with_value_capture(self, with_jit=True):
|
| 358 |
+
outer_val = np.array(3., dtype=np.float32)
|
| 359 |
+
|
| 360 |
+
def fun_tf(x):
|
| 361 |
+
return x * outer_val + 1.
|
| 362 |
+
|
| 363 |
+
x = np.float32(2.)
|
| 364 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 365 |
+
self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
|
| 366 |
+
|
| 367 |
+
@_parameterized_jit
|
| 368 |
+
def test_with_multiple_capture(self, with_jit=True):
|
| 369 |
+
if jtu.test_device_matches(["gpu"]):
|
| 370 |
+
raise unittest.SkipTest("Test fails on GPU")
|
| 371 |
+
v2 = tf.Variable(2., dtype=np.float32)
|
| 372 |
+
v3 = tf.Variable(3., dtype=np.float32)
|
| 373 |
+
t4 = tf.constant(4., dtype=np.float32)
|
| 374 |
+
t5 = tf.constant(5., dtype=np.float32)
|
| 375 |
+
|
| 376 |
+
def fun_tf(x):
|
| 377 |
+
return (x * v3 + t4 + v2) * v3 + t5
|
| 378 |
+
|
| 379 |
+
x = np.float32(2.)
|
| 380 |
+
res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 381 |
+
self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False)
|
| 382 |
+
|
| 383 |
+
@_parameterized_jit
|
| 384 |
+
def test_grad(self, with_jit=False):
|
| 385 |
+
x = np.float32(3.)
|
| 386 |
+
res = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(tf.math.sin)))(x)
|
| 387 |
+
self.assertAllClose(np.cos(x), res)
|
| 388 |
+
|
| 389 |
+
@_parameterized_jit
|
| 390 |
+
def test_grad_pytree(self, with_jit=False):
|
| 391 |
+
|
| 392 |
+
def fun_tf(x: dict, y: tuple) -> tuple:
|
| 393 |
+
return x["first"] * x["second"] + 3. * y[0] + 4. * y[1]
|
| 394 |
+
|
| 395 |
+
x = dict(first=np.float32(3.), second=np.float32(4.))
|
| 396 |
+
y = (np.float32(5.), np.float32(6.))
|
| 397 |
+
grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(fun_tf)))(x, y)
|
| 398 |
+
self.assertAllClose(
|
| 399 |
+
dict(first=np.float32(4.), second=np.float32(3.)), grad_x)
|
| 400 |
+
|
| 401 |
+
def test_grad_nested(self):
|
| 402 |
+
# We embed the call_tf function in a larger function whose gradient we take
|
| 403 |
+
# It is relevant here that the cotangents flowing through the call_tf
|
| 404 |
+
# function are not scalars.
|
| 405 |
+
|
| 406 |
+
b = np.array([[11., 12., 13.], [21., 22., 23.]], dtype=np.float32) # [2, 3]
|
| 407 |
+
c = np.array([[31., 32.], [41., 42.], [51., 52.], [61., 62.]], dtype=np.float32) # [4, 2]
|
| 408 |
+
x_dict = dict(b=b, c=c) # b:[2, 3], c=[4, 2]
|
| 409 |
+
# res: dict(r:[4, 3], s:[4, 2])
|
| 410 |
+
def f_tf(x_dict):
|
| 411 |
+
return dict(r=tf.matmul(x_dict["c"], x_dict["b"]), s=7. * x_dict["c"])
|
| 412 |
+
|
| 413 |
+
@jax.jit # To recognize it in jaxpr
|
| 414 |
+
def f_jax(x_dict):
|
| 415 |
+
return dict(r=jnp.matmul(x_dict["c"], x_dict["b"]), s=7. * x_dict["c"])
|
| 416 |
+
|
| 417 |
+
def loss(functional, x_dict):
|
| 418 |
+
prediction = functional(x_dict) # r:[4, 3], s:[4, 2]
|
| 419 |
+
weights = np.array([1., 2., 3., 4.], dtype=np.float32) # [4]
|
| 420 |
+
weighted_pred = jnp.matmul(weights, prediction["r"]) # [3]
|
| 421 |
+
return jnp.sum(weighted_pred) + 4. * jnp.sum(prediction["s"])
|
| 422 |
+
|
| 423 |
+
g_fun_with_tf = jax.grad(partial(loss, jax2tf.call_tf(f_tf)))
|
| 424 |
+
g_fun_with_jax = jax.grad(partial(loss, f_jax))
|
| 425 |
+
|
| 426 |
+
g_tf = g_fun_with_tf(x_dict)
|
| 427 |
+
g_jax = g_fun_with_jax(x_dict)
|
| 428 |
+
self.assertAllClose(g_jax, g_tf)
|
| 429 |
+
|
| 430 |
+
def test_grad_int_argument(self):
|
| 431 |
+
# Similar to https://github.com/google/jax/issues/6975
|
| 432 |
+
# state is a pytree that contains an integer and a boolean.
|
| 433 |
+
# The function returns an integer and a boolean.
|
| 434 |
+
def f(param, state, x):
|
| 435 |
+
return param * x, state
|
| 436 |
+
|
| 437 |
+
param = np.array([0.7, 0.9], dtype=np.float32)
|
| 438 |
+
state = dict(array=np.float32(1.), counter=7, truth=True)
|
| 439 |
+
x = np.float32(3.)
|
| 440 |
+
|
| 441 |
+
# tf.function is important, without it the bug does not appear
|
| 442 |
+
f_call_tf = jax2tf.call_tf(f)
|
| 443 |
+
g_call_tf = jax.grad(lambda *args: jnp.sum(f_call_tf(*args)[0]))(param, state, x)
|
| 444 |
+
g = jax.grad(lambda *args: jnp.sum(f(*args)[0]))(param, state, x)
|
| 445 |
+
self.assertAllClose(g_call_tf, g)
|
| 446 |
+
|
| 447 |
+
def test_grad_int_argument_unused(self):
|
| 448 |
+
batch_size = 5
|
| 449 |
+
inputs = np.ones((batch_size, 3), dtype=np.float32)
|
| 450 |
+
rng = np.array([1, 2], dtype=np.uint32)
|
| 451 |
+
params = np.float32(.5)
|
| 452 |
+
|
| 453 |
+
# rng is integer, unused
|
| 454 |
+
def jax_model(params, rng, inputs):
|
| 455 |
+
return jnp.ones([batch_size, 2], dtype=jnp.float32)
|
| 456 |
+
|
| 457 |
+
tf_model = jax2tf.convert(jax_model, with_gradient=True)
|
| 458 |
+
|
| 459 |
+
def _loss_fn(inference_fn, params, rng, inputs):
|
| 460 |
+
prediction = inference_fn(params, rng, inputs)
|
| 461 |
+
return jnp.mean(prediction)
|
| 462 |
+
|
| 463 |
+
jax_loss_fn = partial(_loss_fn, jax_model)
|
| 464 |
+
jax_grad = jax.grad(jax_loss_fn)(params, rng, inputs)
|
| 465 |
+
|
| 466 |
+
paramsv = tf.Variable(params)
|
| 467 |
+
with tf.GradientTape() as tape:
|
| 468 |
+
tf_prediction = tf_model(paramsv, rng, inputs)
|
| 469 |
+
tf_loss = tf.reduce_mean(tf_prediction)
|
| 470 |
+
|
| 471 |
+
tf_grad = tape.gradient(tf_loss, paramsv)
|
| 472 |
+
self.assertAllClose(jax_grad, tf_grad.numpy())
|
| 473 |
+
|
| 474 |
+
call_tf_loss_fn = partial(_loss_fn, jax2tf.call_tf(tf_model))
|
| 475 |
+
call_tf_grad = jax.grad(call_tf_loss_fn)(params, rng, inputs)
|
| 476 |
+
self.assertAllClose(jax_grad, call_tf_grad)
|
| 477 |
+
|
| 478 |
+
def test_grad_with_float0_result(self):
|
| 479 |
+
# Gradient over integer-argument functions, with float0 result
|
| 480 |
+
def f_jax(x, y): # x is an int, y is a float; res is a (int, float)
|
| 481 |
+
return (2 * x, 2 * x + y * y)
|
| 482 |
+
def f_tf(x, y):
|
| 483 |
+
# TF needs explicit casts
|
| 484 |
+
return (2 * x, tf.cast(2 * x, dtype=y.dtype) + y * y)
|
| 485 |
+
|
| 486 |
+
def wrapper(functional, x, y): # x: i32
|
| 487 |
+
return jnp.sum(2. * functional(3 * x, 4. * y)[1])
|
| 488 |
+
|
| 489 |
+
grad_g = jax.grad(partial(wrapper, f_jax),
|
| 490 |
+
allow_int=True, argnums=(0, 1))
|
| 491 |
+
grad_g_call_tf = jax.grad(partial(wrapper, jax2tf.call_tf(f_tf)),
|
| 492 |
+
allow_int=True, argnums=(0, 1))
|
| 493 |
+
|
| 494 |
+
x = np.int32(2)
|
| 495 |
+
y = np.float32(3.)
|
| 496 |
+
g_jax = grad_g(x, y)
|
| 497 |
+
g_call_tf = grad_g_call_tf(x, y)
|
| 498 |
+
self.assertEqual(g_jax[0].dtype, dtypes.float0)
|
| 499 |
+
self.assertEqual(g_call_tf[0].dtype, dtypes.float0)
|
| 500 |
+
self.assertAllClose(g_jax[1], g_call_tf[1])
|
| 501 |
+
|
| 502 |
+
@_parameterized_jit
|
| 503 |
+
def test_grad_custom(self, with_jit=False):
|
| 504 |
+
|
| 505 |
+
@tf.custom_gradient
|
| 506 |
+
def func_square_tf(x):
|
| 507 |
+
# Like x ** 2, but with custom grad 3. * x
|
| 508 |
+
def grad(dy, variables=None):
|
| 509 |
+
# dy, = dys
|
| 510 |
+
return 3. * x * dy,
|
| 511 |
+
|
| 512 |
+
return x * x, grad
|
| 513 |
+
|
| 514 |
+
x = np.float32(4.)
|
| 515 |
+
grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(func_square_tf)))(x)
|
| 516 |
+
self.assertAllClose(np.float32(3.) * x, grad_x)
|
| 517 |
+
|
| 518 |
+
@parameterized.named_parameters(
|
| 519 |
+
dict(
|
| 520 |
+
testcase_name=f"_{degree=}{'_jit' if with_jit else ''}",
|
| 521 |
+
degree=degree,
|
| 522 |
+
with_jit=with_jit)
|
| 523 |
+
for degree in [1, 2, 3, 4]
|
| 524 |
+
for with_jit in [True, False])
|
| 525 |
+
def test_higher_order_grad(self, degree=2, with_jit=False):
|
| 526 |
+
|
| 527 |
+
def fun_tf(x):
|
| 528 |
+
return 2. * x * x * x
|
| 529 |
+
|
| 530 |
+
def fun_jax(x):
|
| 531 |
+
return 3. * _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
|
| 532 |
+
|
| 533 |
+
def fun_jax_pure(x):
|
| 534 |
+
return 3. * fun_tf(x)
|
| 535 |
+
|
| 536 |
+
grad_jax = fun_jax
|
| 537 |
+
grad_jax_pure = fun_jax_pure
|
| 538 |
+
for _ in range(degree):
|
| 539 |
+
grad_jax = jax.grad(grad_jax)
|
| 540 |
+
grad_jax_pure = jax.grad(grad_jax_pure)
|
| 541 |
+
|
| 542 |
+
res_jax = grad_jax(np.float32(5.))
|
| 543 |
+
logging.info("Grad of %s degree is %s", degree, res_jax)
|
| 544 |
+
self.assertAllClose(res_jax, grad_jax_pure(np.float32(5.)))
|
| 545 |
+
|
| 546 |
+
def test_pmap(self):
|
| 547 |
+
logging.info("Running test_pmap on %s devices", jax.local_device_count())
|
| 548 |
+
|
| 549 |
+
def plus_2_tf(x):
|
| 550 |
+
return tf.math.add(2., x)
|
| 551 |
+
|
| 552 |
+
def fun_jax(x):
|
| 553 |
+
return np.float32(3.) * jax2tf.call_tf(plus_2_tf)(x)
|
| 554 |
+
|
| 555 |
+
x = np.arange(jax.local_device_count(), dtype=np.float32)
|
| 556 |
+
res = jax.pmap(fun_jax)(x)
|
| 557 |
+
self.assertAllClose(np.float32(3. * (x + 2)), res)
|
| 558 |
+
|
| 559 |
+
def test_function_compile_time_constant_inputs(self):
|
| 560 |
+
# Call a function for which shape inference does not give an output
|
| 561 |
+
# shape.
|
| 562 |
+
x = np.array([1, 2, 3], dtype=np.int32)
|
| 563 |
+
def fun_tf(x): # x:i32[3]
|
| 564 |
+
# Indexing with a dynamic slice makes the TF shape inference return
|
| 565 |
+
# a partially known shape.
|
| 566 |
+
end_idx = x[1]
|
| 567 |
+
res = x[0:end_idx]
|
| 568 |
+
return res
|
| 569 |
+
|
| 570 |
+
# Call in eager mode. Should work!
|
| 571 |
+
res1 = jax2tf.call_tf(fun_tf)(x)
|
| 572 |
+
self.assertAllClose(x[0:x[1]], res1)
|
| 573 |
+
|
| 574 |
+
# Now under jit, should fail because the function is not compilable
|
| 575 |
+
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
|
| 576 |
+
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
| 577 |
+
fun_jax(x)
|
| 578 |
+
|
| 579 |
+
def test_experimental_get_compiler_ir_design_doc(self):
|
| 580 |
+
# Not a test of call_tf, but more of how experimental_get_compiler_ir works.
|
| 581 |
+
# Examples are from the design doc.
|
| 582 |
+
|
| 583 |
+
# Constant slice. This is the common case.
|
| 584 |
+
x = np.zeros((10,), dtype=np.int32)
|
| 585 |
+
|
| 586 |
+
def fun_tf(x):
|
| 587 |
+
begin = 0
|
| 588 |
+
return x[begin:5]
|
| 589 |
+
|
| 590 |
+
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
| 591 |
+
self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
|
| 592 |
+
|
| 593 |
+
# Non-constant slice, but compile-time constant depending only on values.
|
| 594 |
+
x = np.zeros((10,), dtype=np.int32)
|
| 595 |
+
|
| 596 |
+
# Non-constant slice, but compile-time constant depending only on shapes.
|
| 597 |
+
x = np.zeros((10,), dtype=np.int32)
|
| 598 |
+
|
| 599 |
+
def fun_tf(x):
|
| 600 |
+
begin = tf.shape(x)[0] - 2 # begin is a compile-time constant, even if x is not
|
| 601 |
+
return x[begin:]
|
| 602 |
+
|
| 603 |
+
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
| 604 |
+
self.assertIn("(arg0.1: s32[10]) -> s32[2]", hlo)
|
| 605 |
+
|
| 606 |
+
# Capture a variable
|
| 607 |
+
outer_var = tf.Variable(np.array([3.], dtype=np.float32))
|
| 608 |
+
x = np.array([2., 3., 4.], dtype=np.float32)
|
| 609 |
+
|
| 610 |
+
def fun_tf(x):
|
| 611 |
+
return x * tf.broadcast_to(outer_var, x.shape) + 1.
|
| 612 |
+
|
| 613 |
+
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
| 614 |
+
self.assertIn("(arg0.1: f32[3], arg1.2: f32[1]) -> f32[3]", hlo)
|
| 615 |
+
|
| 616 |
+
# Capture a constant
|
| 617 |
+
outer_ct = np.array([3.], dtype=np.float32)
|
| 618 |
+
x = np.array([2., 3., 4.], dtype=np.float32)
|
| 619 |
+
|
| 620 |
+
def fun_tf(x):
|
| 621 |
+
return x * tf.broadcast_to(outer_ct, x.shape) + 1.
|
| 622 |
+
|
| 623 |
+
hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
|
| 624 |
+
self.assertIn("(arg0.1: f32[3]) -> f32[3]", hlo)
|
| 625 |
+
|
| 626 |
+
# Call get_compiler_ir in a function context
|
| 627 |
+
x = np.array([2., 3., 4.], dtype=np.float32)
|
| 628 |
+
|
| 629 |
+
def fun_tf_outer(x):
|
| 630 |
+
x_const = tf.constant(0, shape=x.shape, dtype=x.dtype)
|
| 631 |
+
_ = tf.function(tf.math.sin, jit_compile=True, autograph=False).experimental_get_compiler_ir(x_const)()
|
| 632 |
+
|
| 633 |
+
# TODO(b/193754660)
|
| 634 |
+
# with self.assertRaisesRegex(
|
| 635 |
+
# TypeError, "An op outside of the function building code is being passed"):
|
| 636 |
+
# tf.function(fun_tf_outer)(x)
|
| 637 |
+
#
|
| 638 |
+
# with self.assertRaisesRegex(
|
| 639 |
+
# TypeError, "An op outside of the function building code is being passed"):
|
| 640 |
+
# tf.function(fun_tf_outer, jit_compile=True)(x)
|
| 641 |
+
|
| 642 |
+
# Call get_concrete_function in a graph context
|
| 643 |
+
def fun_tf_outer_2(x):
|
| 644 |
+
_ = tf.function(tf.math.sin, jit_compile=True).get_concrete_function(tf.TensorSpec(x.shape, x.dtype))
|
| 645 |
+
return x
|
| 646 |
+
|
| 647 |
+
# Outside of a function context, this works.
|
| 648 |
+
_ = tf.function(fun_tf_outer_2)(x)
|
| 649 |
+
_ = tf.function(fun_tf_outer_2, jit_compile=True)(x)
|
| 650 |
+
|
| 651 |
+
def test_repro_193754660(self):
|
| 652 |
+
# Try to reproduce b/193754660. I can't.
|
| 653 |
+
# We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))).
|
| 654 |
+
# The get_compiler_ir will indeed fail for f_tf. Then we try to use
|
| 655 |
+
# shape inference for f_tf.
|
| 656 |
+
# I thought to use a f_tf that uses an op without shape inference, e.g.,
|
| 657 |
+
# tfxla.gather. If we wash it through a saved_model I expect that shape
|
| 658 |
+
# inference would not work on it. Instead, shape inference works!!!
|
| 659 |
+
x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)
|
| 660 |
+
def f_jax(x):
|
| 661 |
+
return x[1]
|
| 662 |
+
f_tf = jax2tf.convert(f_jax)
|
| 663 |
+
f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
| 664 |
+
f_jax2 = jax2tf.call_tf(f_tf_rt)
|
| 665 |
+
f_tf2 = jax2tf.convert(f_jax2)
|
| 666 |
+
res = tf.function(f_tf2, autograph=False)(x)
|
| 667 |
+
self.assertAllClose(res.numpy(), f_jax(x))
|
| 668 |
+
|
| 669 |
+
def test_effectful(self):
|
| 670 |
+
x = np.ones((3,), dtype=np.float32)
|
| 671 |
+
lower_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=True)).lower(x)
|
| 672 |
+
self.assertNotEmpty(lower_effect._lowering.compile_args["unordered_effects"])
|
| 673 |
+
|
| 674 |
+
lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
|
| 675 |
+
self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
|
| 676 |
+
|
| 677 |
+
def test_module_documentation(self):
|
| 678 |
+
def cos_tf(x):
|
| 679 |
+
return tf.math.cos(x)
|
| 680 |
+
|
| 681 |
+
# Compute cos with TF and sin with JAX
|
| 682 |
+
def cos_tf_sin_jax(x):
|
| 683 |
+
return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x))
|
| 684 |
+
|
| 685 |
+
# Calls `cos_tf` in TF eager mode
|
| 686 |
+
x = np.float32(1.)
|
| 687 |
+
cos_tf_sin_jax(x)
|
| 688 |
+
|
| 689 |
+
# Compiles `cos_tf` using TF and embeds the XLA computation into the JAX
|
| 690 |
+
# XLA computation (containing `sin`). The XLA compiler may even be able to
|
| 691 |
+
# fuse through JAX-TF computations.
|
| 692 |
+
jax.jit(cos_tf_sin_jax)(x)
|
| 693 |
+
|
| 694 |
+
# Uses TF gradient for `cos_tf` and JAX gradient for `sin`
|
| 695 |
+
jax.grad(cos_tf_sin_jax)(x)
|
| 696 |
+
|
| 697 |
+
logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x))
|
| 698 |
+
logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
|
| 699 |
+
|
| 700 |
+
def test_tf_gather(self):
|
| 701 |
+
"""tf_gather gradient output is tf.IndexSlices."""
|
| 702 |
+
operand = jnp.array(np.random.uniform(size=(100, 128)))
|
| 703 |
+
indices = jnp.array(np.random.randint(low=0, high=100, size=(4000,)))
|
| 704 |
+
|
| 705 |
+
@tf.function(jit_compile=True, autograph=False)
|
| 706 |
+
def fun_tf(operand, indices):
|
| 707 |
+
return tf.experimental.numpy.std(tf.gather(operand, indices))
|
| 708 |
+
|
| 709 |
+
fun_jax = jax2tf.call_tf(fun_tf)
|
| 710 |
+
grad_fun_jax = jax.grad(fun_jax)
|
| 711 |
+
grad_res = grad_fun_jax(operand, indices)
|
| 712 |
+
self.assertEqual(grad_res.shape, (100, 128))
|
| 713 |
+
|
| 714 |
+
def test_output_shape_dtype_none(self):
|
| 715 |
+
x = jnp.zeros((10), dtype=jnp.float32)
|
| 716 |
+
|
| 717 |
+
@tf.function(jit_compile=True, autograph=False)
|
| 718 |
+
def fun_tf(x): # pylint: disable=unused-argument
|
| 719 |
+
return
|
| 720 |
+
|
| 721 |
+
fun_jax_1 = jax2tf.call_tf(fun_tf, output_shape_dtype=None)
|
| 722 |
+
fun_jax_2 = jax2tf.call_tf(fun_tf)
|
| 723 |
+
self.assertIsNone(fun_jax_1(x))
|
| 724 |
+
self.assertIsNone(fun_jax_2(x))
|
| 725 |
+
fun_jax_3 = jax2tf.call_tf(
|
| 726 |
+
fun_tf, output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32)
|
| 727 |
+
)
|
| 728 |
+
with self.assertRaisesRegex(
|
| 729 |
+
ValueError,
|
| 730 |
+
"The pytree of the TensorFlow function results does not match the"
|
| 731 |
+
" pytree of the declared output_shape_dtype",
|
| 732 |
+
):
|
| 733 |
+
_ = fun_jax_3(x)
|
| 734 |
+
|
| 735 |
+
def test_output_shape_dtype_not_none(self):
|
| 736 |
+
x = jnp.zeros((10), dtype=jnp.float32)
|
| 737 |
+
|
| 738 |
+
@tf.function(jit_compile=True, autograph=False)
|
| 739 |
+
def fun_tf(x):
|
| 740 |
+
return x
|
| 741 |
+
|
| 742 |
+
fun_jax_1 = jax2tf.call_tf(
|
| 743 |
+
fun_tf, output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32)
|
| 744 |
+
)
|
| 745 |
+
fun_jax_2 = jax2tf.call_tf(fun_tf)
|
| 746 |
+
self.assertAllClose(fun_jax_1(x), fun_jax_2(x))
|
| 747 |
+
|
| 748 |
+
fun_jax_3 = jax2tf.call_tf(fun_tf, output_shape_dtype=None)
|
| 749 |
+
with self.assertRaisesRegex(
|
| 750 |
+
ValueError,
|
| 751 |
+
"The pytree of the TensorFlow function results does not match the"
|
| 752 |
+
" pytree of the declared output_shape_dtype",
|
| 753 |
+
):
|
| 754 |
+
_ = fun_jax_3(x)
|
| 755 |
+
|
| 756 |
+
def test_multi_platform(self):
|
| 757 |
+
def tf_fun(x):
|
| 758 |
+
return tf.math.sin(x)
|
| 759 |
+
|
| 760 |
+
def f_jax(x):
|
| 761 |
+
return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
|
| 762 |
+
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
| 763 |
+
|
| 764 |
+
# Find platforms that are available for both JAX and TF
|
| 765 |
+
# Pick one device from each available platform
|
| 766 |
+
jax_platforms = []
|
| 767 |
+
for backend in ["cpu", "gpu", "tpu"]:
|
| 768 |
+
try:
|
| 769 |
+
devices = jax.devices(backend)
|
| 770 |
+
except RuntimeError:
|
| 771 |
+
devices = []
|
| 772 |
+
if devices:
|
| 773 |
+
jax_platforms.append(devices[0].platform)
|
| 774 |
+
|
| 775 |
+
jax_and_tf_platforms = (
|
| 776 |
+
set(jax_platforms) & {d.device_type.lower()
|
| 777 |
+
for d in self.__class__.tf_devices})
|
| 778 |
+
|
| 779 |
+
lowering_platforms = ("tpu", "cpu", "cuda")
|
| 780 |
+
|
| 781 |
+
exp = export.export(f_jax,
|
| 782 |
+
lowering_platforms=lowering_platforms)(x)
|
| 783 |
+
for jax_platform in jax_and_tf_platforms:
|
| 784 |
+
with self.subTest(jax_platform):
|
| 785 |
+
jax_device = jax.devices(jax_platform)[0]
|
| 786 |
+
x_device = jax.device_put(x, jax_device)
|
| 787 |
+
logging.info("Running harness natively on %s", jax_device)
|
| 788 |
+
native_res = f_jax(x_device)
|
| 789 |
+
logging.info("Running exported harness on %s", jax_device)
|
| 790 |
+
exported_res = export.call(exp)(x_device)
|
| 791 |
+
self.assertAllClose(native_res, exported_res)
|
| 792 |
+
|
| 793 |
+
def test_multi_platform_call_tf_graph(self):
|
| 794 |
+
def tf_fun(x):
|
| 795 |
+
return tf.math.sin(x)
|
| 796 |
+
|
| 797 |
+
def f_jax(x):
|
| 798 |
+
return jnp.cos(jax2tf.call_tf(tf_fun,
|
| 799 |
+
call_tf_graph=True,
|
| 800 |
+
ordered=True)(jnp.cos(x)))
|
| 801 |
+
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
| 802 |
+
# When we use call_tf_graph we can serialize for multiple platforms
|
| 803 |
+
lowering_platforms = ("tpu", "cpu", "cuda")
|
| 804 |
+
# We must use jax2tf.convert to run a call_tf(call_tf_graph)
|
| 805 |
+
# TODO(necula): if we remove the tf.function and we have multiple platforms
|
| 806 |
+
# then we attempt to lower call_tf multiple times and only the first
|
| 807 |
+
# lowering will have the proper side effects for the function_list.
|
| 808 |
+
f_tf = tf.function(jax2tf.convert(
|
| 809 |
+
f_jax,
|
| 810 |
+
native_serialization=True,
|
| 811 |
+
native_serialization_platforms=lowering_platforms))
|
| 812 |
+
for tf_device in self.__class__.tf_devices:
|
| 813 |
+
with self.subTest(tf_device.device_type):
|
| 814 |
+
logging.info(
|
| 815 |
+
f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
|
| 816 |
+
with tf.device(tf_device):
|
| 817 |
+
res = f_tf(x)
|
| 818 |
+
self.assertAllClose(res, f_jax(x))
|
| 819 |
+
|
| 820 |
+
@parameterized.named_parameters(
|
| 821 |
+
{"testcase_name": f"_type={type_.__name__}", "type_": type_}
|
| 822 |
+
for type_ in dlpack.SUPPORTED_DTYPES
|
| 823 |
+
)
|
| 824 |
+
def test_avoid_copy_between_gpu_and_cpu(self, type_):
|
| 825 |
+
try:
|
| 826 |
+
gpu_devices = jax.devices("gpu")
|
| 827 |
+
except RuntimeError:
|
| 828 |
+
gpu_devices = []
|
| 829 |
+
if not gpu_devices:
|
| 830 |
+
raise unittest.SkipTest("Test requires a GPU device.")
|
| 831 |
+
|
| 832 |
+
def tf_fun(x):
|
| 833 |
+
if type_ == jnp.bool_:
|
| 834 |
+
return tf.math.logical_or(x, True)
|
| 835 |
+
else:
|
| 836 |
+
return x + 1
|
| 837 |
+
|
| 838 |
+
jax_array_on_gpu = jnp.zeros([1], type_, device=gpu_devices[0])
|
| 839 |
+
|
| 840 |
+
# Since the input array is already on a GPU device, we expect that no memory
|
| 841 |
+
# copy occurs between GPU and CPU. Thus, we expect no errors raised by the
|
| 842 |
+
# transfer guard.
|
| 843 |
+
# There are two exceptions:
|
| 844 |
+
# First, when dtype is "int32". This is because almost all TensorFlow
|
| 845 |
+
# kernels for GPU devices keep int32 tensors in host memory.
|
| 846 |
+
# (https://github.com/tensorflow/tensorflow/blob/4eb3e36d1b0cd511e1677e740bd093f42365cf9f/tensorflow/python/eager/pywrap_tensor.cc#L352-L354)
|
| 847 |
+
# Hence, for "int32", we do expect a "host-to-device" copy.
|
| 848 |
+
# Second, when using PJRT C API runtime. This is because it currently skips dlpack
|
| 849 |
+
# to workaround "PJRT C API does not support GetDefaultLayout" runtime error.
|
| 850 |
+
# https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289
|
| 851 |
+
@contextlib.contextmanager
|
| 852 |
+
def _transfer_guard(guard_level):
|
| 853 |
+
with contextlib.ExitStack() as stack:
|
| 854 |
+
stack.enter_context(jax.transfer_guard_device_to_device(guard_level))
|
| 855 |
+
stack.enter_context(jax.transfer_guard_device_to_host(guard_level))
|
| 856 |
+
if type_ != jnp.int32:
|
| 857 |
+
stack.enter_context(jax.transfer_guard_host_to_device(guard_level))
|
| 858 |
+
yield
|
| 859 |
+
|
| 860 |
+
with _transfer_guard("disallow_explicit"):
|
| 861 |
+
jax2tf.call_tf(tf_fun)(jax_array_on_gpu)
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
| 865 |
+
"Reloading output of jax2tf into JAX with call_tf"
|
| 866 |
+
def setUp(self):
|
| 867 |
+
if tf is None:
|
| 868 |
+
raise unittest.SkipTest("Test requires tensorflow")
|
| 869 |
+
# TODO(b/171320191): this line works around a missing context initialization
|
| 870 |
+
# bug in TensorFlow.
|
| 871 |
+
_ = tf.add(1, 1)
|
| 872 |
+
super().setUp()
|
| 873 |
+
|
| 874 |
+
def test_simple(self):
|
| 875 |
+
f_jax = jnp.sin
|
| 876 |
+
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
| 877 |
+
x = np.float32(0.7)
|
| 878 |
+
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
| 879 |
+
|
| 880 |
+
def test_pytree(self):
|
| 881 |
+
def f_jax(x): # x: dict(a=f32, b=f32)
|
| 882 |
+
return dict(a=x["a"]+1., b=x)
|
| 883 |
+
x = dict(a=0.7, b=0.8)
|
| 884 |
+
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
|
| 885 |
+
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
| 886 |
+
|
| 887 |
+
def test_custom_grad(self):
|
| 888 |
+
@jax.custom_vjp
|
| 889 |
+
def f(x):
|
| 890 |
+
return x * x
|
| 891 |
+
|
| 892 |
+
# f_fwd: a -> (b, residual)
|
| 893 |
+
def f_fwd(x):
|
| 894 |
+
return f(x), np.float32(3.) * x
|
| 895 |
+
# f_bwd: (residual, CT b) -> [CT a]
|
| 896 |
+
def f_bwd(residual, ct_b):
|
| 897 |
+
return residual * ct_b,
|
| 898 |
+
|
| 899 |
+
f.defvjp(f_fwd, f_bwd)
|
| 900 |
+
|
| 901 |
+
f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
|
| 902 |
+
x = np.float32(0.7)
|
| 903 |
+
self.assertAllClose(f(x), f_rt(x))
|
| 904 |
+
self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
|
| 905 |
+
|
| 906 |
+
def test_shape_poly(self):
|
| 907 |
+
f_jax = jnp.sin
|
| 908 |
+
f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
|
| 909 |
+
polymorphic_shapes=["(b, ...)"]))
|
| 910 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 911 |
+
self.assertAllClose(f_jax(x), f_jax_rt(x))
|
| 912 |
+
|
| 913 |
+
def test_saved_model_simple(self):
|
| 914 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 915 |
+
def f_jax(x):
|
| 916 |
+
return jnp.sin(x)
|
| 917 |
+
|
| 918 |
+
f_tf = jax2tf.convert(f_jax)
|
| 919 |
+
restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
| 920 |
+
restored_jax = jax2tf.call_tf(restored_tf)
|
| 921 |
+
self.assertAllClose(f_jax(x), restored_jax(x))
|
| 922 |
+
|
| 923 |
+
def test_saved_model_variables(self):
|
| 924 |
+
param = np.array([1., 2.], dtype=np.float32)
|
| 925 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 926 |
+
def f_jax(param, x):
|
| 927 |
+
return jnp.sin(x) + jnp.cos(param)
|
| 928 |
+
|
| 929 |
+
param_v = tf.Variable(param)
|
| 930 |
+
f_tf = jax2tf.convert(f_jax)
|
| 931 |
+
_, restored_model = tf_test_util.SaveAndLoadFunction(
|
| 932 |
+
lambda x: f_tf(param_v, x),
|
| 933 |
+
input_args=[x],
|
| 934 |
+
variables=[param_v])
|
| 935 |
+
restored_jax = jax2tf.call_tf(restored_model.f)
|
| 936 |
+
self.assertAllClose(f_jax(param, x), restored_jax(x))
|
| 937 |
+
self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
|
| 938 |
+
|
| 939 |
+
def test_saved_model_shape_poly(self):
|
| 940 |
+
tracing_count = 0
|
| 941 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 942 |
+
def f_jax(x):
|
| 943 |
+
nonlocal tracing_count
|
| 944 |
+
tracing_count += 1
|
| 945 |
+
return jnp.sin(x)
|
| 946 |
+
|
| 947 |
+
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
| 948 |
+
res_jax = f_jax(x)
|
| 949 |
+
self.assertEqual(1, tracing_count)
|
| 950 |
+
# Will trace twice, it seems. Once to get the result signature, and once again
|
| 951 |
+
# for the actual saving.
|
| 952 |
+
restored_f, _ = tf_test_util.SaveAndLoadFunction(
|
| 953 |
+
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
|
| 954 |
+
self.assertGreaterEqual(tracing_count, 2)
|
| 955 |
+
tracing_count = 0
|
| 956 |
+
f_jax_rt = jax2tf.call_tf(restored_f)
|
| 957 |
+
self.assertAllClose(res_jax, f_jax_rt(x))
|
| 958 |
+
# Ensure that restored_f works at other batch size as well
|
| 959 |
+
y = np.concatenate([x, x])
|
| 960 |
+
self.assertEqual(0, tracing_count)
|
| 961 |
+
res_jax_y = f_jax(y)
|
| 962 |
+
self.assertEqual(1, tracing_count)
|
| 963 |
+
# No more tracing for f_jax_rt
|
| 964 |
+
self.assertAllClose(res_jax_y, f_jax_rt(y))
|
| 965 |
+
self.assertEqual(1, tracing_count)
|
| 966 |
+
|
| 967 |
+
def test_custom_grad_saved_model(self):
|
| 968 |
+
|
| 969 |
+
@jax.custom_vjp
|
| 970 |
+
def f(x):
|
| 971 |
+
return x * x
|
| 972 |
+
|
| 973 |
+
# f_fwd: a -> (b, residual)
|
| 974 |
+
def f_fwd(x):
|
| 975 |
+
return f(x), np.float32(3.) * x
|
| 976 |
+
# f_bwd: (residual, CT b) -> [CT a]
|
| 977 |
+
def f_bwd(residual, ct_b):
|
| 978 |
+
return residual * ct_b,
|
| 979 |
+
|
| 980 |
+
f.defvjp(f_fwd, f_bwd)
|
| 981 |
+
def g(x):
|
| 982 |
+
return jnp.sum(f(x))
|
| 983 |
+
|
| 984 |
+
g_tf, _ = tf_test_util.SaveAndLoadFunction(
|
| 985 |
+
jax2tf.convert(g, with_gradient=True),
|
| 986 |
+
input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)],
|
| 987 |
+
)
|
| 988 |
+
g_rt = jax2tf.call_tf(g_tf)
|
| 989 |
+
x = np.array([0.7], dtype=np.float32)
|
| 990 |
+
self.assertAllClose(g(x), g_rt(x))
|
| 991 |
+
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
|
| 992 |
+
|
| 993 |
+
def test_without_gradient_saved_model(self):
|
| 994 |
+
# Explicitly with_gradient=False
|
| 995 |
+
f_jax = jnp.sum
|
| 996 |
+
|
| 997 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 998 |
+
f_tf, _ = tf_test_util.SaveAndLoadFunction(
|
| 999 |
+
jax2tf.convert(f_jax, with_gradient=False),
|
| 1000 |
+
input_args=[x])
|
| 1001 |
+
f_rt = jax2tf.call_tf(f_tf)
|
| 1002 |
+
|
| 1003 |
+
self.assertAllClose(f_jax(x), f_rt(x))
|
| 1004 |
+
with self.assertRaisesRegex(Exception,
|
| 1005 |
+
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
|
| 1006 |
+
jax.grad(f_rt)(x)
|
| 1007 |
+
|
| 1008 |
+
def test_saved_model_no_gradients(self):
|
| 1009 |
+
# Save without gradients
|
| 1010 |
+
f_jax = jnp.sum
|
| 1011 |
+
|
| 1012 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 1013 |
+
f_tf, _ = tf_test_util.SaveAndLoadFunction(
|
| 1014 |
+
jax2tf.convert(f_jax, with_gradient=True), input_args=[x],
|
| 1015 |
+
save_gradients=False)
|
| 1016 |
+
f_rt = jax2tf.call_tf(f_tf)
|
| 1017 |
+
|
| 1018 |
+
self.assertAllClose(f_jax(x), f_rt(x))
|
| 1019 |
+
# TODO: clean this up b/191117111: it should fail with a clear error
|
| 1020 |
+
# The following results in a confusing error:
|
| 1021 |
+
# TypeError: tf.Graph captured an external symbolic tensor.
|
| 1022 |
+
with self.assertRaises(TypeError):
|
| 1023 |
+
_ = jax.grad(f_rt)(x)
|
| 1024 |
+
|
| 1025 |
+
def test_call_tf_under_function_context(self):
|
| 1026 |
+
def fun_jax(x, y):
|
| 1027 |
+
z = jax2tf.call_tf(tf.math.sin)(x) + jnp.cos(y)
|
| 1028 |
+
return z
|
| 1029 |
+
|
| 1030 |
+
x = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
|
| 1031 |
+
y = np.array([-0.5, 0.0, 0.5], dtype=np.float32)
|
| 1032 |
+
|
| 1033 |
+
converted_fun = tf.function(
|
| 1034 |
+
jax2tf.convert(fun_jax, native_serialization=True)
|
| 1035 |
+
)
|
| 1036 |
+
expected = np.sin(x) + np.cos(y)
|
| 1037 |
+
res = tf.function(converted_fun, jit_compile=True, autograph=False)(x, y)
|
| 1038 |
+
self.assertAllClose(expected, res.numpy(), atol=1e-5, rtol=1e-5)
|
| 1039 |
+
|
| 1040 |
+
@parameterized.named_parameters(
|
| 1041 |
+
dict(
|
| 1042 |
+
testcase_name=f"_{dtype.__name__}",
|
| 1043 |
+
dtype=dtype,
|
| 1044 |
+
)
|
| 1045 |
+
for dtype in set(jtu.dtypes.all_floating)
|
| 1046 |
+
)
|
| 1047 |
+
def test_all_floating_input_gradient(self, dtype):
|
| 1048 |
+
def tf_f(x):
|
| 1049 |
+
res = tf.math.sin(x)
|
| 1050 |
+
return tf.reduce_sum(res)
|
| 1051 |
+
|
| 1052 |
+
jax_f = jax2tf.call_tf(tf_f)
|
| 1053 |
+
tf_f_rt = jax2tf.convert(jax_f)
|
| 1054 |
+
x = jnp.array([5.0, 6.0, 7.0]).astype(dtype)
|
| 1055 |
+
|
| 1056 |
+
def assert_all_close_support_bfloat16(baseline, candidate):
|
| 1057 |
+
def conversion(x):
|
| 1058 |
+
# convert scalar to array and bfloat16 to float32
|
| 1059 |
+
# to support self.assertAllClose numpy array comparison.
|
| 1060 |
+
if x.shape == tf.TensorShape([]):
|
| 1061 |
+
x = tf.convert_to_tensor([x])
|
| 1062 |
+
if dtype == jnp.float16:
|
| 1063 |
+
x = tf.cast(x, tf.float32)
|
| 1064 |
+
return x
|
| 1065 |
+
|
| 1066 |
+
baseline = jax.tree_util.tree_map(conversion, baseline)
|
| 1067 |
+
candidate = jax.tree_util.tree_map(conversion, candidate)
|
| 1068 |
+
tol = (
|
| 1069 |
+
1e-2
|
| 1070 |
+
if jtu.test_device_matches(["tpu"]) and dtype == np.float16
|
| 1071 |
+
else None
|
| 1072 |
+
)
|
| 1073 |
+
self.assertAllClose(baseline, candidate, atol=tol, rtol=tol)
|
| 1074 |
+
|
| 1075 |
+
# Eager mode
|
| 1076 |
+
assert_all_close_support_bfloat16(tf_f(x), tf_f_rt(x))
|
| 1077 |
+
|
| 1078 |
+
# Compiled function mode
|
| 1079 |
+
assert_all_close_support_bfloat16(
|
| 1080 |
+
tf.function(tf_f)(x), tf.function(tf_f_rt)(x)
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# Compiled function mode with jit_compiled=True
|
| 1084 |
+
assert_all_close_support_bfloat16(
|
| 1085 |
+
tf.function(tf_f, jit_compile=True)(x),
|
| 1086 |
+
tf.function(tf_f_rt, jit_compile=True)(x),
|
| 1087 |
+
)
|
| 1088 |
+
|
| 1089 |
+
# RoundTrip test for the gradient
|
| 1090 |
+
grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f))
|
| 1091 |
+
grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))
|
| 1092 |
+
|
| 1093 |
+
# Eager mode
|
| 1094 |
+
assert_all_close_support_bfloat16(grad_fun_jax(x), grad_fun_jax_rt(x))
|
| 1095 |
+
|
| 1096 |
+
# Jit mode
|
| 1097 |
+
assert_all_close_support_bfloat16(
|
| 1098 |
+
jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
@parameterized.named_parameters(
|
| 1102 |
+
dict(
|
| 1103 |
+
testcase_name=f"_{dtype.__name__}",
|
| 1104 |
+
dtype=dtype,
|
| 1105 |
+
)
|
| 1106 |
+
for dtype in set(jtu.dtypes.complex)
|
| 1107 |
+
)
|
| 1108 |
+
def test_complex_input_gradient(self, dtype):
|
| 1109 |
+
def tf_f(x):
|
| 1110 |
+
res = tf.math.sin(x)
|
| 1111 |
+
return tf.reduce_sum(res)
|
| 1112 |
+
|
| 1113 |
+
x = jnp.array([(5.0 + 4.0j), (6.0 + 3.0j), (7.0 + 8.0j)]).astype(dtype)
|
| 1114 |
+
|
| 1115 |
+
jax_f = jax2tf.call_tf(tf_f)
|
| 1116 |
+
tf_f_rt = jax2tf.convert(jax_f)
|
| 1117 |
+
|
| 1118 |
+
# Eager mode
|
| 1119 |
+
self.assertAllClose(tf_f(x), tf_f_rt(x))
|
| 1120 |
+
|
| 1121 |
+
# tf.function context
|
| 1122 |
+
self.assertAllClose(tf.function(tf_f)(x), tf.function(tf_f_rt)(x))
|
| 1123 |
+
|
| 1124 |
+
# tf.function context with jit_compiled=True
|
| 1125 |
+
self.assertAllClose(
|
| 1126 |
+
tf.function(tf_f, jit_compile=True)(x),
|
| 1127 |
+
tf.function(tf_f_rt, jit_compile=True)(x),
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
# RoundTrip test for the gradient
|
| 1131 |
+
grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f), holomorphic=True)
|
| 1132 |
+
grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))
|
| 1133 |
+
|
| 1134 |
+
# Eager mode
|
| 1135 |
+
self.assertAllClose(grad_fun_jax(x), grad_fun_jax_rt(x))
|
| 1136 |
+
|
| 1137 |
+
# Jit mode
|
| 1138 |
+
self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x))
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
|
| 1142 |
+
"Reloading output of call_tf into TF with jax2tf."
|
| 1143 |
+
|
| 1144 |
+
def setUp(self):
|
| 1145 |
+
if tf is None:
|
| 1146 |
+
raise unittest.SkipTest("Test requires tensorflow")
|
| 1147 |
+
# TODO(b/171320191): this line works around a missing context initialization
|
| 1148 |
+
# bug in TensorFlow.
|
| 1149 |
+
_ = tf.add(1, 1)
|
| 1150 |
+
super().setUp()
|
| 1151 |
+
|
| 1152 |
+
def test_alternate(self):
|
| 1153 |
+
# Alternate sin/cos with sin in TF and cos in JAX
|
| 1154 |
+
f_tf_inner = tf.math.sin
|
| 1155 |
+
def f_jax(x_jax):
|
| 1156 |
+
y_jax = jnp.cos(x_jax)
|
| 1157 |
+
z_jax = jax2tf.call_tf(f_tf_inner)(y_jax)
|
| 1158 |
+
return jnp.cos(z_jax)
|
| 1159 |
+
def f_tf_outer(x_tf):
|
| 1160 |
+
y_tf = tf.math.sin(x_tf)
|
| 1161 |
+
z_tf = jax2tf.convert(f_jax)(y_tf)
|
| 1162 |
+
return tf.math.sin(z_tf)
|
| 1163 |
+
|
| 1164 |
+
x = np.float32(0.7)
|
| 1165 |
+
|
| 1166 |
+
self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))),
|
| 1167 |
+
f_tf_outer(x).numpy())
|
| 1168 |
+
xv = tf.Variable(x)
|
| 1169 |
+
with tf.GradientTape() as tape:
|
| 1170 |
+
res = f_tf_outer(xv)
|
| 1171 |
+
g_tf = tape.gradient(res, xv)
|
| 1172 |
+
_, gf = tf_test_util.ComputeTfValueAndGrad(f_tf_outer, (x,))
|
| 1173 |
+
# Eager
|
| 1174 |
+
expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x)))))
|
| 1175 |
+
self.assertAllClose(expected_res, f_tf_outer(x).numpy())
|
| 1176 |
+
|
| 1177 |
+
# Gradient
|
| 1178 |
+
expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) *
|
| 1179 |
+
np.sin(np.sin(np.cos(np.sin(x)))) *
|
| 1180 |
+
np.cos(np.cos(np.sin(x))) *
|
| 1181 |
+
np.sin(np.sin(x)) *
|
| 1182 |
+
np.cos(x))
|
| 1183 |
+
self.assertAllClose(expected_grad, g_tf.numpy())
|
| 1184 |
+
|
| 1185 |
+
# Graph
|
| 1186 |
+
self.assertAllClose(expected_res,
|
| 1187 |
+
tf.function(f_tf_outer, autograph=False)(x).numpy())
|
| 1188 |
+
|
| 1189 |
+
# Compiled
|
| 1190 |
+
self.assertAllClose(expected_res,
|
| 1191 |
+
tf.function(f_tf_outer, autograph=False,
|
| 1192 |
+
jit_compile=True)(x).numpy())
|
| 1193 |
+
|
| 1194 |
+
def test_saved_model(self):
|
| 1195 |
+
x = np.array([.7, .8], dtype=np.float32)
|
| 1196 |
+
def fun_tf(x):
|
| 1197 |
+
return tf.math.sin(x)
|
| 1198 |
+
def fun_jax(x):
|
| 1199 |
+
return jax2tf.call_tf(fun_tf)(x)
|
| 1200 |
+
|
| 1201 |
+
# Now convert and save to SavedModel
|
| 1202 |
+
fun_tf_rt = jax2tf.convert(fun_jax)
|
| 1203 |
+
res = fun_tf_rt(x)
|
| 1204 |
+
self.assertAllClose(np.sin(x), res.numpy())
|
| 1205 |
+
|
| 1206 |
+
res = tf.function(fun_tf_rt, autograph=False)(x)
|
| 1207 |
+
self.assertAllClose(np.sin(x), res.numpy())
|
| 1208 |
+
|
| 1209 |
+
res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x)
|
| 1210 |
+
self.assertAllClose(np.sin(x), res.numpy())
|
| 1211 |
+
|
| 1212 |
+
reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
|
| 1213 |
+
fun_tf_rt, input_args=[x])
|
| 1214 |
+
res = reloaded_f(x)
|
| 1215 |
+
self.assertAllClose(np.sin(x), res.numpy())
|
| 1216 |
+
|
| 1217 |
+
def test_saved_model_polymorphic_input_static_output(self):
|
| 1218 |
+
x = np.array([.7, .8], dtype=np.float32)
|
| 1219 |
+
def fun_tf(x):
|
| 1220 |
+
return tf.math.reduce_sum(tf.math.sin(x))
|
| 1221 |
+
def fun_jax(x):
|
| 1222 |
+
return jax2tf.call_tf(fun_tf)(x)
|
| 1223 |
+
|
| 1224 |
+
# Now convert and save to SavedModel
|
| 1225 |
+
fun_tf_rt = jax2tf.convert(fun_jax)
|
| 1226 |
+
res = fun_tf_rt(x)
|
| 1227 |
+
self.assertAllClose(fun_tf(x), res.numpy())
|
| 1228 |
+
|
| 1229 |
+
res = tf.function(fun_tf_rt, autograph=False)(x)
|
| 1230 |
+
self.assertAllClose(fun_tf(x), res.numpy())
|
| 1231 |
+
|
| 1232 |
+
res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x)
|
| 1233 |
+
self.assertAllClose(fun_tf(x), res.numpy())
|
| 1234 |
+
|
| 1235 |
+
reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
|
| 1236 |
+
fun_tf_rt, input_args=[x])
|
| 1237 |
+
res = reloaded_f(x)
|
| 1238 |
+
self.assertAllClose(fun_tf(x), res.numpy())
|
| 1239 |
+
|
| 1240 |
+
def test_function_dynamic_shape(self):
|
| 1241 |
+
# Call a function for which shape inference does not give an output
|
| 1242 |
+
# shape.
|
| 1243 |
+
x = np.array([-1, 0, 1], dtype=np.int32)
|
| 1244 |
+
def fun_tf(x): # x:i32[3]
|
| 1245 |
+
# The shape depends on the value of x
|
| 1246 |
+
return tf.cond(x[0] >= 0, lambda: x, lambda: x[1:])
|
| 1247 |
+
|
| 1248 |
+
# Call in eager mode. Should work!
|
| 1249 |
+
res1 = jax2tf.call_tf(fun_tf)(x)
|
| 1250 |
+
expected = x[1:]
|
| 1251 |
+
self.assertAllClose(expected, res1, check_dtypes=False)
|
| 1252 |
+
|
| 1253 |
+
# Now under jit, should fail because the function is not compilable
|
| 1254 |
+
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
|
| 1255 |
+
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
|
| 1256 |
+
fun_jax(x)
|
| 1257 |
+
|
| 1258 |
+
# TODO(necula): this should work in op-by-op mode, but it fails because
|
| 1259 |
+
# jax2tf.convert does abstract evaluation.
|
| 1260 |
+
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
|
| 1261 |
+
fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
|
| 1262 |
+
fun_tf_rt(x)
|
| 1263 |
+
|
| 1264 |
+
@_parameterized_jit
|
| 1265 |
+
def test_shape_poly_static_output_shape(self, with_jit=True):
|
| 1266 |
+
if jax.config.jax2tf_default_native_serialization:
|
| 1267 |
+
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
| 1268 |
+
x = np.array([0.7, 0.8], dtype=np.float32)
|
| 1269 |
+
|
| 1270 |
+
def fun_tf(x):
|
| 1271 |
+
return tf.math.reduce_sum(tf.math.sin(x))
|
| 1272 |
+
|
| 1273 |
+
fun_jax = jax2tf.call_tf(fun_tf)
|
| 1274 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1275 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1276 |
+
self.assertAllClose(fun_tf(x), fun_tf_rt(x))
|
| 1277 |
+
|
| 1278 |
+
@_parameterized_jit
|
| 1279 |
+
def test_shape_poly(self, with_jit=False):
|
| 1280 |
+
if jax.config.jax2tf_default_native_serialization:
|
| 1281 |
+
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
| 1282 |
+
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
| 1283 |
+
def fun_jax(x):
|
| 1284 |
+
y = jax2tf.call_tf(tf.math.sin,
|
| 1285 |
+
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x)
|
| 1286 |
+
z = jnp.cos(y)
|
| 1287 |
+
w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0),
|
| 1288 |
+
output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z)
|
| 1289 |
+
assert w.shape[0] == 2 * x.shape[0]
|
| 1290 |
+
return w
|
| 1291 |
+
|
| 1292 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1293 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1294 |
+
res_tf = fun_tf_rt(x)
|
| 1295 |
+
self.assertAllClose(fun_jax(x), res_tf)
|
| 1296 |
+
|
| 1297 |
+
@_parameterized_jit
|
| 1298 |
+
def test_shape_poly_pytree_result(self, with_jit=True):
|
| 1299 |
+
if jax.config.jax2tf_default_native_serialization:
|
| 1300 |
+
raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
|
| 1301 |
+
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
| 1302 |
+
def fun_jax(x):
|
| 1303 |
+
# Returns a tuple
|
| 1304 |
+
y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)),
|
| 1305 |
+
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
|
| 1306 |
+
jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x)
|
| 1307 |
+
assert y[0].shape[0] == x.shape[0]
|
| 1308 |
+
assert y[1].shape[0] == 2 * x.shape[0]
|
| 1309 |
+
return y
|
| 1310 |
+
|
| 1311 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1312 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1313 |
+
res_tf = fun_tf_rt(x)
|
| 1314 |
+
self.assertAllClose(fun_jax(x), res_tf)
|
| 1315 |
+
|
| 1316 |
+
@_parameterized_jit
|
| 1317 |
+
def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True):
|
| 1318 |
+
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
| 1319 |
+
def fun_jax(x):
|
| 1320 |
+
return jax2tf.call_tf(tf.math.sin)(x)
|
| 1321 |
+
|
| 1322 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1323 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1324 |
+
with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
|
| 1325 |
+
fun_tf_rt(x)
|
| 1326 |
+
|
| 1327 |
+
@_parameterized_jit
|
| 1328 |
+
def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False):
|
| 1329 |
+
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
| 1330 |
+
def fun_jax(x):
|
| 1331 |
+
return jax2tf.call_tf(tf.math.sin,
|
| 1332 |
+
output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
|
| 1333 |
+
jax.ShapeDtypeStruct(x.shape, x.dtype)))(x)
|
| 1334 |
+
|
| 1335 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1336 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1337 |
+
|
| 1338 |
+
with self.assertRaisesRegex(
|
| 1339 |
+
ValueError,
|
| 1340 |
+
"The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"):
|
| 1341 |
+
fun_tf_rt(x)
|
| 1342 |
+
|
| 1343 |
+
@parameterized.named_parameters(
|
| 1344 |
+
_named_test(with_jit=with_jit, kind=kind)
|
| 1345 |
+
for with_jit in [True, False]
|
| 1346 |
+
for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"])
|
| 1347 |
+
def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"):
|
| 1348 |
+
x = np.array([7, 8, 9, 10], dtype=np.float32)
|
| 1349 |
+
|
| 1350 |
+
if kind == "bad_rank":
|
| 1351 |
+
def fun_jax(x):
|
| 1352 |
+
return jax2tf.call_tf(lambda x: x,
|
| 1353 |
+
# Wrong shape rank
|
| 1354 |
+
output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x)
|
| 1355 |
+
elif kind == "bad_dim":
|
| 1356 |
+
def fun_jax(x):
|
| 1357 |
+
bad_shape = (5 + x.shape[0],)
|
| 1358 |
+
y = jax2tf.call_tf(lambda x: x,
|
| 1359 |
+
# Wrong dimension
|
| 1360 |
+
output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x)
|
| 1361 |
+
# JAX will believe that the following is Ok, leading to downstream error in TF
|
| 1362 |
+
return y + jnp.ones(bad_shape, dtype=x.dtype)
|
| 1363 |
+
elif kind == "bad_dtype":
|
| 1364 |
+
def fun_jax(x):
|
| 1365 |
+
return jax2tf.call_tf(lambda x: x,
|
| 1366 |
+
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x)
|
| 1367 |
+
elif kind == "bad_dtype_x64":
|
| 1368 |
+
def fun_jax(x):
|
| 1369 |
+
return jax2tf.call_tf(lambda x: x * np.float64(3.),
|
| 1370 |
+
output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x)
|
| 1371 |
+
else:
|
| 1372 |
+
assert False
|
| 1373 |
+
expect_ex = ValueError
|
| 1374 |
+
expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype"
|
| 1375 |
+
|
| 1376 |
+
# Call without shape polymorphism
|
| 1377 |
+
fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax))
|
| 1378 |
+
with self.assertRaisesRegex(expect_ex, expect_error):
|
| 1379 |
+
fun_tf_rt(x)
|
| 1380 |
+
|
| 1381 |
+
# Now with shape polymorphism
|
| 1382 |
+
if kind == "bad_dim" and with_jit:
|
| 1383 |
+
# TODO: in jit more the error pops up later, at AddV2
|
| 1384 |
+
expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
|
| 1385 |
+
if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
|
| 1386 |
+
# TODO(b/268386622): call_tf with shape polymorphism and native serialization.
|
| 1387 |
+
expect_error = "Error compiling TensorFlow function"
|
| 1388 |
+
fun_tf_rt = _maybe_tf_jit(with_jit,
|
| 1389 |
+
jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
|
| 1390 |
+
with self.assertRaisesRegex(expect_ex, expect_error):
|
| 1391 |
+
fun_tf_rt(x)
|
| 1392 |
+
|
| 1393 |
+
def test_inner_native_serialization(self):
|
| 1394 |
+
# Two nested jax2tf, the inner one being with native serialization
|
| 1395 |
+
x = np.ones((3,), dtype=np.float32)
|
| 1396 |
+
def f_inner_jax(x):
|
| 1397 |
+
return jnp.sin(x)
|
| 1398 |
+
def f_outer_jax(x):
|
| 1399 |
+
f_inner_tf = jax2tf.convert(f_inner_jax, native_serialization=True)
|
| 1400 |
+
return jnp.cos(jax2tf.call_tf(f_inner_tf)(x))
|
| 1401 |
+
|
| 1402 |
+
f_outer_tf = tf.function(
|
| 1403 |
+
jax2tf.convert(f_outer_jax, native_serialization=False),
|
| 1404 |
+
autograph=False)
|
| 1405 |
+
f_outer_graph = str(f_outer_tf.get_concrete_function(tf.convert_to_tensor(x)).graph.as_graph_def())
|
| 1406 |
+
# Quick way to check that there is an XlaCallModule op, and a Cos op, but no Sin op
|
| 1407 |
+
self.assertIn('op: "Cos"', f_outer_graph)
|
| 1408 |
+
self.assertIn('op: "XlaCallModule"', f_outer_graph)
|
| 1409 |
+
self.assertNotIn('op: "Sin"', f_outer_graph)
|
| 1410 |
+
|
| 1411 |
+
@parameterized.named_parameters(
|
| 1412 |
+
_named_test(f2_function=f2_function, f2_saved_model=f2_saved_model,
|
| 1413 |
+
f4_function=f4_function, f4_saved_model=f4_saved_model)
|
| 1414 |
+
for f2_function in [True, False]
|
| 1415 |
+
for f2_saved_model in [True, False]
|
| 1416 |
+
for f4_function in [True, False]
|
| 1417 |
+
for f4_saved_model in [True, False])
|
| 1418 |
+
def test_several_round_trips(self,
|
| 1419 |
+
f2_function=False, f2_saved_model=False,
|
| 1420 |
+
f4_function=False, f4_saved_model=False):
|
| 1421 |
+
if (f2_saved_model and
|
| 1422 |
+
f4_saved_model and
|
| 1423 |
+
not jax.config.jax2tf_default_native_serialization):
|
| 1424 |
+
# TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
|
| 1425 |
+
# when saving f4, but only with non-native serialization.
|
| 1426 |
+
raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")
|
| 1427 |
+
x = np.array(.7, dtype=np.float32)
|
| 1428 |
+
# f(n)(x) = 2. * x^n
|
| 1429 |
+
def f(n):
|
| 1430 |
+
def fn(x):
|
| 1431 |
+
acc = np.array(2., dtype=x.dtype)
|
| 1432 |
+
for i in range(n):
|
| 1433 |
+
acc *= x
|
| 1434 |
+
return acc
|
| 1435 |
+
return fn
|
| 1436 |
+
|
| 1437 |
+
f2_tf = lambda x: x * jax2tf.convert(f(1))(x)
|
| 1438 |
+
if f2_function:
|
| 1439 |
+
f2_tf = tf.function(f2_tf, autograph=False)
|
| 1440 |
+
if f2_saved_model:
|
| 1441 |
+
f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x])
|
| 1442 |
+
|
| 1443 |
+
self.assertAllClose(f(2)(x), f2_tf(x).numpy())
|
| 1444 |
+
_, (g_f2_ft,) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x])
|
| 1445 |
+
self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy())
|
| 1446 |
+
|
| 1447 |
+
f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x)
|
| 1448 |
+
self.assertAllClose(f(3)(x), f3_jax(x))
|
| 1449 |
+
self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x))
|
| 1450 |
+
self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x))
|
| 1451 |
+
|
| 1452 |
+
f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x)
|
| 1453 |
+
self.assertAllClose(f(4)(x), f4_tf(x).numpy())
|
| 1454 |
+
_, (g_f4_ft,) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
|
| 1455 |
+
self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
|
| 1456 |
+
|
| 1457 |
+
if f4_function:
|
| 1458 |
+
f4_tf = tf.function(f4_tf, autograph=False)
|
| 1459 |
+
if f4_saved_model:
|
| 1460 |
+
f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x])
|
| 1461 |
+
self.assertAllClose(f(4)(x), f4_tf(x).numpy())
|
| 1462 |
+
_, (g_f4_ft,) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
|
| 1463 |
+
self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
|
| 1464 |
+
|
| 1465 |
+
@classmethod
|
| 1466 |
+
def _walk_stablehlo_operations(cls, op, cb):
|
| 1467 |
+
"""walk the stablehlo operation recursive with callback function."""
|
| 1468 |
+
cb(op)
|
| 1469 |
+
for region in op.operation.regions:
|
| 1470 |
+
for block in region:
|
| 1471 |
+
for op in block:
|
| 1472 |
+
cls._walk_stablehlo_operations(op, cb)
|
| 1473 |
+
|
| 1474 |
+
def test_call_tf_graph(self):
|
| 1475 |
+
const = tf.Variable(0.0, dtype=tf.float32)
|
| 1476 |
+
|
| 1477 |
+
@tf.function(jit_compile=True)
|
| 1478 |
+
def tf_func_1(x):
|
| 1479 |
+
return x * x + const
|
| 1480 |
+
|
| 1481 |
+
@tf.function
|
| 1482 |
+
def tf_func_2(x, y):
|
| 1483 |
+
return tf_func_1(x) + y
|
| 1484 |
+
|
| 1485 |
+
@tf.function
|
| 1486 |
+
def tf_func_3(x, y, z):
|
| 1487 |
+
return tf_func_2(x, y) + z, z
|
| 1488 |
+
|
| 1489 |
+
x = jnp.array(3.0, dtype=jnp.float32)
|
| 1490 |
+
y = jnp.array(3.0, dtype=jnp.float32)
|
| 1491 |
+
z = jnp.array(5.0, dtype=jnp.float32)
|
| 1492 |
+
f_jax = jax.jit(jax2tf.call_tf(tf_func_3, call_tf_graph=False))
|
| 1493 |
+
stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo")
|
| 1494 |
+
self.assertNotIn("stablehlo.custom_call", str(stablehlo_module))
|
| 1495 |
+
|
| 1496 |
+
f_jax = jax.jit(
|
| 1497 |
+
jax2tf.call_tf(
|
| 1498 |
+
tf_func_3,
|
| 1499 |
+
call_tf_graph=True,
|
| 1500 |
+
)
|
| 1501 |
+
)
|
| 1502 |
+
with self.assertRaisesRegex(
|
| 1503 |
+
ValueError,
|
| 1504 |
+
"call_tf_graph=True only support exporting by jax2tf.convert currently",
|
| 1505 |
+
):
|
| 1506 |
+
stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo")
|
| 1507 |
+
self.assertIn("stablehlo.custom_call", str(stablehlo_module))
|
| 1508 |
+
|
| 1509 |
+
called_index_list = []
|
| 1510 |
+
|
| 1511 |
+
def _extract_info(op):
|
| 1512 |
+
if op.operation.name != "stablehlo.custom_call":
|
| 1513 |
+
return
|
| 1514 |
+
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
|
| 1515 |
+
called_index = ir.IntegerAttr(tf_backend_config["called_index"]).value
|
| 1516 |
+
called_index_list.append(called_index)
|
| 1517 |
+
|
| 1518 |
+
self._walk_stablehlo_operations(stablehlo_module, _extract_info)
|
| 1519 |
+
self.assertLen(called_index_list, 1)
|
| 1520 |
+
|
| 1521 |
+
@parameterized.named_parameters(
|
| 1522 |
+
dict(
|
| 1523 |
+
testcase_name="multiple_outputs",
|
| 1524 |
+
tf_f=lambda x: tf.py_function(np.sin, [x], tf.float32),
|
| 1525 |
+
output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32),
|
| 1526 |
+
),
|
| 1527 |
+
dict(
|
| 1528 |
+
testcase_name="zero_outputs",
|
| 1529 |
+
tf_f=lambda x: print(tf.strings.length(tf.constant("hello, world"))),
|
| 1530 |
+
output_shape_dtype=None,
|
| 1531 |
+
),
|
| 1532 |
+
)
|
| 1533 |
+
def test_call_tf_graph_non_compilable(self, tf_f, output_shape_dtype):
|
| 1534 |
+
inputs = jnp.ones([10], dtype=jnp.float32)
|
| 1535 |
+
called_index_list = []
|
| 1536 |
+
xla_call_module_list = []
|
| 1537 |
+
|
| 1538 |
+
def _extract_info(op):
|
| 1539 |
+
if op.operation.name != "stablehlo.custom_call":
|
| 1540 |
+
return
|
| 1541 |
+
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
|
| 1542 |
+
called_index = ir.IntegerAttr(tf_backend_config["called_index"]).value
|
| 1543 |
+
called_index_list.append(called_index)
|
| 1544 |
+
|
| 1545 |
+
jax_f = jax2tf.call_tf(
|
| 1546 |
+
tf_f,
|
| 1547 |
+
call_tf_graph=True,
|
| 1548 |
+
output_shape_dtype=output_shape_dtype,
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
# Eager mode
|
| 1552 |
+
self.assertAllClose(tf_f(inputs), jax_f(inputs))
|
| 1553 |
+
|
| 1554 |
+
# Jit mode
|
| 1555 |
+
stablehlo_module = None
|
| 1556 |
+
with self.assertRaisesRegex(
|
| 1557 |
+
ValueError,
|
| 1558 |
+
"call_tf_graph=True only support exporting by jax2tf.convert currently",
|
| 1559 |
+
):
|
| 1560 |
+
stablehlo_module = jax.jit(jax_f).lower(inputs).compiler_ir("stablehlo")
|
| 1561 |
+
if stablehlo_module:
|
| 1562 |
+
self.assertIn(
|
| 1563 |
+
"stablehlo.custom_call @tf.call_tf_function",
|
| 1564 |
+
str(stablehlo_module),
|
| 1565 |
+
)
|
| 1566 |
+
self.assertIn("tf.backend_config", str(stablehlo_module))
|
| 1567 |
+
self._walk_stablehlo_operations(stablehlo_module, _extract_info)
|
| 1568 |
+
self.assertLen(called_index_list, 1)
|
| 1569 |
+
|
| 1570 |
+
# Test model exporting and reloading.
|
| 1571 |
+
# There is no runtime support yet so it can not run.
|
| 1572 |
+
tf_f_rt = jax2tf.convert(
|
| 1573 |
+
jax_f,
|
| 1574 |
+
native_serialization=True,
|
| 1575 |
+
with_gradient=False,
|
| 1576 |
+
)
|
| 1577 |
+
_, restored_model = tf_test_util.SaveAndLoadFunction(
|
| 1578 |
+
tf_f_rt, input_args=[inputs]
|
| 1579 |
+
)
|
| 1580 |
+
func_def = restored_model.f.concrete_functions[0]
|
| 1581 |
+
|
| 1582 |
+
for node_def in func_def.graph.as_graph_def().node:
|
| 1583 |
+
if node_def.op == "XlaCallModule":
|
| 1584 |
+
xla_call_module_list.append(node_def)
|
| 1585 |
+
# There is only one xla_call_module in the saved model.
|
| 1586 |
+
self.assertLen(xla_call_module_list, 1)
|
| 1587 |
+
|
| 1588 |
+
# Check the xla_call_module version and function_list attributes.
|
| 1589 |
+
xla_call_module = xla_call_module_list[0]
|
| 1590 |
+
self.assertGreaterEqual(xla_call_module.attr["version"].i, 5)
|
| 1591 |
+
self.assertIn("function_list", str(xla_call_module.attr))
|
| 1592 |
+
xla_call_module_list.clear()
|
| 1593 |
+
called_index_list.clear()
|
| 1594 |
+
|
| 1595 |
+
# If JAX calls same tensorflow function by `jax2tf.call_tf` twice,
|
| 1596 |
+
# it should return two different tf concrete functions.
|
| 1597 |
+
def jax_f_2(x):
|
| 1598 |
+
res1 = jax2tf.call_tf(
|
| 1599 |
+
tf_f,
|
| 1600 |
+
call_tf_graph=True,
|
| 1601 |
+
output_shape_dtype=output_shape_dtype,
|
| 1602 |
+
)(x)
|
| 1603 |
+
res2 = jax2tf.call_tf(
|
| 1604 |
+
tf_f,
|
| 1605 |
+
call_tf_graph=True,
|
| 1606 |
+
output_shape_dtype=output_shape_dtype,
|
| 1607 |
+
)(x)
|
| 1608 |
+
return res1, res2
|
| 1609 |
+
stablehlo_module = None
|
| 1610 |
+
with self.assertRaisesRegex(ValueError, "call_tf_graph=True only support exporting by jax2tf.convert currently"):
|
| 1611 |
+
stablehlo_module = jax.jit(jax_f_2).lower(inputs).compiler_ir("stablehlo")
|
| 1612 |
+
if stablehlo_module:
|
| 1613 |
+
self._walk_stablehlo_operations(stablehlo_module, _extract_info)
|
| 1614 |
+
xla_call_module_list.clear()
|
| 1615 |
+
|
| 1616 |
+
def test_b279454591(self):
|
| 1617 |
+
"""Test case when tensorflow function returns `StatefulPartitionedCall` op."""
|
| 1618 |
+
inputs = jnp.ones([10], dtype=jnp.float32)
|
| 1619 |
+
|
| 1620 |
+
# With one or more outputs, it is okay.
|
| 1621 |
+
def tf_f(x):
|
| 1622 |
+
y = tf.math.sin(3.0)
|
| 1623 |
+
tf.print(y)
|
| 1624 |
+
return x
|
| 1625 |
+
|
| 1626 |
+
jax_f = jax2tf.call_tf(
|
| 1627 |
+
tf.function(tf_f),
|
| 1628 |
+
call_tf_graph=True,
|
| 1629 |
+
)
|
| 1630 |
+
tf_f_rt = jax2tf.convert(
|
| 1631 |
+
jax_f,
|
| 1632 |
+
native_serialization=True,
|
| 1633 |
+
with_gradient=False,
|
| 1634 |
+
)
|
| 1635 |
+
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt, input_args=[inputs])
|
| 1636 |
+
|
| 1637 |
+
# With zero output, it return `StatefulPartitionedCall` op instead.
|
| 1638 |
+
def tf_f_2():
|
| 1639 |
+
y = tf.math.sin(3.0)
|
| 1640 |
+
tf.print(y)
|
| 1641 |
+
return
|
| 1642 |
+
|
| 1643 |
+
jax_f_2 = jax2tf.call_tf(tf.function(tf_f_2), call_tf_graph=True)
|
| 1644 |
+
tf_f_rt_2 = jax2tf.convert(
|
| 1645 |
+
jax_f_2,
|
| 1646 |
+
native_serialization=True,
|
| 1647 |
+
with_gradient=False,
|
| 1648 |
+
)
|
| 1649 |
+
_, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
|
| 1650 |
+
|
| 1651 |
+
@jtu.parameterized_filterable(
|
| 1652 |
+
kwargs=[dict(version=version) for version in [9]]
|
| 1653 |
+
)
|
| 1654 |
+
def test_call_tf_graph_ordered(self, *, version: int):
|
| 1655 |
+
with config.jax_serialization_version(version):
|
| 1656 |
+
logging.info(
|
| 1657 |
+
"Using JAX serialization version %s",
|
| 1658 |
+
jax.config.jax_serialization_version)
|
| 1659 |
+
|
| 1660 |
+
@tf.function
|
| 1661 |
+
def tf_print(x):
|
| 1662 |
+
tf.print(x)
|
| 1663 |
+
|
| 1664 |
+
call_tf_print = jax2tf.call_tf(
|
| 1665 |
+
tf_print,
|
| 1666 |
+
call_tf_graph=True,
|
| 1667 |
+
ordered=True,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
x = jnp.array(1.0, dtype=jnp.float32)
|
| 1671 |
+
|
| 1672 |
+
def body(i, x):
|
| 1673 |
+
call_tf_print(x)
|
| 1674 |
+
return x + 1
|
| 1675 |
+
|
| 1676 |
+
@jax.jit
|
| 1677 |
+
def f_jax(x):
|
| 1678 |
+
return jax.lax.fori_loop(0, 4, body, x)
|
| 1679 |
+
|
| 1680 |
+
num_custom_calls = 0
|
| 1681 |
+
|
| 1682 |
+
def _check_mlir_ops(op):
|
| 1683 |
+
nonlocal num_custom_calls
|
| 1684 |
+
|
| 1685 |
+
if (
|
| 1686 |
+
op.operation.name == "stablehlo.custom_call"
|
| 1687 |
+
and ir.StringAttr(op.attributes["call_target_name"]).value
|
| 1688 |
+
== "tf.call_tf_function"
|
| 1689 |
+
):
|
| 1690 |
+
num_custom_calls += 1
|
| 1691 |
+
|
| 1692 |
+
# The custom call op must have `has_token_input_output` attribute.
|
| 1693 |
+
tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
|
| 1694 |
+
self.assertTrue(
|
| 1695 |
+
ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
|
| 1696 |
+
)
|
| 1697 |
+
|
| 1698 |
+
# Verify that the first argument/result of the custom call op is a token
|
| 1699 |
+
# type. This is a calling convention defined by `has_token_input_output`.
|
| 1700 |
+
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
| 1701 |
+
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
| 1702 |
+
|
| 1703 |
+
stablehlo_module = None
|
| 1704 |
+
with self.assertRaisesRegex(
|
| 1705 |
+
ValueError,
|
| 1706 |
+
"call_tf_graph=True only support exporting by jax2tf.convert currently",
|
| 1707 |
+
):
|
| 1708 |
+
lower = f_jax.lower(x)
|
| 1709 |
+
self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
|
| 1710 |
+
stablehlo_module = lower.compiler_ir("stablehlo")
|
| 1711 |
+
if stablehlo_module:
|
| 1712 |
+
self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
|
| 1713 |
+
self.assertEqual(num_custom_calls, 1)
|
| 1714 |
+
|
| 1715 |
+
f_tf = jax2tf.convert(
|
| 1716 |
+
f_jax,
|
| 1717 |
+
native_serialization=True,
|
| 1718 |
+
with_gradient=False,
|
| 1719 |
+
)
|
| 1720 |
+
_, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
|
| 1721 |
+
|
| 1722 |
+
@jtu.parameterized_filterable(
|
| 1723 |
+
kwargs=[dict(poly=poly, version=version)
|
| 1724 |
+
for poly in [True, False]
|
| 1725 |
+
for version in [9]]
|
| 1726 |
+
)
|
| 1727 |
+
def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
|
| 1728 |
+
with config.jax_serialization_version(version):
|
| 1729 |
+
logging.info(
|
| 1730 |
+
"Using JAX serialization version %s",
|
| 1731 |
+
jax.config.jax_serialization_version)
|
| 1732 |
+
def f_jax(x1, x_dead, x3):
|
| 1733 |
+
return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
|
| 1734 |
+
call_tf_graph=True)(x3))
|
| 1735 |
+
if poly:
|
| 1736 |
+
polymorphic_shapes = ["b", None, None]
|
| 1737 |
+
else:
|
| 1738 |
+
polymorphic_shapes = None
|
| 1739 |
+
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
|
| 1740 |
+
x1 = np.arange(3, dtype=np.float32)
|
| 1741 |
+
x_dead = np.arange(4, dtype=np.float32)
|
| 1742 |
+
x3 = np.arange(5, dtype=np.float32)
|
| 1743 |
+
self.assertAllClose(f_jax(x1, x_dead, x3),
|
| 1744 |
+
f_tf(x1, x_dead, x3))
|
| 1745 |
+
|
| 1746 |
+
@jtu.parameterized_filterable(
|
| 1747 |
+
kwargs=[dict(ordered=ordered, version=version)
|
| 1748 |
+
for ordered in [True, False]
|
| 1749 |
+
for version in [9]
|
| 1750 |
+
]
|
| 1751 |
+
)
|
| 1752 |
+
def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
|
| 1753 |
+
with config.jax_serialization_version(version):
|
| 1754 |
+
logging.info(
|
| 1755 |
+
"Using JAX serialization version %s",
|
| 1756 |
+
jax.config.jax_serialization_version)
|
| 1757 |
+
|
| 1758 |
+
@tf.function(jit_compile=True, autograph=False)
|
| 1759 |
+
@partial(jax2tf.convert,
|
| 1760 |
+
with_gradient=False,
|
| 1761 |
+
native_serialization=True,
|
| 1762 |
+
polymorphic_shapes=["(b)"])
|
| 1763 |
+
@jax.jit
|
| 1764 |
+
def tf_f_2(x):
|
| 1765 |
+
tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
|
| 1766 |
+
jax2tf.call_tf(tf_f,
|
| 1767 |
+
call_tf_graph=True,
|
| 1768 |
+
ordered=ordered)(x)
|
| 1769 |
+
return x
|
| 1770 |
+
|
| 1771 |
+
x = np.arange(3, dtype=np.int32)
|
| 1772 |
+
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
|
| 1773 |
+
|
| 1774 |
+
# TODO(b/293927250): call_tf_graph=True only accept concrete_function. The
|
| 1775 |
+
# workaround here is to set `module.call=concrete_fn.`.
|
| 1776 |
+
@unittest.skip(
|
| 1777 |
+
"The root cause here is because the XLACallModule.function_list attribute"
|
| 1778 |
+
" depends on JAX call_tf lowering. The 2nd time tf.SavedModel TF tracing"
|
| 1779 |
+
" will not trigger call_tf tracing since it was already cached. The"
|
| 1780 |
+
" solution is to create the `CallTFContext` to make TF tracing and JAX"
|
| 1781 |
+
" tracing work together correctly."
|
| 1782 |
+
)
|
| 1783 |
+
def test_call_tf_graph_save_and_load(self):
|
| 1784 |
+
def jax_func(x):
|
| 1785 |
+
def func_tf(x):
|
| 1786 |
+
return tf.math.sin(x)
|
| 1787 |
+
|
| 1788 |
+
return jnp.cos(
|
| 1789 |
+
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
|
| 1790 |
+
)
|
| 1791 |
+
data_inputs = (np.array([0.5, 0.7], dtype=np.float32),)
|
| 1792 |
+
|
| 1793 |
+
def tf_func(the_input):
|
| 1794 |
+
res = jax2tf.convert(jax_func, native_serialization=True)(the_input)
|
| 1795 |
+
return tf.identity(res, name="the_result")
|
| 1796 |
+
|
| 1797 |
+
jit_tf_func = tf.function(
|
| 1798 |
+
tf_func,
|
| 1799 |
+
autograph=False,
|
| 1800 |
+
jit_compile=True,
|
| 1801 |
+
)
|
| 1802 |
+
# The next line is necessary to reproduce this issue. It trigger TF
|
| 1803 |
+
# ConcreteFunction tracing. Otherwise, you will fail with another error
|
| 1804 |
+
# `Found zero restored functions for caller function`.
|
| 1805 |
+
_ = jit_tf_func.get_concrete_function(*data_inputs)
|
| 1806 |
+
module = tf.Module()
|
| 1807 |
+
module.call = jit_tf_func # Switching to concrete_function works.
|
| 1808 |
+
root_dir = self.create_tempdir()
|
| 1809 |
+
saved_model_dir = os.path.join(root_dir, "saved_model")
|
| 1810 |
+
tf.saved_model.save(
|
| 1811 |
+
module,
|
| 1812 |
+
saved_model_dir,
|
| 1813 |
+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
|
| 1814 |
+
)
|
| 1815 |
+
loaded_model = tf.saved_model.load(saved_model_dir)
|
| 1816 |
+
res = loaded_model.call(*data_inputs)
|
| 1817 |
+
self.assertAllClose(jax_func(*data_inputs), res)
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
if __name__ == "__main__":
|
| 1821 |
+
absltest.main(testLoader=jtu.JaxTestLoader())
|