ZAIDX11 commited on
Commit
e868c41
·
verified ·
1 Parent(s): 425d1bb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/BEq.lean +155 -0
  3. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/Basic.lean +134 -0
  4. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/DecEq.lean +212 -0
  5. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/FromToJson.lean +249 -0
  6. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Term.lean +2128 -0
  7. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Time.lean +26 -0
  8. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Util.lean +249 -0
  9. backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/WhereFinally.lean +29 -0
  10. external/alphageometry/.venv-ag/Lib/site-packages/absl/app.py +488 -0
  11. external/alphageometry/.venv-ag/Lib/site-packages/absl/app.pyi +88 -0
  12. external/alphageometry/.venv-ag/Lib/site-packages/absl/command_name.py +63 -0
  13. external/alphageometry/.venv-ag/Lib/site-packages/absl/py.typed +0 -0
  14. external/alphageometry/.venv-ag/Lib/site-packages/distutils-precedence.pth +1 -0
  15. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/__init__.py +27 -0
  16. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/__init__.py +213 -0
  17. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_array_methods.py +45 -0
  18. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_creation_functions.py +31 -0
  19. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_data_type_functions.py +78 -0
  20. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_elementwise_functions.py +75 -0
  21. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_fft_functions.py +25 -0
  22. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_linear_algebra_functions.py +28 -0
  23. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_manipulation_functions.py +25 -0
  24. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_statistical_functions.py +25 -0
  25. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_utility_functions.py +86 -0
  26. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_version.py +15 -0
  27. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/fft.py +33 -0
  28. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/linalg.py +43 -0
  29. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/__init__.py +13 -0
  30. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization.py +635 -0
  31. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization_test.py +493 -0
  32. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/__init__.py +13 -0
  33. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/compilation_cache.py +20 -0
  34. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/export/__init__.py +36 -0
  35. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/__init__.py +23 -0
  36. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/call_tf.py +682 -0
  37. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/__init__.py +13 -0
  38. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main.py +78 -0
  39. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +50 -0
  40. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/mnist_lib.py +324 -0
  41. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py +154 -0
  42. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main.py +210 -0
  43. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main_test.py +70 -0
  44. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/__init__.py +13 -0
  45. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/model_server_request.py +128 -0
  46. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/impl_no_xla.py +1287 -0
  47. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/jax2tf.py +0 -0
  48. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/__init__.py +13 -0
  49. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/back_compat_tf_test.py +154 -0
  50. external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/call_tf_test.py +1821 -0
.gitattributes CHANGED
@@ -4510,3 +4510,5 @@ external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64-arm.ex
4510
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
4511
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
4512
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
 
 
 
4510
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
4511
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text
4512
  external/alphageometry/.venv-ag/Lib/site-packages/pip/_vendor/distlib/w64.exe filter=lfs diff=lfs merge=lfs -text
4513
+ hfenv/Lib/site-packages/setuptools/cli-arm64.exe filter=lfs diff=lfs merge=lfs -text
4514
+ hfenv/Lib/site-packages/setuptools/gui-arm64.exe filter=lfs diff=lfs merge=lfs -text
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/BEq.lean ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Leonardo de Moura
5
+ -/
6
+ prelude
7
+ import Lean.Meta.Transform
8
+ import Lean.Elab.Deriving.Basic
9
+ import Lean.Elab.Deriving.Util
10
+
11
+ namespace Lean.Elab.Deriving.BEq
12
+ open Lean.Parser.Term
13
+ open Meta
14
+
15
+ def mkBEqHeader (indVal : InductiveVal) : TermElabM Header := do
16
+ mkHeader `BEq 2 indVal
17
+
18
+ def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
19
+ let discrs ← mkDiscrs header indVal
20
+ let alts ← mkAlts
21
+ `(match $[$discrs],* with $alts:matchAlt*)
22
+ where
23
+ mkElseAlt : TermElabM (TSyntax ``matchAltExpr) := do
24
+ let mut patterns := #[]
25
+ -- add `_` pattern for indices
26
+ for _ in [:indVal.numIndices] do
27
+ patterns := patterns.push (← `(_))
28
+ patterns := patterns.push (← `(_))
29
+ patterns := patterns.push (← `(_))
30
+ let altRhs ← `(false)
31
+ `(matchAltExpr| | $[$patterns:term],* => $altRhs:term)
32
+
33
+ mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
34
+ let mut alts := #[]
35
+ for ctorName in indVal.ctors do
36
+ let ctorInfo ← getConstInfoCtor ctorName
37
+ let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
38
+ let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
39
+ let mut patterns := #[]
40
+ -- add `_` pattern for indices
41
+ for _ in [:indVal.numIndices] do
42
+ patterns := patterns.push (← `(_))
43
+ let mut ctorArgs1 := #[]
44
+ let mut ctorArgs2 := #[]
45
+ let mut rhs ← `(true)
46
+ let mut rhs_empty := true
47
+ for i in [:ctorInfo.numFields] do
48
+ let pos := indVal.numParams + ctorInfo.numFields - i - 1
49
+ let x := xs[pos]!
50
+ if type.containsFVar x.fvarId! then
51
+ -- If resulting type depends on this field, we don't need to compare
52
+ ctorArgs1 := ctorArgs1.push (← `(_))
53
+ ctorArgs2 := ctorArgs2.push (← `(_))
54
+ else
55
+ let a := mkIdent (← mkFreshUserName `a)
56
+ let b := mkIdent (← mkFreshUserName `b)
57
+ ctorArgs1 := ctorArgs1.push a
58
+ ctorArgs2 := ctorArgs2.push b
59
+ let xType ← inferType x
60
+ if (← isProp xType) then
61
+ continue
62
+ if xType.isAppOf indVal.name then
63
+ if rhs_empty then
64
+ rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident)
65
+ rhs_empty := false
66
+ else
67
+ rhs ← `($(mkIdent auxFunName):ident $a:ident $b:ident && $rhs)
68
+ /- If `x` appears in the type of another field, use `eq_of_beq` to
69
+ unify the types of the subsequent variables -/
70
+ else if ← xs[(pos+1)...*].anyM
71
+ (fun fvar => (Expr.containsFVar · x.fvarId!) <$> (inferType fvar)) then
72
+ rhs ← `(if h : $a:ident == $b:ident then by
73
+ cases (eq_of_beq h)
74
+ exact $rhs
75
+ else false)
76
+ rhs_empty := false
77
+ else
78
+ if rhs_empty then
79
+ rhs ← `($a:ident == $b:ident)
80
+ rhs_empty := false
81
+ else
82
+ rhs ← `($a:ident == $b:ident && $rhs)
83
+ -- add `_` for inductive parameters, they are inaccessible
84
+ for _ in [:indVal.numParams] do
85
+ ctorArgs1 := ctorArgs1.push (← `(_))
86
+ ctorArgs2 := ctorArgs2.push (← `(_))
87
+ patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs1.reverse:term*))
88
+ patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2.reverse:term*))
89
+ `(matchAltExpr| | $[$patterns:term],* => $rhs:term)
90
+ alts := alts.push alt
91
+ alts := alts.push (← mkElseAlt)
92
+ return alts
93
+
94
+ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
95
+ let auxFunName := ctx.auxFunNames[i]!
96
+ let indVal := ctx.typeInfos[i]!
97
+ let header ← mkBEqHeader indVal
98
+ let mut body ← mkMatch header indVal auxFunName
99
+ if ctx.usePartial then
100
+ let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
101
+ body ← mkLet letDecls body
102
+ let binders := header.binders
103
+ if ctx.usePartial then
104
+ `(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
105
+ else
106
+ `(@[expose] def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Bool := $body:term)
107
+
108
+ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
109
+ let mut auxDefs := #[]
110
+ for i in [:ctx.typeInfos.size] do
111
+ auxDefs := auxDefs.push (← mkAuxFunction ctx i)
112
+ `(mutual
113
+ set_option match.ignoreUnusedAlts true
114
+ $auxDefs:command*
115
+ end)
116
+
117
+ private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
118
+ let ctx ← mkContext "beq" declName
119
+ let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
120
+ trace[Elab.Deriving.beq] "\n{cmds}"
121
+ return cmds
122
+
123
+ private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
124
+ let auxFunName := ctx.auxFunNames[0]!
125
+ `(@[expose] def $(mkIdent auxFunName):ident (x y : $(mkIdent name)) : Bool := x.toCtorIdx == y.toCtorIdx)
126
+
127
+ private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
128
+ let ctx ← mkContext "beq" name
129
+ let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
130
+ trace[Elab.Deriving.beq] "\n{cmds}"
131
+ return cmds
132
+
133
+ open Command
134
+
135
+ def mkBEqInstance (declName : Name) : CommandElabM Unit := do
136
+ let cmds ← liftTermElabM <|
137
+ if (← isEnumType declName) then
138
+ mkBEqEnumCmd declName
139
+ else
140
+ mkBEqInstanceCmds declName
141
+ cmds.forM elabCommand
142
+
143
+ def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
144
+ if (← declNames.allM isInductive) then
145
+ for declName in declNames do
146
+ mkBEqInstance declName
147
+ return true
148
+ else
149
+ return false
150
+
151
+ builtin_initialize
152
+ registerDerivingHandler `BEq mkBEqInstanceHandler
153
+ registerTraceClass `Elab.Deriving.beq
154
+
155
+ end Lean.Elab.Deriving.BEq
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/Basic.lean ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Leonardo de Moura, Wojciech Nawrocki
5
+ -/
6
+ prelude
7
+ import Lean.Elab.Command
8
+ import Lean.Elab.DeclarationRange
9
+
10
+ namespace Lean.Elab
11
+ open Command
12
+
13
+ namespace Term
14
+ open Meta
15
+
16
+ /-- Result for `mkInst?` -/
17
+ structure MkInstResult where
18
+ instVal : Expr
19
+ instType : Expr
20
+ outParams : Array Expr := #[]
21
+
22
+ /--
23
+ Construct an instance for `className out₁ ... outₙ type`.
24
+ The method support classes with a prefix of `outParam`s (e.g. `MonadReader`). -/
25
+ private partial def mkInst? (className : Name) (type : Expr) : MetaM (Option MkInstResult) := do
26
+ let rec go? (instType instTypeType : Expr) (outParams : Array Expr) : MetaM (Option MkInstResult) := do
27
+ let instTypeType ← whnfD instTypeType
28
+ unless instTypeType.isForall do
29
+ return none
30
+ let d := instTypeType.bindingDomain!
31
+ if d.isOutParam then
32
+ let mvar ← mkFreshExprMVar d
33
+ go? (mkApp instType mvar) (instTypeType.bindingBody!.instantiate1 mvar) (outParams.push mvar)
34
+ else
35
+ unless (← isDefEqGuarded (← inferType type) d) do
36
+ return none
37
+ let instType ← instantiateMVars (mkApp instType type)
38
+ let instVal ← synthInstance instType
39
+ return some { instVal, instType, outParams }
40
+ let instType ← mkConstWithFreshMVarLevels className
41
+ go? instType (← inferType instType) #[]
42
+
43
+ def processDefDeriving (className : Name) (declName : Name) : TermElabM Bool := do
44
+ try
45
+ let ConstantInfo.defnInfo info ← getConstInfo declName | return false
46
+ let some result ← mkInst? className info.value | return false
47
+ let instTypeNew := mkApp result.instType.appFn! (Lean.mkConst declName (info.levelParams.map mkLevelParam))
48
+ Meta.check instTypeNew
49
+ let instName ← liftMacroM <| mkUnusedBaseName (declName.appendBefore "inst" |>.appendAfter className.getString!)
50
+ addAndCompile <| Declaration.defnDecl {
51
+ name := instName
52
+ levelParams := info.levelParams
53
+ type := (← instantiateMVars instTypeNew)
54
+ value := (← instantiateMVars result.instVal)
55
+ hints := info.hints
56
+ safety := info.safety
57
+ }
58
+ addInstance instName AttributeKind.global (eval_prio default)
59
+ addDeclarationRangesFromSyntax instName (← getRef)
60
+ return true
61
+ catch _ =>
62
+ return false
63
+
64
+ end Term
65
+
66
+ def DerivingHandler := (typeNames : Array Name) → CommandElabM Bool
67
+
68
+ builtin_initialize derivingHandlersRef : IO.Ref (NameMap (List DerivingHandler)) ← IO.mkRef {}
69
+
70
+ /-- A `DerivingHandler` is called on the fully qualified names of all types it is running for
71
+ as well as the syntax of a `with` argument, if present.
72
+
73
+ For example, `deriving instance Foo with fooArgs for Bar, Baz` invokes
74
+ ``fooHandler #[`Bar, `Baz] `(fooArgs)``. -/
75
+ def registerDerivingHandler (className : Name) (handler : DerivingHandler) : IO Unit := do
76
+ unless (← initializing) do
77
+ throw (IO.userError "failed to register deriving handler, it can only be registered during initialization")
78
+ derivingHandlersRef.modify fun m => match m.find? className with
79
+ | some handlers => m.insert className (handler :: handlers)
80
+ | none => m.insert className [handler]
81
+
82
+ def defaultHandler (className : Name) (typeNames : Array Name) : CommandElabM Unit := do
83
+ throwError "default handlers have not been implemented yet, class: '{className}' types: {typeNames}"
84
+
85
+ def applyDerivingHandlers (className : Name) (typeNames : Array Name) : CommandElabM Unit := do
86
+ withTraceNode `Elab.Deriving (fun _ => return m!"running deriving handlers for '{className}'") do
87
+ match (← derivingHandlersRef.get).find? className with
88
+ | some handlers =>
89
+ for handler in handlers do
90
+ if (← handler typeNames) then
91
+ return ()
92
+ defaultHandler className typeNames
93
+ | none => defaultHandler className typeNames
94
+
95
+ private def tryApplyDefHandler (className : Name) (declName : Name) : CommandElabM Bool :=
96
+ liftTermElabM do
97
+ Term.processDefDeriving className declName
98
+
99
+ @[builtin_command_elab «deriving»] def elabDeriving : CommandElab
100
+ | `(deriving instance $[$classes],* for $[$declNames],*) => do
101
+ let declNames ← liftCoreM <| declNames.mapM realizeGlobalConstNoOverloadWithInfo
102
+ for cls in classes do
103
+ try
104
+ let className ← liftCoreM <| realizeGlobalConstNoOverloadWithInfo cls
105
+ withRef cls do
106
+ if declNames.size == 1 then
107
+ if (← tryApplyDefHandler className declNames[0]!) then
108
+ return ()
109
+ applyDerivingHandlers className declNames
110
+ catch ex =>
111
+ logException ex
112
+ | _ => throwUnsupportedSyntax
113
+
114
+ structure DerivingClassView where
115
+ ref : Syntax
116
+ className : Name
117
+
118
+ def getOptDerivingClasses (optDeriving : Syntax) : CoreM (Array DerivingClassView) := do
119
+ match optDeriving with
120
+ | `(Parser.Command.optDeriving| deriving $[$classes],*) =>
121
+ let mut ret := #[]
122
+ for cls in classes do
123
+ let className ← realizeGlobalConstNoOverloadWithInfo cls
124
+ ret := ret.push { ref := cls, className := className }
125
+ return ret
126
+ | _ => return #[]
127
+
128
+ def DerivingClassView.applyHandlers (view : DerivingClassView) (declNames : Array Name) : CommandElabM Unit :=
129
+ withRef view.ref do applyDerivingHandlers view.className declNames
130
+
131
+ builtin_initialize
132
+ registerTraceClass `Elab.Deriving
133
+
134
+ end Lean.Elab
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/DecEq.lean ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Leonardo de Moura
5
+ -/
6
+ prelude
7
+ import Lean.Meta.Transform
8
+ import Lean.Meta.Inductive
9
+ import Lean.Elab.Deriving.Basic
10
+ import Lean.Elab.Deriving.Util
11
+
12
+ namespace Lean.Elab.Deriving.DecEq
13
+ open Lean.Parser.Term
14
+ open Meta
15
+
16
+ def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
17
+ mkHeader `DecidableEq 2 indVal
18
+
19
+ def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
20
+ let discrs ← mkDiscrs header indVal
21
+ let alts ← mkAlts
22
+ `(match $[$discrs],* with $alts:matchAlt*)
23
+ where
24
+ mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
25
+ | [] => ``(isTrue rfl)
26
+ | (a, b, recField, isProof) :: todo => withFreshMacroScope do
27
+ let rhs ← if isProof then
28
+ `(have h : @$a = @$b := rfl; by subst h; exact $(← mkSameCtorRhs todo):term)
29
+ else
30
+ let sameCtor ← mkSameCtorRhs todo
31
+ `(if h : @$a = @$b then
32
+ by subst h; exact $sameCtor:term
33
+ else
34
+ isFalse (by intro n; injection n; apply h _; assumption))
35
+ if let some auxFunName := recField then
36
+ -- add local instance for `a = b` using the function being defined `auxFunName`
37
+ `(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
38
+ else
39
+ return rhs
40
+
41
+ mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
42
+ let mut alts := #[]
43
+ for ctorName₁ in indVal.ctors do
44
+ let ctorInfo ← getConstInfoCtor ctorName₁
45
+ for ctorName₂ in indVal.ctors do
46
+ let mut patterns := #[]
47
+ -- add `_` pattern for indices
48
+ for _ in [:indVal.numIndices] do
49
+ patterns := patterns.push (← `(_))
50
+ if ctorName₁ == ctorName₂ then
51
+ let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
52
+ let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
53
+ let mut patterns := patterns
54
+ let mut ctorArgs1 := #[]
55
+ let mut ctorArgs2 := #[]
56
+ -- add `_` for inductive parameters, they are inaccessible
57
+ for _ in [:indVal.numParams] do
58
+ ctorArgs1 := ctorArgs1.push (← `(_))
59
+ ctorArgs2 := ctorArgs2.push (← `(_))
60
+ let mut todo := #[]
61
+ for i in [:ctorInfo.numFields] do
62
+ let x := xs[indVal.numParams + i]!
63
+ if type.containsFVar x.fvarId! then
64
+ -- If resulting type depends on this field, we don't need to compare
65
+ ctorArgs1 := ctorArgs1.push (← `(_))
66
+ ctorArgs2 := ctorArgs2.push (← `(_))
67
+ else
68
+ let a := mkIdent (← mkFreshUserName `a)
69
+ let b := mkIdent (← mkFreshUserName `b)
70
+ ctorArgs1 := ctorArgs1.push a
71
+ ctorArgs2 := ctorArgs2.push b
72
+ let xType ← inferType x
73
+ let indValNum :=
74
+ ctx.typeInfos.findIdx?
75
+ (xType.isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
76
+ let recField := indValNum.map (ctx.auxFunNames[·]!)
77
+ let isProof ← isProp xType
78
+ todo := todo.push (a, b, recField, isProof)
79
+ patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
80
+ patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
81
+ let rhs ← mkSameCtorRhs todo.toList
82
+ `(matchAltExpr| | $[$patterns:term],* => $rhs:term)
83
+ alts := alts.push alt
84
+ else if (← compatibleCtors ctorName₁ ctorName₂) then
85
+ patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
86
+ let rhs ← `(isFalse (by intro h; injection h))
87
+ alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
88
+ return alts
89
+
90
+ def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
91
+ let header ← mkDecEqHeader indVal
92
+ let body ← mkMatch ctx header indVal
93
+ let binders := header.binders
94
+ let target₁ := mkIdent header.targetNames[0]!
95
+ let target₂ := mkIdent header.targetNames[1]!
96
+ let termSuffix ← if indVal.isRec
97
+ then `(Parser.Termination.suffix|termination_by structural $target₁)
98
+ else `(Parser.Termination.suffix|)
99
+ let type ← `(Decidable ($target₁ = $target₂))
100
+ `(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term
101
+ $termSuffix:suffix)
102
+
103
+ def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
104
+ let mut res : Array (TSyntax `command) := #[]
105
+ for i in [:ctx.auxFunNames.size] do
106
+ let auxFunName := ctx.auxFunNames[i]!
107
+ let indVal := ctx.typeInfos[i]!
108
+ res := res.push (← mkAuxFunction ctx auxFunName indVal)
109
+ `(command| mutual $[$res:command]* end)
110
+
111
+ def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
112
+ let ctx ← mkContext "decEq" indVal.name
113
+ let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
114
+ trace[Elab.Deriving.decEq] "\n{cmds}"
115
+ return cmds
116
+
117
+ open Command
118
+
119
+ def mkDecEq (declName : Name) : CommandElabM Bool := do
120
+ let indVal ← getConstInfoInduct declName
121
+ if indVal.isNested then
122
+ return false -- nested inductive types are not supported yet
123
+ else
124
+ let cmds ← liftTermElabM <| mkDecEqCmds indVal
125
+ -- `cmds` can have a number of syntax nodes quadratic in the number of constructors
126
+ -- and thus create as many info tree nodes, which we never make use of but which can
127
+ -- significantly slow down e.g. the unused variables linter; avoid creating them
128
+ withEnableInfoTree false do
129
+ cmds.forM elabCommand
130
+ return true
131
+
132
+ partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
133
+ let indVal ← getConstInfoInduct declName
134
+ let enumType := mkConst declName
135
+ let ctors := indVal.ctors.toArray
136
+ withLocalDeclD `n (mkConst ``Nat) fun n => do
137
+ let cond := mkConst ``cond [1]
138
+ let rec mkDecTree (low high : Nat) : Expr :=
139
+ if low + 1 == high then
140
+ mkConst ctors[low]!
141
+ else if low + 2 == high then
142
+ mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkRawNatLit low)) (mkConst ctors[low]!) (mkConst ctors[low+1]!)
143
+ else
144
+ let mid := (low + high)/2
145
+ let lowBranch := mkDecTree low mid
146
+ let highBranch := mkDecTree mid high
147
+ mkApp4 cond enumType (mkApp2 (mkConst ``Nat.ble) (mkRawNatLit mid) n) highBranch lowBranch
148
+ let value ← mkLambdaFVars #[n] (mkDecTree 0 ctors.size)
149
+ let type ← mkArrow (mkConst ``Nat) enumType
150
+ addAndCompile <| Declaration.defnDecl {
151
+ name := Name.mkStr declName "ofNat"
152
+ levelParams := []
153
+ safety := DefinitionSafety.safe
154
+ hints := ReducibilityHints.abbrev
155
+ value, type
156
+ }
157
+
158
+ def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
159
+ let indVal ← getConstInfoInduct declName
160
+ let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx")
161
+ let ofNat := mkConst (Name.mkStr declName "ofNat")
162
+ let enumType := mkConst declName
163
+ let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType
164
+ let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType
165
+ let ctors := indVal.ctors
166
+ withLocalDeclD `x enumType fun x => do
167
+ let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x
168
+ let motive ← mkLambdaFVars #[x] resultType
169
+ let casesOn := mkConst (mkCasesOnName declName) [levelZero]
170
+ let mut value := mkApp2 casesOn motive x
171
+ for ctor in ctors do
172
+ value := mkApp value (mkApp rflEnum (mkConst ctor))
173
+ value ← mkLambdaFVars #[x] value
174
+ let type ← mkForallFVars #[x] resultType
175
+ addAndCompile <| Declaration.thmDecl {
176
+ name := Name.mkStr declName "ofNat_toCtorIdx"
177
+ levelParams := []
178
+ value, type
179
+ }
180
+
181
+ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
182
+ liftTermElabM <| mkEnumOfNat declName
183
+ liftTermElabM <| mkEnumOfNatThm declName
184
+ let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
185
+ let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
186
+ let cmd ← `(
187
+ instance : DecidableEq $(mkIdent declName) :=
188
+ fun x y =>
189
+ if h : x.toCtorIdx = y.toCtorIdx then
190
+ -- We use `rfl` in the following proof because the first script fails for unit-like datatypes due to etaStruct.
191
+ isTrue (by first | have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption | rfl)
192
+ else
193
+ isFalse fun h => by subst h; contradiction
194
+ )
195
+ trace[Elab.Deriving.decEq] "\n{cmd}"
196
+ elabCommand cmd
197
+
198
+ def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
199
+ if (← isEnumType declName) then
200
+ mkDecEqEnum declName
201
+ return true
202
+ else
203
+ mkDecEq declName
204
+
205
+ def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
206
+ declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
207
+
208
+ builtin_initialize
209
+ registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
210
+ registerTraceClass `Elab.Deriving.decEq
211
+
212
+ end Lean.Elab.Deriving.DecEq
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Deriving/FromToJson.lean ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2020 Sebastian Ullrich. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Sebastian Ullrich, Dany Fabian
5
+ -/
6
+ prelude
7
+ import Lean.Meta.Transform
8
+ import Lean.Elab.Deriving.Basic
9
+ import Lean.Elab.Deriving.Util
10
+ import Lean.Data.Json.FromToJson
11
+
12
+ namespace Lean.Elab.Deriving.FromToJson
13
+ open Lean.Elab.Command
14
+ open Lean.Json
15
+ open Lean.Parser.Term
16
+ open Lean.Meta
17
+
18
+ def mkToJsonHeader (indVal : InductiveVal) : TermElabM Header := do
19
+ mkHeader ``ToJson 1 indVal
20
+
21
+ def mkFromJsonHeader (indVal : InductiveVal) : TermElabM Header := do
22
+ let header ← mkHeader ``FromJson 0 indVal
23
+ let jsonArg ← `(bracketedBinderF|(json : Json))
24
+ return {header with
25
+ binders := header.binders.push jsonArg}
26
+
27
+ def mkJsonField (n : Name) : CoreM (Bool × Term) := do
28
+ let .str .anonymous s := n | throwError "invalid json field name {n}"
29
+ let s₁ := s.dropRightWhile (· == '?')
30
+ return (s != s₁, Syntax.mkStrLit s₁)
31
+
32
+ def mkToJsonBodyForStruct (header : Header) (indName : Name) : TermElabM Term := do
33
+ let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false)
34
+ let fields ← fields.mapM fun field => do
35
+ let (isOptField, nm) ← mkJsonField field
36
+ let target := mkIdent header.targetNames[0]!
37
+ if isOptField then ``(opt $nm $target.$(mkIdent field))
38
+ else ``([($nm, toJson ($target).$(mkIdent field))])
39
+ `(mkObj <| List.flatten [$fields,*])
40
+
41
+ def mkToJsonBodyForInduct (ctx : Context) (header : Header) (indName : Name) : TermElabM Term := do
42
+ let indVal ← getConstInfoInduct indName
43
+ let toJsonFuncId := mkIdent ctx.auxFunNames[0]!
44
+ -- Return syntax to JSONify `id`, either via `ToJson` or recursively
45
+ -- if `id`'s type is the type we're deriving for.
46
+ let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do
47
+ if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
48
+ else ``(toJson $id:ident)
49
+ let discrs ← mkDiscrs header indVal
50
+ let alts ← mkAlts indVal fun ctor args userNames => do
51
+ let ctorStr := ctor.name.eraseMacroScopes.getString!
52
+ match args, userNames with
53
+ | #[], _ => ``(toJson $(quote ctorStr))
54
+ | #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))])
55
+ | xs, none =>
56
+ let xs ← xs.mapM fun (x, t) => mkToJson x t
57
+ ``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])])
58
+ | xs, some userNames =>
59
+ let xs ← xs.mapIdxM fun idx (x, t) => do
60
+ `(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t)))
61
+ ``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])])
62
+ `(match $[$discrs],* with $alts:matchAlt*)
63
+
64
+ where
65
+ mkAlts
66
+ (indVal : InductiveVal)
67
+ (rhs : ConstructorVal → Array (Ident × Expr) → Option (Array Name) → TermElabM Term): TermElabM (Array (TSyntax ``matchAlt)) := do
68
+ let mut alts := #[]
69
+ for ctorName in indVal.ctors do
70
+ let ctorInfo ← getConstInfoCtor ctorName
71
+ let alt ← forallTelescopeReducing ctorInfo.type fun xs _ => do
72
+ let mut patterns := #[]
73
+ -- add `_` pattern for indices
74
+ for _ in [:indVal.numIndices] do
75
+ patterns := patterns.push (← `(_))
76
+ let mut ctorArgs := #[]
77
+ -- add `_` for inductive parameters, they are inaccessible
78
+ for _ in [:indVal.numParams] do
79
+ ctorArgs := ctorArgs.push (← `(_))
80
+ -- bound constructor arguments and their types
81
+ let mut binders := #[]
82
+ let mut userNames := #[]
83
+ for i in [:ctorInfo.numFields] do
84
+ let x := xs[indVal.numParams + i]!
85
+ let localDecl ← x.fvarId!.getDecl
86
+ if !localDecl.userName.hasMacroScopes then
87
+ userNames := userNames.push localDecl.userName
88
+ let a := mkIdent (← mkFreshUserName `a)
89
+ binders := binders.push (a, localDecl.type)
90
+ ctorArgs := ctorArgs.push a
91
+ patterns := patterns.push (← `(@$(mkIdent ctorInfo.name):ident $ctorArgs:term*))
92
+ let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none)
93
+ `(matchAltExpr| | $[$patterns:term],* => $rhs:term)
94
+ alts := alts.push alt
95
+ return alts
96
+
97
+ def mkFromJsonBodyForStruct (indName : Name) : TermElabM Term := do
98
+ let fields := getStructureFieldsFlattened (← getEnv) indName (includeSubobjectFields := false)
99
+ let getters ← fields.mapM (fun field => do
100
+ let getter ← `(getObjValAs? json _ $(Prod.snd <| ← mkJsonField field))
101
+ let getter ← `(doElem| Except.mapError (fun s => (toString $(quote indName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter)
102
+ return getter
103
+ )
104
+ let fields := fields.map mkIdent
105
+ `(do
106
+ $[let $fields:ident ← $getters]*
107
+ return { $[$fields:ident := $(id fields)],* })
108
+
109
+ def mkFromJsonBodyForInduct (ctx : Context) (indName : Name) : TermElabM Term := do
110
+ let indVal ← getConstInfoInduct indName
111
+ let alts ← mkAlts indVal
112
+ let auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched"))
113
+ `($auxTerm)
114
+ where
115
+ mkAlts (indVal : InductiveVal) : TermElabM (Array Term) := do
116
+ let mut alts := #[]
117
+ for ctorName in indVal.ctors do
118
+ let ctorInfo ← getConstInfoCtor ctorName
119
+ let alt ← do forallTelescopeReducing ctorInfo.type fun xs _ => do
120
+ let mut binders := #[]
121
+ let mut userNames := #[]
122
+ for i in [:ctorInfo.numFields] do
123
+ let x := xs[indVal.numParams + i]!
124
+ let localDecl ← x.fvarId!.getDecl
125
+ if !localDecl.userName.hasMacroScopes then
126
+ userNames := userNames.push localDecl.userName
127
+ let a := mkIdent (← mkFreshUserName `a)
128
+ binders := binders.push (a, localDecl.type)
129
+ let fromJsonFuncId := mkIdent ctx.auxFunNames[0]!
130
+ -- Return syntax to parse `id`, either via `FromJson` or recursively
131
+ -- if `id`'s type is the type we're deriving for.
132
+ let mkFromJson (idx : Nat) (type : Expr) : TermElabM (TSyntax ``doExpr) :=
133
+ if type.isAppOf indVal.name then `(Lean.Parser.Term.doExpr| $fromJsonFuncId:ident jsons[$(quote idx)]!)
134
+ else `(Lean.Parser.Term.doExpr| fromJson? jsons[$(quote idx)]!)
135
+ let identNames := binders.map Prod.fst
136
+ let fromJsons ← binders.mapIdxM fun idx (_, type) => mkFromJson idx type
137
+ let userNamesOpt ← if binders.size == userNames.size then
138
+ ``(some #[$[$(userNames.map quote)],*])
139
+ else
140
+ ``(none)
141
+ let stx ←
142
+ `((Json.parseTagged json $(quote ctorName.eraseMacroScopes.getString!) $(quote ctorInfo.numFields) $(quote userNamesOpt)).bind
143
+ (fun jsons => do
144
+ $[let $identNames:ident ← $fromJsons:doExpr]*
145
+ return $(mkIdent ctorName):ident $identNames*))
146
+ pure (stx, ctorInfo.numFields)
147
+ alts := alts.push alt
148
+ -- the smaller cases, especially the ones without fields are likely faster
149
+ let alts' := alts.qsort (fun (_, x) (_, y) => x < y)
150
+ return alts'.map Prod.fst
151
+
152
+ def mkToJsonBody (ctx : Context) (header : Header) (e : Expr): TermElabM Term := do
153
+ let indName := e.getAppFn.constName!
154
+ if isStructure (← getEnv) indName then
155
+ mkToJsonBodyForStruct header indName
156
+ else
157
+ mkToJsonBodyForInduct ctx header indName
158
+
159
+ def mkToJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
160
+ let auxFunName := ctx.auxFunNames[i]!
161
+ let header ← mkToJsonHeader ctx.typeInfos[i]!
162
+ let binders := header.binders
163
+ Term.elabBinders binders fun _ => do
164
+ let type ← Term.elabTerm header.targetType none
165
+ let mut body ← mkToJsonBody ctx header type
166
+ if ctx.usePartial then
167
+ let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames
168
+ body ← mkLet letDecls body
169
+ `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
170
+ else
171
+ `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Json := $body:term)
172
+
173
+ def mkFromJsonBody (ctx : Context) (e : Expr) : TermElabM Term := do
174
+ let indName := e.getAppFn.constName!
175
+ if isStructure (← getEnv) indName then
176
+ mkFromJsonBodyForStruct indName
177
+ else
178
+ mkFromJsonBodyForInduct ctx indName
179
+
180
+ def mkFromJsonAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
181
+ let auxFunName := ctx.auxFunNames[i]!
182
+ let indval := ctx.typeInfos[i]!
183
+ let header ← mkFromJsonHeader indval --TODO fix header info
184
+ let binders := header.binders
185
+ Term.elabBinders binders fun _ => do
186
+ let type ← Term.elabTerm header.targetType none
187
+ let mut body ← mkFromJsonBody ctx type
188
+ if ctx.usePartial || indval.isRec then
189
+ let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames
190
+ body ← mkLet letDecls body
191
+ `(private partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term)
192
+ else
193
+ `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Except String $(← mkInductiveApp ctx.typeInfos[i]! header.argNames) := $body:term)
194
+
195
+
196
+ def mkToJsonMutualBlock (ctx : Context) : TermElabM Command := do
197
+ let mut auxDefs := #[]
198
+ for i in [:ctx.typeInfos.size] do
199
+ auxDefs := auxDefs.push (← mkToJsonAuxFunction ctx i)
200
+ `(mutual
201
+ $auxDefs:command*
202
+ end)
203
+
204
+ def mkFromJsonMutualBlock (ctx : Context) : TermElabM Command := do
205
+ let mut auxDefs := #[]
206
+ for i in [:ctx.typeInfos.size] do
207
+ auxDefs := auxDefs.push (← mkFromJsonAuxFunction ctx i)
208
+ `(mutual
209
+ $auxDefs:command*
210
+ end)
211
+
212
+ private def mkToJsonInstance (declName : Name) : TermElabM (Array Command) := do
213
+ let ctx ← mkContext "toJson" declName
214
+ let cmds := #[← mkToJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``ToJson #[declName])
215
+ trace[Elab.Deriving.toJson] "\n{cmds}"
216
+ return cmds
217
+
218
+ private def mkFromJsonInstance (declName : Name) : TermElabM (Array Command) := do
219
+ let ctx ← mkContext "fromJson" declName
220
+ let cmds := #[← mkFromJsonMutualBlock ctx] ++ (← mkInstanceCmds ctx ``FromJson #[declName])
221
+ trace[Elab.Deriving.fromJson] "\n{cmds}"
222
+ return cmds
223
+
224
+ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
225
+ if (← declNames.allM isInductive) && declNames.size > 0 then
226
+ for declName in declNames do
227
+ let cmds ← liftTermElabM <| mkToJsonInstance declName
228
+ cmds.forM elabCommand
229
+ return true
230
+ else
231
+ return false
232
+
233
+ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
234
+ if (← declNames.allM isInductive) && declNames.size > 0 then
235
+ for declName in declNames do
236
+ let cmds ← liftTermElabM <| mkFromJsonInstance declName
237
+ cmds.forM elabCommand
238
+ return true
239
+ else
240
+ return false
241
+
242
+ builtin_initialize
243
+ registerDerivingHandler ``ToJson mkToJsonInstanceHandler
244
+ registerDerivingHandler ``FromJson mkFromJsonInstanceHandler
245
+
246
+ registerTraceClass `Elab.Deriving.toJson
247
+ registerTraceClass `Elab.Deriving.fromJson
248
+
249
+ end Lean.Elab.Deriving.FromToJson
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Term.lean ADDED
@@ -0,0 +1,2128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2019 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Leonardo de Moura, Sebastian Ullrich
5
+ -/
6
+ prelude
7
+ import Lean.ReservedNameAction
8
+ import Lean.Meta.AppBuilder
9
+ import Lean.Meta.CollectMVars
10
+ import Lean.Meta.Coe
11
+ import Lean.Util.CollectLevelMVars
12
+ import Lean.Linter.Deprecated
13
+ import Lean.Elab.Config
14
+ import Lean.Elab.Level
15
+ import Lean.Elab.DeclModifiers
16
+ import Lean.Elab.PreDefinition.TerminationHint
17
+ import Lean.Elab.DeclarationRange
18
+ import Lean.Elab.WhereFinally
19
+ import Lean.Language.Basic
20
+ import Lean.Elab.InfoTree.InlayHints
21
+
22
+ namespace Lean.Elab
23
+
24
+ namespace Term
25
+
26
+ /-- Saved context for postponed terms and tactics to be executed. -/
27
+ structure SavedContext where
28
+ declName? : Option Name
29
+ options : Options
30
+ openDecls : List OpenDecl
31
+ macroStack : MacroStack
32
+ errToSorry : Bool
33
+ levelNames : List Name
34
+
35
+ /-- The kind of a tactic metavariable, used for additional error reporting. -/
36
+ inductive TacticMVarKind
37
+ /-- Standard tactic metavariable, arising from `by ...` syntax. -/
38
+ | term
39
+ /-- Tactic metavariable arising from an autoparam for a function application. -/
40
+ | autoParam (argName : Name)
41
+ /-- Tactic metavariable arising from an autoparam for a structure field. -/
42
+ | fieldAutoParam (fieldName structName : Name)
43
+
44
+ /-- We use synthetic metavariables as placeholders for pending elaboration steps. -/
45
+ inductive SyntheticMVarKind where
46
+ /--
47
+ Use typeclass resolution to synthesize value for metavariable.
48
+ If `extraErrorMsg?` is `some msg`, `msg` contains additional information to include in error messages
49
+ regarding type class synthesis failure.
50
+ -/
51
+ | typeClass (extraErrorMsg? : Option MessageData)
52
+ /--
53
+ Use coercion to synthesize value for the metavariable.
54
+ If synthesis fails, then throws an error.
55
+ - If `mkErrorMsg?` is provided, then the error `mkErrorMsg expectedType e` is thrown.
56
+ The `mkErrorMsg` function is allowed to throw an error itself.
57
+ - Otherwise, throws a default type mismatch error message.
58
+ If `header?` is not provided, the default header is "type mismatch".
59
+ If `f?` is provided, then throws an application type mismatch error.
60
+ -/
61
+ | coe (header? : Option String) (expectedType : Expr) (e : Expr) (f? : Option Expr)
62
+ (mkErrorMsg? : Option (MVarId → Expr → Expr → MetaM MessageData))
63
+ /--
64
+ Use tactic to synthesize value for metavariable.
65
+
66
+ If `delayOnMVars` is true, the tactic will not be executed until the goal is free of unassigned
67
+ expr metavariables.
68
+ -/
69
+ | tactic (tacticCode : Syntax) (ctx : SavedContext) (kind : TacticMVarKind) (delayOnMVars := false)
70
+ /-- Metavariable represents a hole whose elaboration has been postponed. -/
71
+ | postponed (ctx : SavedContext)
72
+ deriving Inhabited
73
+
74
+ /--
75
+ Convert an "extra" optional error message into a message `"\n{msg}"` (if `some msg`) and `MessageData.nil` (if `none`)
76
+ -/
77
+ def extraMsgToMsg (extraErrorMsg? : Option MessageData) : MessageData :=
78
+ if let some msg := extraErrorMsg? then m!"\n{msg}" else .nil
79
+
80
+ instance : ToString SyntheticMVarKind where
81
+ toString
82
+ | .typeClass .. => "typeclass"
83
+ | .coe .. => "coe"
84
+ | .tactic .. => "tactic"
85
+ | .postponed .. => "postponed"
86
+
87
+ structure SyntheticMVarDecl where
88
+ stx : Syntax
89
+ kind : SyntheticMVarKind
90
+ deriving Inhabited
91
+
92
+ /--
93
+ We can optionally associate an error context with a metavariable (see `MVarErrorInfo`).
94
+ We have three different kinds of error context.
95
+ -/
96
+ inductive MVarErrorKind where
97
+ /-- Metavariable for implicit arguments. `ctx` is the parent application,
98
+ `lctx` is a local context where it is valid (necessary for eta feature for named arguments). -/
99
+ | implicitArg (lctx : LocalContext) (ctx : Expr)
100
+ /-- Metavariable for explicit holes provided by the user (e.g., `_` and `?m`) -/
101
+ | hole
102
+ /-- "Custom", `msgData` stores the additional error messages. -/
103
+ | custom (msgData : MessageData)
104
+ deriving Inhabited
105
+
106
+ instance : ToString MVarErrorKind where
107
+ toString
108
+ | .implicitArg _ _ => "implicitArg"
109
+ | .hole => "hole"
110
+ | .custom _ => "custom"
111
+
112
+ /--
113
+ We can optionally associate an error context with metavariables.
114
+ -/
115
+ structure MVarErrorInfo where
116
+ mvarId : MVarId
117
+ ref : Syntax
118
+ kind : MVarErrorKind
119
+ deriving Inhabited
120
+
121
+ /--
122
+ When reporting unexpected universe level metavariables, it is useful to localize the errors
123
+ to particular terms, especially at `let` bindings and function binders,
124
+ where universe polymorphism is not permitted.
125
+ -/
126
+ structure LevelMVarErrorInfo where
127
+ lctx : LocalContext
128
+ expr : Expr
129
+ ref : Syntax
130
+ msgData? : Option MessageData := none
131
+ deriving Inhabited
132
+
133
+ /--
134
+ Nested `let rec` expressions are eagerly lifted by the elaborator.
135
+ We store the information necessary for performing the lifting here.
136
+ -/
137
+ structure LetRecToLift where
138
+ ref : Syntax
139
+ fvarId : FVarId
140
+ attrs : Array Attribute
141
+ shortDeclName : Name
142
+ declName : Name
143
+ parentName? : Option Name
144
+ lctx : LocalContext
145
+ localInstances : LocalInstances
146
+ type : Expr
147
+ val : Expr
148
+ mvarId : MVarId
149
+ termination : TerminationHints
150
+ deriving Inhabited
151
+
152
+ /--
153
+ State of the `TermElabM` monad.
154
+ -/
155
+ structure State where
156
+ levelNames : List Name := []
157
+ syntheticMVars : MVarIdMap SyntheticMVarDecl := {}
158
+ pendingMVars : List MVarId := {}
159
+ /-- List of errors associated to a metavariable that are shown to the user if the metavariable could not be fully instantiated -/
160
+ mvarErrorInfos : List MVarErrorInfo := []
161
+ /-- List of data to be able to localize universe level metavariable errors to particular expressions. -/
162
+ levelMVarErrorInfos : List LevelMVarErrorInfo := []
163
+ /--
164
+ `mvarArgNames` stores the argument names associated to metavariables.
165
+ These are used in combination with `mvarErrorInfos` for throwing errors about metavariables that could not be fully instantiated.
166
+ For example when elaborating `List _`, the argument name of the placeholder will be `α`.
167
+
168
+ While elaborating an application, `mvarArgNames` is set for each metavariable argument, using the available argument name.
169
+ This may happen before or after the `mvarErrorInfos` is set for the same metavariable.
170
+
171
+ We used to store the argument names in `mvarErrorInfos`, updating the `MVarErrorInfos` to add the argument name when it is available,
172
+ but this doesn't work if the argument name is available _before_ the `mvarErrorInfos` is set for that metavariable.
173
+ -/
174
+ mvarArgNames : MVarIdMap Name := {}
175
+ letRecsToLift : List LetRecToLift := []
176
+ deriving Inhabited
177
+
178
+ /--
179
+ Backtrackable state for the `TermElabM` monad.
180
+ -/
181
+ structure SavedState where
182
+ «meta» : Meta.SavedState
183
+ «elab» : State
184
+ deriving Nonempty
185
+
186
+ end Term
187
+
188
+ namespace Tactic
189
+
190
+ /--
191
+ State of the `TacticM` monad.
192
+ -/
193
+ structure State where
194
+ goals : List MVarId
195
+ deriving Inhabited
196
+
197
+ /--
198
+ Snapshots are used to implement the `save` tactic.
199
+ This tactic caches the state of the system, and allows us to "replay"
200
+ expensive proofs efficiently. This is only relevant implementing the
201
+ LSP server.
202
+ -/
203
+ structure Snapshot where
204
+ core : Core.State
205
+ «meta» : Meta.State
206
+ term : Term.State
207
+ tactic : Tactic.State
208
+ stx : Syntax
209
+
210
+ /--
211
+ Key for the cache used to implement the `save` tactic.
212
+ -/
213
+ structure CacheKey where
214
+ mvarId : MVarId -- TODO: should include all goals
215
+ pos : String.Pos
216
+ deriving BEq, Hashable, Inhabited
217
+
218
+ /--
219
+ Cache for the `save` tactic.
220
+ -/
221
+ structure Cache where
222
+ pre : PHashMap CacheKey Snapshot := {}
223
+ post : PHashMap CacheKey Snapshot := {}
224
+ deriving Inhabited
225
+
226
+ section Snapshot
227
+ open Language
228
+
229
+ structure SavedState where
230
+ term : Term.SavedState
231
+ tactic : State
232
+
233
+ /-- Snapshot after finishing execution of a tactic. -/
234
+ structure TacticFinishedSnapshot extends Language.Snapshot where
235
+ /-- State saved for reuse, if no fatal exception occurred. -/
236
+ state? : Option SavedState
237
+ /-- Untyped snapshots from `logSnapshotTask`, saved at this level for cancellation. -/
238
+ moreSnaps : Array (SnapshotTask SnapshotTree)
239
+ deriving Inhabited
240
+ instance : ToSnapshotTree TacticFinishedSnapshot where
241
+ toSnapshotTree s := ⟨s.toSnapshot, s.moreSnaps⟩
242
+
243
+ /-- Snapshot just before execution of a tactic. -/
244
+ structure TacticParsedSnapshot extends Language.Snapshot where
245
+ /-- Syntax tree of the tactic, stored and compared for incremental reuse. -/
246
+ stx : Syntax
247
+ /-- Task for nested incrementality, if enabled for tactic. -/
248
+ inner? : Option (SnapshotTask TacticParsedSnapshot) := none
249
+ /-- Task for state after tactic execution. -/
250
+ finished : SnapshotTask TacticFinishedSnapshot
251
+ /-- Tasks for subsequent, potentially parallel, tactic steps. -/
252
+ next : Array (SnapshotTask TacticParsedSnapshot) := #[]
253
+ deriving Inhabited
254
+ partial instance : ToSnapshotTree TacticParsedSnapshot where
255
+ toSnapshotTree := go where
256
+ go := fun s => ⟨s.toSnapshot,
257
+ s.inner?.toArray.map (·.map (sync := true) go) ++
258
+ #[s.finished.map (sync := true) toSnapshotTree] ++
259
+ s.next.map (·.map (sync := true) go)⟩
260
+
261
+ end Snapshot
262
+ end Tactic
263
+
264
+ namespace Term
265
+
266
+ structure Context where
267
+ declName? : Option Name := none
268
+ macroStack : MacroStack := []
269
+ /--
270
+ When `mayPostpone == true`, an elaboration function may interrupt its execution by throwing `Exception.postpone`.
271
+ The function `elabTerm` catches this exception and creates fresh synthetic metavariable `?m`, stores `?m` in
272
+ the list of pending synthetic metavariables, and returns `?m`. -/
273
+ mayPostpone : Bool := true
274
+ /--
275
+ When `errToSorry` is set to true, the method `elabTerm` catches
276
+ exceptions and converts them into synthetic `sorry`s.
277
+ The implementation of choice nodes and overloaded symbols rely on the fact
278
+ that when `errToSorry` is set to false for an elaboration function `F`, then
279
+ `errToSorry` remains `false` for all elaboration functions invoked by `F`.
280
+ That is, it is safe to transition `errToSorry` from `true` to `false`, but
281
+ we must not set `errToSorry` to `true` when it is currently set to `false`. -/
282
+ errToSorry : Bool := true
283
+ /--
284
+ When `autoBoundImplicit` is set to true, instead of producing
285
+ an "unknown identifier" error for unbound variables, we generate an
286
+ internal exception. This exception is caught at `elabBinders` and
287
+ `elabTypeWithUnboldImplicit`. Both methods add implicit declarations
288
+ for the unbound variable and try again. -/
289
+ autoBoundImplicit : Bool := false
290
+ autoBoundImplicits : PArray Expr := {}
291
+ /--
292
+ A name `n` is only eligible to be an auto implicit name if `autoBoundImplicitForbidden n = false`.
293
+ We use this predicate to disallow `f` to be considered an auto implicit name in a definition such
294
+ as
295
+ ```
296
+ def f : f → Bool := fun _ => true
297
+ ```
298
+ -/
299
+ autoBoundImplicitForbidden : Name → Bool := fun _ => false
300
+ /-- Map from user name to internal unique name -/
301
+ sectionVars : NameMap Name := {}
302
+ /-- Map from internal name to fvar -/
303
+ sectionFVars : NameMap Expr := {}
304
+ /-- Enable/disable implicit lambdas feature. -/
305
+ implicitLambda : Bool := true
306
+ /-- Heed `elab_as_elim` attribute. -/
307
+ heedElabAsElim : Bool := true
308
+ /-- Noncomputable sections automatically add the `noncomputable` modifier to any declaration we cannot generate code for. -/
309
+ isNoncomputableSection : Bool := false
310
+ /-- When `true` we skip TC failures. We use this option when processing patterns. -/
311
+ ignoreTCFailures : Bool := false
312
+ /-- `true` when elaborating patterns. It affects how we elaborate named holes. -/
313
+ inPattern : Bool := false
314
+ /--
315
+ Snapshot for incremental processing of current tactic, if any.
316
+
317
+ Invariant: if the bundle's `old?` is set, then the state *up to the start* of the tactic is
318
+ unchanged, i.e. reuse is possible.
319
+ -/
320
+ tacSnap? : Option (Language.SnapshotBundle Tactic.TacticParsedSnapshot) := none
321
+ /--
322
+ If `true`, we store in the `Expr` the `Syntax` for recursive applications (i.e., applications
323
+ of free variables tagged with `isAuxDecl`). We store the `Syntax` using `mkRecAppWithSyntax`.
324
+ We use the `Syntax` object to produce better error messages at `Structural.lean` and `WF.lean`. -/
325
+ saveRecAppSyntax : Bool := true
326
+ /--
327
+ If `holesAsSyntheticOpaque` is `true`, then we mark metavariables associated
328
+ with `_`s as `syntheticOpaque` if they do not occur in patterns.
329
+ This option is useful when elaborating terms in tactics such as `refine'` where
330
+ we want holes there to become new goals. See issue #1681, we have
331
+ `refine' (fun x => _)
332
+ -/
333
+ holesAsSyntheticOpaque : Bool := false
334
+ /--
335
+ If `checkDeprecated := true`, then `Linter.checkDeprecated` when creating constants.
336
+ -/
337
+ checkDeprecated : Bool := true
338
+
339
+ abbrev TermElabM := ReaderT Context $ StateRefT State MetaM
340
+ abbrev TermElab := Syntax → Option Expr → TermElabM Expr
341
+
342
+ /-
343
+ Make the compiler generate specialized `pure`/`bind` so we do not have to optimize through the
344
+ whole monad stack at every use site. May eventually be covered by `deriving`.
345
+ -/
346
+ @[always_inline]
347
+ instance : Monad TermElabM :=
348
+ let i := inferInstanceAs (Monad TermElabM)
349
+ { pure := i.pure, bind := i.bind }
350
+
351
+ open Meta
352
+
353
+ instance : Inhabited (TermElabM α) where
354
+ default := throw default
355
+
356
+ protected def saveState : TermElabM SavedState :=
357
+ return { «meta» := (← Meta.saveState), «elab» := (← get) }
358
+
359
+ def SavedState.restore (s : SavedState) (restoreInfo : Bool := false) : TermElabM Unit := do
360
+ let traceState ← getTraceState -- We never backtrack trace message
361
+ let infoState ← getInfoState -- We also do not backtrack the info nodes when `restoreInfo == false`
362
+ s.meta.restore
363
+ set s.elab
364
+ setTraceState traceState
365
+ unless restoreInfo do
366
+ setInfoState infoState
367
+
368
+ /--
369
+ Like `Meta.withRestoreOrSaveFull` for `TermElabM`, but also takes a `tacSnap?` that
370
+ * when running `act`, is set as `Context.tacSnap?`
371
+ * otherwise (i.e. on restore) is used to update the new snapshot promise to the old task's
372
+ value.
373
+ This extra restore step is necessary because while `reusableResult?` can be used to replay any
374
+ effects on `State`, `Context.tacSnap?` is not part of it but changed via an `IO` side effect, so
375
+ it needs to be replayed separately.
376
+
377
+ We use an explicit parameter instead of accessing `Context.tacSnap?` directly because this prevents
378
+ `withRestoreOrSaveFull` and `withReader` from being used in the wrong order.
379
+ -/
380
+ @[specialize]
381
+ def withRestoreOrSaveFull (reusableResult? : Option (α × SavedState))
382
+ (tacSnap? : Option (Language.SnapshotBundle Tactic.TacticParsedSnapshot)) (act : TermElabM α) :
383
+ TermElabM (α × SavedState) := do
384
+ if let some (_, state) := reusableResult? then
385
+ set state.elab
386
+ if let some snap := tacSnap? then
387
+ let some old := snap.old?
388
+ | throwError "withRestoreOrSaveFull: expected old snapshot in `tacSnap?`"
389
+ snap.new.resolve old.val.get
390
+
391
+ let reusableResult? := reusableResult?.map (fun (val, state) => (val, state.meta))
392
+ let (a, «meta») ← withReader ({ · with tacSnap? }) do
393
+ controlAt MetaM fun runInBase => do
394
+ Meta.withRestoreOrSaveFull reusableResult? <| runInBase act
395
+ return (a, { «meta», «elab» := (← get) })
396
+
397
+ instance : MonadBacktrack SavedState TermElabM where
398
+ saveState := Term.saveState
399
+ restoreState b := b.restore
400
+
401
+ /--
402
+ Incremental elaboration helper. Avoids leakage of data from outside syntax via the monadic context
403
+ when running `act` on `stx` by
404
+ * setting `stx` as the `ref` and
405
+ * deactivating `suppressElabErrors` if `stx` is `missing`-free, which also helps with not hiding
406
+ useful errors in this part of the input. Note that if `stx` has `missing`, this should always be
407
+ true for the outer syntax as well, so taking the old value of `suppressElabErrors` into account
408
+ should not introduce data leakage.
409
+
410
+ This combinator should always be used when narrowing reuse to a syntax subtree, usually (in the case
411
+ of tactics, to be generalized) via `withNarrowed(Arg)TacticReuse`.
412
+ -/
413
+ def withReuseContext [Monad m] [MonadWithReaderOf Core.Context m] (stx : Syntax) (act : m α) :
414
+ m α := do
415
+ withTheReader Core.Context (fun ctx => { ctx with
416
+ ref := stx
417
+ suppressElabErrors := ctx.suppressElabErrors && stx.hasMissing }) act
418
+
419
+ /--
420
+ Manages reuse information for nested tactics by `split`ting given syntax into an outer and inner
421
+ part. `act` is then run on the inner part but with reuse information adjusted as following:
422
+ * If the old (from `tacSnap?`'s `SyntaxGuarded.stx`) and new (from `stx`) outer syntax are not
423
+ identical according to `Syntax.eqWithInfo`, reuse is disabled.
424
+ * Otherwise, the old syntax as stored in `tacSnap?` is updated to the old *inner* syntax.
425
+ * In any case, `withReuseContext` is used on the new inner syntax to further prepare the monadic
426
+ context.
427
+
428
+ For any tactic that participates in reuse, `withNarrowedTacticReuse` should be applied to the
429
+ tactic's syntax and `act` should be used to do recursive tactic evaluation of nested parts. Also,
430
+ after this function, `getAndEmptySnapshotTasks` should be called and the result stored in a snapshot
431
+ so that the tasks don't end up in a snapshot further up and are cancelled together with it; see
432
+ note [Incremental Cancellation].
433
+ -/
434
+ def withNarrowedTacticReuse [Monad m] [MonadReaderOf Context m] [MonadLiftT BaseIO m]
435
+ [MonadWithReaderOf Core.Context m] [MonadWithReaderOf Context m] [MonadOptions m]
436
+ (split : Syntax → Syntax × Syntax) (act : Syntax → m α) (stx : Syntax) : m α := do
437
+ let (outer, inner) := split stx
438
+ let opts ← getOptions
439
+ let ctx ← readThe Term.Context
440
+ withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.map fun tacSnap =>
441
+ { tacSnap with old? := tacSnap.old?.bind fun old => do
442
+ let (oldOuter, oldInner) := split old.stx
443
+ guard <| outer.eqWithInfoAndTraceReuse opts oldOuter
444
+ return { old with stx := oldInner }
445
+ }
446
+ }) do
447
+ if let some oldOuter := ctx.tacSnap?.bind (·.old?) then
448
+ if (← read).tacSnap?.bind (·.old?) |>.isNone then
449
+ oldOuter.val.cancelRec
450
+ withReuseContext inner (act inner)
451
+
452
+ /--
453
+ A variant of `withNarrowedTacticReuse` that uses `stx[argIdx]` as the inner syntax and all `stx`
454
+ child nodes before that as the outer syntax, i.e. reuse is disabled if there was any change before
455
+ `argIdx`.
456
+
457
+ NOTE: child nodes after `argIdx` are not tested (which would almost always disable reuse as they are
458
+ necessarily shifted by changes at `argIdx`) so it must be ensured that the result of `arg` does not
459
+ depend on them (i.e. they should not be inspected beforehand).
460
+ -/
461
+ def withNarrowedArgTacticReuse [Monad m] [MonadReaderOf Context m] [MonadLiftT BaseIO m]
462
+ [MonadWithReaderOf Core.Context m] [MonadWithReaderOf Context m] [MonadOptions m]
463
+ (argIdx : Nat) (act : Syntax → m α) (stx : Syntax) : m α :=
464
+ withNarrowedTacticReuse (fun stx => (mkNullNode stx.getArgs[*...argIdx], stx[argIdx])) act stx
465
+
466
+ /--
467
+ Disables incremental tactic reuse *and* reporting for `act` if `cond` is true by setting `tacSnap?`
468
+ to `none`. This should be done for tactic blocks that are run multiple times as otherwise the
469
+ reported progress will jump back and forth (and partial reuse for these kinds of tact blocks is
470
+ similarly questionable).
471
+ -/
472
+ def withoutTacticIncrementality [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
473
+ (cond : Bool) (act : m α) : m α := do
474
+ let opts ← getOptions
475
+ withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.filter fun tacSnap => Id.run do
476
+ if let some old := tacSnap.old? then
477
+ if cond && opts.getBool `trace.Elab.reuse then
478
+ dbg_trace "reuse stopped: guard failed at {old.stx}"
479
+ return !cond
480
+ }) act
481
+
482
+ /-- Disables incremental tactic reuse for `act` if `cond` is true. -/
483
+ def withoutTacticReuse [Monad m] [MonadWithReaderOf Context m] [MonadOptions m]
484
+ (cond : Bool) (act : m ��) : m α := do
485
+ let opts ← getOptions
486
+ withTheReader Term.Context (fun ctx => { ctx with tacSnap? := ctx.tacSnap?.map fun tacSnap =>
487
+ { tacSnap with old? := tacSnap.old?.filter fun old => Id.run do
488
+ if cond && opts.getBool `trace.Elab.reuse then
489
+ dbg_trace "reuse stopped: guard failed at {old.stx}"
490
+ return !cond }
491
+ }) act
492
+
493
+ @[inherit_doc Core.wrapAsyncAsSnapshot]
494
+ def wrapAsyncAsSnapshot {α : Type} (act : α → TermElabM Unit) (cancelTk? : Option IO.CancelToken)
495
+ (desc : String := by exact decl_name%.toString) :
496
+ TermElabM (α → BaseIO Language.SnapshotTree) := do
497
+ let ctx ← read
498
+ let st ← get
499
+ let metaCtx ← readThe Meta.Context
500
+ let metaSt ← getThe Meta.State
501
+ Core.wrapAsyncAsSnapshot (cancelTk? := cancelTk?) (desc := desc) fun a =>
502
+ act a |>.run ctx |>.run' st |>.run' metaCtx metaSt
503
+
504
+ abbrev TermElabResult (α : Type) := EStateM.Result Exception SavedState α
505
+
506
+ /--
507
+ Execute `x`, save resulting expression and new state.
508
+ We remove any `Info` created by `x`.
509
+ The info nodes are committed when we execute `applyResult`.
510
+ We use `observing` to implement overloaded notation and decls.
511
+ We want to save `Info` nodes for the chosen alternative.
512
+ -/
513
+ def observing (x : TermElabM α) : TermElabM (TermElabResult α) := do
514
+ let s ← saveState
515
+ try
516
+ let e ← x
517
+ let sNew ← saveState
518
+ s.restore (restoreInfo := true)
519
+ return EStateM.Result.ok e sNew
520
+ catch
521
+ | ex@(.error ..) =>
522
+ let sNew ← saveState
523
+ s.restore (restoreInfo := true)
524
+ return .error ex sNew
525
+ | ex@(.internal id _) =>
526
+ if id == postponeExceptionId then
527
+ s.restore (restoreInfo := true)
528
+ throw ex
529
+
530
+ /--
531
+ Apply the result/exception and state captured with `observing`.
532
+ We use this method to implement overloaded notation and symbols. -/
533
+ def applyResult (result : TermElabResult α) : TermElabM α := do
534
+ match result with
535
+ | .ok a r => r.restore (restoreInfo := true); return a
536
+ | .error ex r => r.restore (restoreInfo := true); throw ex
537
+
538
+ /--
539
+ Execute `x`, but keep state modifications only if `x` did not postpone.
540
+ This method is useful to implement elaboration functions that cannot decide whether
541
+ they need to postpone or not without updating the state. -/
542
+ def commitIfDidNotPostpone (x : TermElabM α) : TermElabM α := do
543
+ -- We just reuse the implementation of `observing` and `applyResult`.
544
+ let r ← observing x
545
+ applyResult r
546
+
547
+ /--
548
+ Return the universe level names explicitly provided by the user.
549
+ -/
550
+ def getLevelNames : TermElabM (List Name) :=
551
+ return (← get).levelNames
552
+
553
+ /--
554
+ Given a free variable `fvar`, return its declaration.
555
+ This function panics if `fvar` is not a free variable.
556
+ -/
557
+ def getFVarLocalDecl! (fvar : Expr) : TermElabM LocalDecl := do
558
+ match (← getLCtx).find? fvar.fvarId! with
559
+ | some d => pure d
560
+ | none => unreachable!
561
+
562
+ instance : AddErrorMessageContext TermElabM where
563
+ add ref msg := do
564
+ let ctx ← read
565
+ let ref := getBetterRef ref ctx.macroStack
566
+ let msg ← addMessageContext msg
567
+ let msg ← addMacroStack msg ctx.macroStack
568
+ pure (ref, msg)
569
+
570
+ /--
571
+ Execute `x` without storing `Syntax` for recursive applications. See `saveRecAppSyntax` field at `Context`.
572
+ -/
573
+ def withoutSavingRecAppSyntax (x : TermElabM α) : TermElabM α :=
574
+ withReader (fun ctx => { ctx with saveRecAppSyntax := false }) x
575
+
576
+ unsafe def mkTermElabAttributeUnsafe (ref : Name) : IO (KeyedDeclsAttribute TermElab) :=
577
+ mkElabAttribute TermElab `builtin_term_elab `term_elab `Lean.Parser.Term `Lean.Elab.Term.TermElab "term" ref
578
+
579
+ @[implemented_by mkTermElabAttributeUnsafe]
580
+ opaque mkTermElabAttribute (ref : Name) : IO (KeyedDeclsAttribute TermElab)
581
+
582
+ /--
583
+ Registers a term elaborator for the given syntax node kind.
584
+
585
+ A term elaborator should have type `Lean.Elab.Term.TermElab` (which is
586
+ `Lean.Syntax → Option Lean.Expr → Lean.Elab.Term.TermElabM Lean.Expr`), i.e. should take syntax of
587
+ the given syntax node kind and an optional expected type as parameters and produce an expression.
588
+
589
+ The `elab_rules` and `elab` commands should usually be preferred over using this attribute
590
+ directly.
591
+ -/
592
+ @[builtin_doc]
593
+ builtin_initialize termElabAttribute : KeyedDeclsAttribute TermElab ← mkTermElabAttribute decl_name%
594
+
595
+ /--
596
+ Auxiliary datatype for presenting a Lean lvalue modifier.
597
+ We represent an unelaborated lvalue as a `Syntax` (or `Expr`) and `List LVal`.
598
+ Example: `a.foo.1` is represented as the `Syntax` `a` and the list
599
+ `[LVal.fieldName "foo", LVal.fieldIdx 1]`.
600
+ -/
601
+ inductive LVal where
602
+ | fieldIdx (ref : Syntax) (i : Nat)
603
+ /-- Field `suffix?` is for producing better error messages because `x.y` may be a field access or a hierarchical/composite name.
604
+ `ref` is the syntax object representing the field. `fullRef` includes the LHS. -/
605
+ | fieldName (ref : Syntax) (name : String) (suffix? : Option Name) (fullRef : Syntax)
606
+
607
+ def LVal.getRef : LVal → Syntax
608
+ | .fieldIdx ref _ => ref
609
+ | .fieldName ref .. => ref
610
+
611
+ def LVal.isFieldName : LVal ��� Bool
612
+ | .fieldName .. => true
613
+ | _ => false
614
+
615
+ instance : ToString LVal where
616
+ toString
617
+ | .fieldIdx _ i => toString i
618
+ | .fieldName _ n .. => n
619
+
620
+ /-- Return the name of the declaration being elaborated if available. -/
621
+ def getDeclName? : TermElabM (Option Name) := return (← read).declName?
622
+ /-- Return the list of nested `let rec` declarations that need to be lifted. -/
623
+ def getLetRecsToLift : TermElabM (List LetRecToLift) := return (← get).letRecsToLift
624
+ /-- Return the declaration of the given metavariable -/
625
+ def getMVarDecl (mvarId : MVarId) : TermElabM MetavarDecl := return (← getMCtx).getDecl mvarId
626
+
627
+ instance : MonadParentDecl TermElabM where
628
+ getParentDeclName? := getDeclName?
629
+
630
+ /--
631
+ Executes `x` in the context of the given declaration name. Ensures that the info tree is set up
632
+ correctly and adjusts the declaration name generator to generate names below this name, resetting
633
+ the nested counter.
634
+ -/
635
+ def withDeclName (name : Name) (x : TermElabM α) : TermElabM α :=
636
+ withReader (fun ctx => { ctx with declName? := name }) do
637
+ withSaveParentDeclInfoContext do
638
+ withDeclNameForAuxNaming name do
639
+ x
640
+
641
+ /-- Update the universe level parameter names. -/
642
+ def setLevelNames (levelNames : List Name) : TermElabM Unit :=
643
+ modify fun s => { s with levelNames := levelNames }
644
+
645
+ /-- Execute `x` using `levelNames` as the universe level parameter names. See `getLevelNames`. -/
646
+ def withLevelNames (levelNames : List Name) (x : TermElabM α) : TermElabM α := do
647
+ let levelNamesSaved ← getLevelNames
648
+ setLevelNames levelNames
649
+ try x finally setLevelNames levelNamesSaved
650
+
651
+ def withoutErrToSorryImp (x : TermElabM α) : TermElabM α :=
652
+ withReader (fun ctx => { ctx with errToSorry := false }) x
653
+
654
+ /--
655
+ Execute `x` without converting errors (i.e., exceptions) to `sorry` applications.
656
+ Recall that when `errToSorry = true`, the method `elabTerm` catches exceptions and converts them into `sorry` applications.
657
+ -/
658
+ def withoutErrToSorry [MonadFunctorT TermElabM m] : m α → m α :=
659
+ monadMap (m := TermElabM) withoutErrToSorryImp
660
+
661
+ def withoutHeedElabAsElimImp (x : TermElabM α) : TermElabM α :=
662
+ withReader (fun ctx => { ctx with heedElabAsElim := false }) x
663
+
664
+ /--
665
+ Execute `x` without heeding the `elab_as_elim` attribute. Useful when there is
666
+ no expected type (so `elabAppArgs` would fail), but expect that the user wants
667
+ to use such constants.
668
+ -/
669
+ def withoutHeedElabAsElim [MonadFunctorT TermElabM m] : m α → m α :=
670
+ monadMap (m := TermElabM) withoutHeedElabAsElimImp
671
+
672
+ /--
673
+ Execute `x` but discard changes performed at `Term.State` and `Meta.State`.
674
+ Recall that the `Environment`, `InfoState` and messages are at `Core.State`. Thus, any updates to
675
+ it will be preserved.
676
+ This method is useful for performing computations where all metavariable must be resolved or
677
+ discarded.
678
+ The `InfoTree`s are not discarded, however, and wrapped in `InfoTree.Context`
679
+ to store their metavariable context.
680
+ -/
681
+ def withoutModifyingElabMetaStateWithInfo (x : TermElabM α) : TermElabM α := do
682
+ let s ← get
683
+ let sMeta ← getThe Meta.State
684
+ try
685
+ withSaveInfoContext x
686
+ finally
687
+ set s
688
+ set sMeta
689
+
690
+ /--
691
+ Execute `x` but discard changes performed to the state.
692
+ However, the info trees and messages are not discarded. -/
693
+ private def withoutModifyingStateWithInfoAndMessagesImpl (x : TermElabM α) : TermElabM α := do
694
+ let saved ← saveState
695
+ try
696
+ withSaveInfoContext x
697
+ finally
698
+ let saved := { saved with meta.core.infoState := (← getInfoState), meta.core.messages := (← getThe Core.State).messages }
699
+ restoreState saved
700
+
701
+ /--
702
+ Wraps the trees returned from `getInfoTrees`, if any, in an `InfoTree.context` node based on the
703
+ current monadic context and state. This is mainly used to report info trees early via
704
+ `Snapshot.infoTree?`. The trees are not removed from the `getInfoTrees` state as the final info tree
705
+ of the elaborated command should be complete and not depend on whether parts have been reported
706
+ early.
707
+
708
+ As `InfoTree.context` can have only one child, this function panics if `trees` contains more than 1
709
+ tree. Also, `PartialContextInfo.parentDeclCtx` is not currently generated as that information is not
710
+ available in the monadic context and only needed for the final info tree.
711
+ -/
712
+ def getInfoTreeWithContext? : TermElabM (Option InfoTree) := do
713
+ let st ← getInfoState
714
+ if st.trees.size > 1 then
715
+ return panic! "getInfoTreeWithContext: overfull tree"
716
+ let some t := st.trees[0]? |
717
+ return none
718
+ let t := t.substitute st.assignment
719
+ let ctx ← readThe Core.Context
720
+ let s ← getThe Core.State
721
+ let ctx := PartialContextInfo.commandCtx {
722
+ env := s.env, fileMap := ctx.fileMap, mctx := {}, currNamespace := ctx.currNamespace,
723
+ openDecls := ctx.openDecls, options := ctx.options, ngen := s.ngen
724
+ }
725
+ return InfoTree.context ctx t
726
+
727
+ /-- For testing `TermElabM` methods. The #eval command will sign the error. -/
728
+ def throwErrorIfErrors : TermElabM Unit := do
729
+ if (← MonadLog.hasErrors) then
730
+ throwError "Error(s)"
731
+
732
+ def traceAtCmdPos (cls : Name) (msg : Unit → MessageData) : TermElabM Unit :=
733
+ withRef Syntax.missing <| trace cls msg
734
+
735
+ def ppGoal (mvarId : MVarId) : TermElabM Format :=
736
+ Meta.ppGoal mvarId
737
+
738
+ open Level (LevelElabM)
739
+
740
+ def liftLevelM (x : LevelElabM α) : TermElabM α := do
741
+ let ctx ← read
742
+ let mctx ← getMCtx
743
+ let ngen ← getNGen
744
+ let lvlCtx : Level.Context := { options := (← getOptions), ref := (← getRef), autoBoundImplicit := ctx.autoBoundImplicit }
745
+ match (x lvlCtx).run { ngen := ngen, mctx := mctx, levelNames := (← getLevelNames) } with
746
+ | .ok a newS => setMCtx newS.mctx; setNGen newS.ngen; setLevelNames newS.levelNames; pure a
747
+ | .error ex _ => throw ex
748
+
749
+ def elabLevel (stx : Syntax) : TermElabM Level :=
750
+ liftLevelM <| Level.elabLevel stx
751
+
752
+ /-- Elaborate `x` with `stx` on the macro stack -/
753
+ def withPushMacroExpansionStack (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
754
+ withReader (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
755
+
756
+ /-- Elaborate `x` with `stx` on the macro stack and produce macro expansion info -/
757
+ def withMacroExpansion (beforeStx afterStx : Syntax) (x : TermElabM α) : TermElabM α :=
758
+ withMacroExpansionInfo beforeStx afterStx do
759
+ withPushMacroExpansionStack beforeStx afterStx x
760
+
761
+ /--
762
+ Add the given metavariable to the list of pending synthetic metavariables.
763
+ The method `synthesizeSyntheticMVars` is used to process the metavariables on this list. -/
764
+ def registerSyntheticMVar (stx : Syntax) (mvarId : MVarId) (kind : SyntheticMVarKind) : TermElabM Unit := do
765
+ modify fun s => { s with syntheticMVars := s.syntheticMVars.insert mvarId { stx, kind }, pendingMVars := mvarId :: s.pendingMVars }
766
+
767
+ def registerSyntheticMVarWithCurrRef (mvarId : MVarId) (kind : SyntheticMVarKind) : TermElabM Unit := do
768
+ registerSyntheticMVar (← getRef) mvarId kind
769
+
770
+ def registerMVarErrorInfo (mvarErrorInfo : MVarErrorInfo) : TermElabM Unit :=
771
+ modify fun s => { s with mvarErrorInfos := mvarErrorInfo :: s.mvarErrorInfos }
772
+
773
+ def registerMVarErrorHoleInfo (mvarId : MVarId) (ref : Syntax) : TermElabM Unit :=
774
+ registerMVarErrorInfo { mvarId, ref, kind := .hole }
775
+
776
+ def registerMVarErrorImplicitArgInfo (mvarId : MVarId) (ref : Syntax) (app : Expr) : TermElabM Unit := do
777
+ registerMVarErrorInfo { mvarId, ref, kind := .implicitArg (← getLCtx) app }
778
+
779
+ def registerMVarErrorCustomInfo (mvarId : MVarId) (ref : Syntax) (msgData : MessageData) : TermElabM Unit := do
780
+ registerMVarErrorInfo { mvarId, ref, kind := .custom msgData }
781
+
782
+ def registerCustomErrorIfMVar (e : Expr) (ref : Syntax) (msgData : MessageData) : TermElabM Unit :=
783
+ match e.getAppFn with
784
+ | Expr.mvar mvarId => registerMVarErrorCustomInfo mvarId ref msgData
785
+ | _ => pure ()
786
+
787
+ def registerMVarArgName (mvarId : MVarId) (argName : Name) : TermElabM Unit :=
788
+ modify fun s => { s with mvarArgNames := s.mvarArgNames.insert mvarId argName }
789
+
790
+ /--
791
+ Auxiliary method for reporting errors of the form "... contains metavariables ...".
792
+ This kind of error is thrown, for example, at `Match.lean` where elaboration
793
+ cannot continue if there are metavariables in patterns.
794
+ We only want to log it if we haven't logged any errors so far. -/
795
+ def throwMVarError (m : MessageData) : TermElabM α := do
796
+ if (← MonadLog.hasErrors) then
797
+ throwAbortTerm
798
+ else
799
+ throwError m
800
+
801
+ def MVarErrorInfo.logError (mvarErrorInfo : MVarErrorInfo) (extraMsg? : Option MessageData) : TermElabM Unit := do
802
+ match mvarErrorInfo.kind with
803
+ | MVarErrorKind.implicitArg lctx app => withLCtx lctx {} do
804
+ let app ← instantiateMVars app
805
+ let msg ← addArgName "don't know how to synthesize implicit argument"
806
+ let msg := msg ++ m!"{indentExpr app.setAppPPExplicitForExposingMVars}" ++ Format.line ++ "context:" ++ Format.line ++ MessageData.ofGoal mvarErrorInfo.mvarId
807
+ logErrorAt mvarErrorInfo.ref (appendExtra msg)
808
+ | MVarErrorKind.hole => do
809
+ let msg ← addArgName "don't know how to synthesize placeholder" " for argument"
810
+ let msg := msg ++ Format.line ++ "context:" ++ Format.line ++ MessageData.ofGoal mvarErrorInfo.mvarId
811
+ logErrorAt mvarErrorInfo.ref (MessageData.tagged `Elab.synthPlaceholder <| appendExtra msg)
812
+ | MVarErrorKind.custom msg =>
813
+ logErrorAt mvarErrorInfo.ref (appendExtra msg)
814
+ where
815
+ /-- Append the argument name (if available) to the message.
816
+ Remark: if the argument name contains macro scopes we do not append it. -/
817
+ addArgName (msg : MessageData) (extra : String := "") : TermElabM MessageData := do
818
+ match (← get).mvarArgNames.find? mvarErrorInfo.mvarId with
819
+ | none => return msg
820
+ | some argName => return if argName.hasMacroScopes then msg else msg ++ extra ++ m!" '{argName}'"
821
+
822
+ appendExtra (msg : MessageData) : MessageData :=
823
+ match extraMsg? with
824
+ | none => msg
825
+ | some extraMsg => msg ++ extraMsg
826
+
827
+ /--
828
+ Try to log errors for the unassigned metavariables `pendingMVarIds`.
829
+
830
+ Return `true` if there were "unfilled holes", and we should "abort" declaration.
831
+ TODO: try to fill "all" holes using synthetic "sorry's"
832
+
833
+ Remark: We only log the "unfilled holes" as new errors if no error has been logged so far. -/
834
+ def logUnassignedUsingErrorInfos (pendingMVarIds : Array MVarId) (extraMsg? : Option MessageData := none) : TermElabM Bool := do
835
+ if pendingMVarIds.isEmpty then
836
+ return false
837
+ else
838
+ let hasOtherErrors ← MonadLog.hasErrors
839
+ let mut hasNewErrors := false
840
+ let mut alreadyVisited : MVarIdSet := {}
841
+ let mut errors : Array MVarErrorInfo := #[]
842
+ for mvarErrorInfo in (← get).mvarErrorInfos do
843
+ let mvarId := mvarErrorInfo.mvarId
844
+ unless alreadyVisited.contains mvarId do
845
+ alreadyVisited := alreadyVisited.insert mvarId
846
+ /- The metavariable `mvarErrorInfo.mvarId` may have been assigned or
847
+ delayed assigned to another metavariable that is unassigned. -/
848
+ let mvarDeps ← getMVars (mkMVar mvarId)
849
+ if mvarDeps.any pendingMVarIds.contains then do
850
+ unless hasOtherErrors do
851
+ errors := errors.push mvarErrorInfo
852
+ hasNewErrors := true
853
+ -- To sort the errors by position use
854
+ -- let sortedErrors := errors.qsort fun e₁ e₂ => e₁.ref.getPos?.getD 0 < e₂.ref.getPos?.getD 0
855
+ for error in errors do
856
+ error.mvarId.withContext do
857
+ error.logError extraMsg?
858
+ return hasNewErrors
859
+
860
+ def registerLevelMVarErrorInfo (levelMVarErrorInfo : LevelMVarErrorInfo) : TermElabM Unit :=
861
+ modify fun s => { s with levelMVarErrorInfos := levelMVarErrorInfo :: s.levelMVarErrorInfos }
862
+
863
+ def registerLevelMVarErrorExprInfo (expr : Expr) (ref : Syntax) (msgData? : Option MessageData := none) : TermElabM Unit := do
864
+ registerLevelMVarErrorInfo { lctx := (← getLCtx), expr, ref, msgData? }
865
+
866
+ def exposeLevelMVars (e : Expr) : MetaM Expr :=
867
+ Core.transform e
868
+ (post := fun e => do
869
+ match e with
870
+ | .const _ us => return .done <| if us.any (·.isMVar) then e.setPPUniverses true else e
871
+ | .sort u => return .done <| if u.isMVar then e.setPPUniverses true else e
872
+ | .lam _ t _ _ => return .done <| if t.hasLevelMVar then e.setOption `pp.funBinderTypes true else e
873
+ | .letE _ t _ _ _ => return .done <| if t.hasLevelMVar then e.setOption `pp.letVarTypes true else e
874
+ | _ => return .done e)
875
+
876
+ def LevelMVarErrorInfo.logError (levelMVarErrorInfo : LevelMVarErrorInfo) : TermElabM Unit :=
877
+ Meta.withLCtx levelMVarErrorInfo.lctx {} do
878
+ let e' ← exposeLevelMVars (← instantiateMVars levelMVarErrorInfo.expr)
879
+ let msg := levelMVarErrorInfo.msgData?.getD m!"don't know how to synthesize universe level metavariables"
880
+ let msg := m!"{msg}{indentExpr e'}"
881
+ logErrorAt levelMVarErrorInfo.ref msg
882
+
883
+ /--
884
+ Try to log errors for unassigned level metavariables `pendingLevelMVarIds`.
885
+
886
+ Returns `true` if there are any relevant `LevelMVarErrorInfo`s and we should "abort" the declaration.
887
+
888
+ Remark: we only log unassigned level metavariables as new errors if no error has been logged so far.
889
+ -/
890
+ def logUnassignedLevelMVarsUsingErrorInfos (pendingLevelMVarIds : Array LMVarId) : TermElabM Bool := do
891
+ if pendingLevelMVarIds.isEmpty then
892
+ return false
893
+ else
894
+ let hasOtherErrors ← MonadLog.hasErrors
895
+ let mut hasNewErrors := false
896
+ let mut errors : Array LevelMVarErrorInfo := #[]
897
+ for levelMVarErrorInfo in (← get).levelMVarErrorInfos do
898
+ let e ← instantiateMVars levelMVarErrorInfo.expr
899
+ let lmvars := (collectLevelMVars {} e).result
900
+ if lmvars.any pendingLevelMVarIds.contains then do
901
+ unless hasOtherErrors do
902
+ errors := errors.push levelMVarErrorInfo
903
+ hasNewErrors := true
904
+ for error in errors do
905
+ error.logError
906
+ return hasNewErrors
907
+
908
+ /-- Ensure metavariables registered using `registerMVarErrorInfos` (and used in the given declaration) have been assigned. -/
909
+ def ensureNoUnassignedMVars (decl : Declaration) : TermElabM Unit := do
910
+ let pendingMVarIds ← getMVarsAtDecl decl
911
+ if (← logUnassignedUsingErrorInfos pendingMVarIds) then
912
+ throwAbortCommand
913
+
914
+ /--
915
+ Execute `x` without allowing it to postpone elaboration tasks.
916
+ That is, `tryPostpone` is a noop. -/
917
+ def withoutPostponing (x : TermElabM α) : TermElabM α :=
918
+ withReader (fun ctx => { ctx with mayPostpone := false }) x
919
+
920
+ /-- Creates syntax for `(` <ident> `:` <type> `)` -/
921
+ def mkExplicitBinder (ident : Syntax) (type : Syntax) : Syntax :=
922
+ mkNode ``Lean.Parser.Term.explicitBinder #[mkAtom "(", mkNullNode #[ident], mkNullNode #[mkAtom ":", type], mkNullNode, mkAtom ")"]
923
+
924
+ /--
925
+ Convert unassigned universe level metavariables into parameters.
926
+ The new parameter names are fresh names of the form `u_i` with regard to `ctx.levelNames`, which is updated with the new names. -/
927
+ def levelMVarToParam (e : Expr) (except : LMVarId → Bool := fun _ => false) : TermElabM Expr := do
928
+ let levelNames ← getLevelNames
929
+ let r := (← getMCtx).levelMVarToParam (fun n => levelNames.elem n) except e `u 1
930
+ -- Recall that the most recent universe is the first element of the field `levelNames`.
931
+ setLevelNames (r.newParamNames.reverse.toList ++ levelNames)
932
+ setMCtx r.mctx
933
+ return r.expr
934
+
935
+ /--
936
+ Creates a fresh inaccessible binder name based on `x`.
937
+ Equivalent to ``Lean.Core.mkFreshUserName `x``.
938
+
939
+ Do not confuse with `Lean.mkFreshId`, for creating fresh free variable and metavariable ids.
940
+ -/
941
+ def mkFreshBinderName [Monad m] [MonadQuotation m] : m Name :=
942
+ withFreshMacroScope <| MonadQuotation.addMacroScope `x
943
+
944
+ /--
945
+ Auxiliary method for creating a `Syntax.ident` containing
946
+ a fresh name. This method is intended for creating fresh binder names.
947
+ It is just a thin layer on top of `mkFreshUserName`. -/
948
+ def mkFreshIdent [Monad m] [MonadQuotation m] (ref : Syntax) (canonical := false) : m Ident :=
949
+ return mkIdentFrom ref (← mkFreshBinderName) canonical
950
+
951
+ private def applyAttributesCore
952
+ (declName : Name) (attrs : Array Attribute)
953
+ (applicationTime? : Option AttributeApplicationTime) : TermElabM Unit := do profileitM Exception "attribute application" (← getOptions) do
954
+ /-
955
+ Remark: if the declaration has syntax errors, `declName` may be `.anonymous` see issue #4309
956
+ In this case, we skip attribute application.
957
+ -/
958
+ if declName == .anonymous then
959
+ return
960
+ withDeclName declName do
961
+ for attr in attrs do
962
+ withTraceNode `Elab.attribute (fun _ => pure m!"applying [{attr.stx}]") do
963
+ withRef attr.stx do withLogging do
964
+ let env ← getEnv
965
+ match getAttributeImpl env attr.name with
966
+ | Except.error errMsg => throwError errMsg
967
+ | Except.ok attrImpl =>
968
+ let runAttr := attrImpl.add declName attr.stx attr.kind
969
+ let runAttr := do
970
+ -- not truly an elaborator, but a sensible target for go-to-definition
971
+ let elaborator := attrImpl.ref
972
+ if (← getInfoState).enabled then
973
+ withInfoContext (mkInfo := return .ofCommandInfo { elaborator, stx := attr.stx }) do
974
+ try runAttr
975
+ finally if attr.stx[0].isIdent || attr.stx[0].isAtom then
976
+ -- Add an additional node over the leading identifier if there is one to make it look more function-like.
977
+ -- Do this last because we want user-created infos to take precedence
978
+ pushInfoLeaf <| .ofCommandInfo { elaborator, stx := attr.stx[0] }
979
+ else
980
+ runAttr
981
+ match applicationTime? with
982
+ | none => runAttr
983
+ | some applicationTime =>
984
+ if applicationTime == attrImpl.applicationTime then
985
+ runAttr
986
+
987
+ /-- Apply given attributes **at** a given application time -/
988
+ def applyAttributesAt (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : TermElabM Unit :=
989
+ applyAttributesCore declName attrs applicationTime
990
+
991
+ def applyAttributes (declName : Name) (attrs : Array Attribute) : TermElabM Unit :=
992
+ applyAttributesCore declName attrs none
993
+
994
+ def mkTypeMismatchError (header? : Option MessageData) (e : Expr) (eType : Expr) (expectedType : Expr) : MetaM MessageData := do
995
+ let header : MessageData := match header? with
996
+ | some header => m!"{header} "
997
+ | none => m!"type mismatch{indentExpr e}\n"
998
+ return m!"{header}{← mkHasTypeButIsExpectedMsg eType expectedType}"
999
+
1000
+ def throwTypeMismatchError (header? : Option MessageData) (expectedType : Expr) (eType : Expr) (e : Expr)
1001
+ (f? : Option Expr := none) (_extraMsg? : Option MessageData := none) : MetaM α := do
1002
+ /-
1003
+ We ignore `extraMsg?` for now. In all our tests, it contained no useful information. It was
1004
+ always of the form:
1005
+ ```
1006
+ failed to synthesize instance
1007
+ CoeT <eType> <e> <expectedType>
1008
+ ```
1009
+ We should revisit this decision in the future and decide whether it may contain useful information
1010
+ or not. -/
1011
+ let extraMsg := Format.nil
1012
+ /-
1013
+ let extraMsg : MessageData := match extraMsg? with
1014
+ | none => Format.nil
1015
+ | some extraMsg => Format.line ++ extraMsg;
1016
+ -/
1017
+ match f? with
1018
+ | none => throwError "{← mkTypeMismatchError header? e eType expectedType}{extraMsg}"
1019
+ | some f => Meta.throwAppTypeMismatch f e
1020
+
1021
+ def withoutMacroStackAtErr (x : TermElabM α) : TermElabM α :=
1022
+ withTheReader Core.Context (fun (ctx : Core.Context) => { ctx with options := pp.macroStack.set ctx.options false }) x
1023
+
1024
+ namespace ContainsPendingMVar
1025
+
1026
+ abbrev M := MonadCacheT Expr Unit (OptionT MetaM)
1027
+
1028
+ /-- See `containsPostponedTerm` -/
1029
+ partial def visit (e : Expr) : M Unit := do
1030
+ checkCache e fun _ => do
1031
+ match e with
1032
+ | .forallE _ d b _ => visit d; visit b
1033
+ | .lam _ d b _ => visit d; visit b
1034
+ | .letE _ t v b _ => visit t; visit v; visit b
1035
+ | .app f a => visit f; visit a
1036
+ | .mdata _ b => visit b
1037
+ | .proj _ _ b => visit b
1038
+ | .fvar fvarId .. =>
1039
+ match (← fvarId.getDecl) with
1040
+ | .cdecl .. => return ()
1041
+ | .ldecl (value := v) .. => visit v
1042
+ | .mvar mvarId .. =>
1043
+ let e' ← instantiateMVars e
1044
+ if e' != e then
1045
+ visit e'
1046
+ else
1047
+ match (← getDelayedMVarAssignment? mvarId) with
1048
+ | some d => visit (mkMVar d.mvarIdPending)
1049
+ | none => failure
1050
+ | _ => return ()
1051
+
1052
+ end ContainsPendingMVar
1053
+
1054
+ /-- Return `true` if `e` contains a pending metavariable. Remark: it also visits let-declarations. -/
1055
+ def containsPendingMVar (e : Expr) : MetaM Bool := do
1056
+ match (← ContainsPendingMVar.visit e |>.run.run) with
1057
+ | some _ => return false
1058
+ | none => return true
1059
+
1060
+ /--
1061
+ Try to synthesize metavariable using type class resolution.
1062
+ This method assumes the local context and local instances of `instMVar` coincide
1063
+ with the current local context and local instances.
1064
+ Return `true` if the instance was synthesized successfully, and `false` if
1065
+ the instance contains unassigned metavariables that are blocking the type class
1066
+ resolution procedure. Throw an exception if resolution or assignment irrevocably fails.
1067
+
1068
+ If `extraErrorMsg?` is not none, it contains additional information that should be attached
1069
+ to type class synthesis failures.
1070
+ -/
1071
+ def synthesizeInstMVarCore (instMVar : MVarId) (maxResultSize? : Option Nat := none) (extraErrorMsg? : Option MessageData := none): TermElabM Bool := do
1072
+ let extraErrorMsg := extraMsgToMsg extraErrorMsg?
1073
+ let instMVarDecl ← getMVarDecl instMVar
1074
+ let type := instMVarDecl.type
1075
+ let type ← instantiateMVars type
1076
+ let result ← trySynthInstance type maxResultSize?
1077
+ match result with
1078
+ | LOption.some val =>
1079
+ if (← instMVar.isAssigned) then
1080
+ let oldVal ← instantiateMVars (mkMVar instMVar)
1081
+ unless (← isDefEq oldVal val) do
1082
+ if (← containsPendingMVar oldVal <||> containsPendingMVar val) then
1083
+ /- If `val` or `oldVal` contains metavariables directly or indirectly (e.g., in a let-declaration),
1084
+ we return `false` to indicate we should try again later. This is very coarse grain since
1085
+ the metavariable may not be responsible for the failure. We should refine the test in the future if needed.
1086
+ This check has been added to address dependencies between postponed metavariables. The following
1087
+ example demonstrates the issue fixed by this test.
1088
+ ```
1089
+ structure Point where
1090
+ x : Nat
1091
+ y : Nat
1092
+
1093
+ def Point.compute (p : Point) : Point :=
1094
+ let p := { p with x := 1 }
1095
+ let p := { p with y := 0 }
1096
+ if (p.x - p.y) > p.x then p else p
1097
+ ```
1098
+ The `isDefEq` test above fails for `Decidable (p.x - p.y ≤ p.x)` when the structure instance assigned to
1099
+ `p` has not been elaborated yet.
1100
+ -/
1101
+ return false -- we will try again later
1102
+ let oldValType ← inferType oldVal
1103
+ let valType ← inferType val
1104
+ unless (← isDefEq oldValType valType) do
1105
+ let (oldValType, valType) ← addPPExplicitToExposeDiff oldValType valType
1106
+ throwError "synthesized type class instance type is not definitionally equal to expected type, synthesized{indentExpr val}\nhas type{indentExpr valType}\nexpected{indentExpr oldValType}{extraErrorMsg}"
1107
+ let (oldVal, val) ← addPPExplicitToExposeDiff oldVal val
1108
+ throwError "synthesized type class instance is not definitionally equal to expression inferred by typing rules, synthesized{indentExpr val}\ninferred{indentExpr oldVal}{extraErrorMsg}"
1109
+ else
1110
+ unless (← isDefEq (mkMVar instMVar) val) do
1111
+ throwError "failed to assign synthesized type class instance{indentExpr val}{extraErrorMsg}"
1112
+ return true
1113
+ | .undef => return false -- we will try later
1114
+ | .none =>
1115
+ if (← read).ignoreTCFailures then
1116
+ return false
1117
+ else
1118
+ throwError "failed to synthesize{indentExpr type}{extraErrorMsg}{useDiagnosticMsg}"
1119
+
1120
+ def mkCoe (expectedType : Expr) (e : Expr) (f? : Option Expr := none) (errorMsgHeader? : Option String := none)
1121
+ (mkErrorMsg? : Option (MVarId → (expectedType e : Expr) → MetaM MessageData) := none)
1122
+ (mkImmedErrorMsg? : Option ((errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData) := none) : TermElabM Expr := do
1123
+ withTraceNode `Elab.coe (fun _ => return m!"adding coercion for {e} : {← inferType e} =?= {expectedType}") do
1124
+ try
1125
+ withoutMacroStackAtErr do
1126
+ match ← coerce? e expectedType with
1127
+ | .some eNew => return eNew
1128
+ | .none => failure
1129
+ | .undef =>
1130
+ let mvarAux ← mkFreshExprMVar expectedType MetavarKind.syntheticOpaque
1131
+ registerSyntheticMVarWithCurrRef mvarAux.mvarId! (.coe errorMsgHeader? expectedType e f? mkErrorMsg?)
1132
+ return mvarAux
1133
+ catch
1134
+ | .error _ msg =>
1135
+ if let some mkImmedErrorMsg := mkImmedErrorMsg? then
1136
+ throwError (← mkImmedErrorMsg msg expectedType e)
1137
+ else
1138
+ throwTypeMismatchError errorMsgHeader? expectedType (← inferType e) e f? msg
1139
+ | _ =>
1140
+ if let some mkImmedErrorMsg := mkImmedErrorMsg? then
1141
+ throwError (← mkImmedErrorMsg none expectedType e)
1142
+ else
1143
+ throwTypeMismatchError errorMsgHeader? expectedType (← inferType e) e f?
1144
+
1145
+ def mkCoeWithErrorMsgs (expectedType : Expr) (e : Expr)
1146
+ (mkImmedErrorMsg : (errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData)
1147
+ (mkErrorMsg : MVarId → (expectedType e : Expr) → MetaM MessageData) : TermElabM Expr := do
1148
+ mkCoe expectedType e (mkImmedErrorMsg? := mkImmedErrorMsg) (mkErrorMsg? := mkErrorMsg)
1149
+
1150
+ /--
1151
+ If `expectedType?` is `some t`, then ensures `t` and `eType` are definitionally equal by inserting a coercion if necessary.
1152
+
1153
+ Argument `f?` is used only for generating error messages when inserting coercions fails.
1154
+ -/
1155
+ def ensureHasType (expectedType? : Option Expr) (e : Expr)
1156
+ (errorMsgHeader? : Option String := none) (f? : Option Expr := none) : TermElabM Expr := do
1157
+ let some expectedType := expectedType? | return e
1158
+ if (← isDefEq (← inferType e) expectedType) then
1159
+ return e
1160
+ else
1161
+ mkCoe expectedType e f? errorMsgHeader?
1162
+
1163
+ def ensureHasTypeWithErrorMsgs (expectedType? : Option Expr) (e : Expr)
1164
+ (mkImmedErrorMsg : (errorMsg? : Option MessageData) → (expectedType e : Expr) → MetaM MessageData)
1165
+ (mkErrorMsg : MVarId → (expectedType e : Expr) → MetaM MessageData) : TermElabM Expr := do
1166
+ let some expectedType := expectedType? | return e
1167
+ if (← isDefEq (← inferType e) expectedType) then
1168
+ return e
1169
+ else
1170
+ mkCoeWithErrorMsgs expectedType e mkImmedErrorMsg mkErrorMsg
1171
+
1172
+ /--
1173
+ Create a synthetic sorry for the given expected type. If `expectedType? = none`, then a fresh
1174
+ metavariable is created to represent the type.
1175
+ -/
1176
+ private def mkSyntheticSorryFor (expectedType? : Option Expr) : TermElabM Expr := do
1177
+ let expectedType ← match expectedType? with
1178
+ | none => mkFreshTypeMVar
1179
+ | some expectedType => pure expectedType
1180
+ mkLabeledSorry expectedType (synthetic := true) (unique := false)
1181
+
1182
+ /--
1183
+ Log the given exception, and create a synthetic sorry for representing the failed
1184
+ elaboration step with exception `ex`.
1185
+ -/
1186
+ def exceptionToSorry (ex : Exception) (expectedType? : Option Expr) : TermElabM Expr := do
1187
+ let syntheticSorry ← mkSyntheticSorryFor expectedType?
1188
+ logException ex
1189
+ pure syntheticSorry
1190
+
1191
+ /-- If `mayPostpone == true`, throw `Exception.postpone`. -/
1192
+ def tryPostpone : TermElabM Unit := do
1193
+ if (← read).mayPostpone then
1194
+ throwPostpone
1195
+
1196
+ /-- Return `true` if `e` reduces (by unfolding only `[reducible]` declarations) to `?m ...` -/
1197
+ def isMVarApp (e : Expr) : TermElabM Bool :=
1198
+ return (← whnfR e).getAppFn.isMVar
1199
+
1200
+ /-- If `mayPostpone == true` and `e`'s head is a metavariable, throw `Exception.postpone`. -/
1201
+ def tryPostponeIfMVar (e : Expr) : TermElabM Unit := do
1202
+ if (← isMVarApp e) then
1203
+ tryPostpone
1204
+
1205
+ /-- If `e? = some e`, then `tryPostponeIfMVar e`, otherwise it is just `tryPostpone`. -/
1206
+ def tryPostponeIfNoneOrMVar (e? : Option Expr) : TermElabM Unit :=
1207
+ match e? with
1208
+ | some e => tryPostponeIfMVar e
1209
+ | none => tryPostpone
1210
+
1211
+ /--
1212
+ Throws `Exception.postpone`, if `expectedType?` contains unassigned metavariables.
1213
+ It is a noop if `mayPostpone == false`.
1214
+ -/
1215
+ def tryPostponeIfHasMVars? (expectedType? : Option Expr) : TermElabM (Option Expr) := do
1216
+ tryPostponeIfNoneOrMVar expectedType?
1217
+ let some expectedType := expectedType? | return none
1218
+ let expectedType ← instantiateMVars expectedType
1219
+ if expectedType.hasExprMVar then
1220
+ tryPostpone
1221
+ return none
1222
+ return some expectedType
1223
+
1224
+ /--
1225
+ Throws `Exception.postpone`, if `expectedType?` contains unassigned metavariables.
1226
+ If `mayPostpone == false`, it throws error `msg`.
1227
+ -/
1228
+ def tryPostponeIfHasMVars (expectedType? : Option Expr) (msg : String) : TermElabM Expr := do
1229
+ let some expectedType ← tryPostponeIfHasMVars? expectedType? |
1230
+ throwError "{msg}, expected type contains metavariables{indentD expectedType?}"
1231
+ return expectedType
1232
+
1233
+ def withExpectedType (expectedType? : Option Expr) (x : Expr → TermElabM Expr) : TermElabM Expr := do
1234
+ tryPostponeIfNoneOrMVar expectedType?
1235
+ let some expectedType ← pure expectedType?
1236
+ | throwError "expected type must be known"
1237
+ x expectedType
1238
+
1239
+ /--
1240
+ Save relevant context for term elaboration postponement.
1241
+ -/
1242
+ def saveContext : TermElabM SavedContext :=
1243
+ return {
1244
+ macroStack := (← read).macroStack
1245
+ declName? := (← read).declName?
1246
+ options := (← getOptions)
1247
+ openDecls := (← getOpenDecls)
1248
+ errToSorry := (← read).errToSorry
1249
+ levelNames := (← get).levelNames
1250
+ }
1251
+
1252
+ /--
1253
+ Execute `x` with the context saved using `saveContext`.
1254
+ -/
1255
+ def withSavedContext (savedCtx : SavedContext) (x : TermElabM α) : TermElabM α := do
1256
+ withReader (fun ctx => { ctx with declName? := savedCtx.declName?, macroStack := savedCtx.macroStack, errToSorry := savedCtx.errToSorry }) <|
1257
+ withTheReader Core.Context (fun ctx => { ctx with options := savedCtx.options, openDecls := savedCtx.openDecls }) <|
1258
+ withLevelNames savedCtx.levelNames x
1259
+
1260
+ /--
1261
+ Delay the elaboration of `stx`, and return a fresh metavariable that works a placeholder.
1262
+ Remark: the caller is responsible for making sure the info tree is properly updated.
1263
+ This method is used only at `elabUsingElabFnsAux`.
1264
+ -/
1265
+ private def postponeElabTermCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
1266
+ trace[Elab.postpone] "{stx} : {expectedType?}"
1267
+ let mvar ← mkFreshExprMVar expectedType? MetavarKind.syntheticOpaque
1268
+ registerSyntheticMVar stx mvar.mvarId! (SyntheticMVarKind.postponed (← saveContext))
1269
+ return mvar
1270
+
1271
+ def getSyntheticMVarDecl? (mvarId : MVarId) : TermElabM (Option SyntheticMVarDecl) :=
1272
+ return (← get).syntheticMVars.find? mvarId
1273
+
1274
+ register_builtin_option debug.byAsSorry : Bool := {
1275
+ defValue := false
1276
+ group := "debug"
1277
+ descr := "replace `by ..` blocks with `sorry` IF the expected type is a proposition"
1278
+ }
1279
+
1280
+ /--
1281
+ Creates a new metavariable of type `type` that will be synthesized using the tactic code.
1282
+ The `tacticCode` syntax is the full `by ..` syntax.
1283
+ -/
1284
+ def mkTacticMVar (type : Expr) (tacticCode : Syntax) (kind : TacticMVarKind)
1285
+ (delayOnMVars := false) : TermElabM Expr := do
1286
+ if ← pure (debug.byAsSorry.get (← getOptions)) <&&> isProp type then
1287
+ withRef tacticCode <| mkLabeledSorry type false (unique := true)
1288
+ else
1289
+ let mvar ← mkFreshExprMVar type MetavarKind.syntheticOpaque
1290
+ let mvarId := mvar.mvarId!
1291
+ let ref ← getRef
1292
+ registerSyntheticMVar ref mvarId <| .tactic tacticCode (← saveContext) kind delayOnMVars
1293
+ return mvar
1294
+
1295
+ /--
1296
+ Create an auxiliary annotation to make sure we create an `Info` even if `e` is a metavariable.
1297
+ See `mkTermInfo`.
1298
+
1299
+ We use this function because some elaboration functions elaborate subterms that may not be immediately
1300
+ part of the resulting term. Example:
1301
+ ```
1302
+ let_mvar% ?m := b; wait_if_type_mvar% ?m; body
1303
+ ```
1304
+ If the type of `b` is not known, then `wait_if_type_mvar% ?m; body` is postponed and just returns a fresh
1305
+ metavariable `?n`. The elaborator for
1306
+ ```
1307
+ let_mvar% ?m := b; wait_if_type_mvar% ?m; body
1308
+ ```
1309
+ returns `mkSaveInfoAnnotation ?n` to make sure the info nodes created when elaborating `b` are "saved".
1310
+ This is a bit hackish, but elaborators like `let_mvar%` are rare.
1311
+ -/
1312
+ def mkSaveInfoAnnotation (e : Expr) : Expr :=
1313
+ if e.isMVar then
1314
+ mkAnnotation `save_info e
1315
+ else
1316
+ e
1317
+
1318
+ def isSaveInfoAnnotation? (e : Expr) : Option Expr :=
1319
+ annotation? `save_info e
1320
+
1321
+ partial def removeSaveInfoAnnotation (e : Expr) : Expr :=
1322
+ match isSaveInfoAnnotation? e with
1323
+ | some e => removeSaveInfoAnnotation e
1324
+ | _ => e
1325
+
1326
+ /--
1327
+ Return `some mvarId` if `e` corresponds to a hole that is going to be filled "later" by executing a tactic or resuming elaboration.
1328
+
1329
+ We do not save `ofTermInfo` for this kind of node in the `InfoTree`.
1330
+ -/
1331
+ def isTacticOrPostponedHole? (e : Expr) : TermElabM (Option MVarId) := do
1332
+ match e with
1333
+ | Expr.mvar mvarId =>
1334
+ match (← getSyntheticMVarDecl? mvarId) with
1335
+ | some { kind := .tactic .., .. } => return mvarId
1336
+ | some { kind := .postponed .., .. } => return mvarId
1337
+ | _ => return none
1338
+ | _ => pure none
1339
+
1340
+ def mkTermInfo (elaborator : Name) (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none)
1341
+ (lctx? : Option LocalContext := none) (isBinder := false) :
1342
+ TermElabM (Sum Info MVarId) := do
1343
+ match (← isTacticOrPostponedHole? e) with
1344
+ | some mvarId => return Sum.inr mvarId
1345
+ | none =>
1346
+ let e := removeSaveInfoAnnotation e
1347
+ return Sum.inl <| Info.ofTermInfo { elaborator, lctx := lctx?.getD (← getLCtx), expr := e, stx, expectedType?, isBinder }
1348
+
1349
+ def mkPartialTermInfo (elaborator : Name) (stx : Syntax) (expectedType? : Option Expr := none)
1350
+ (lctx? : Option LocalContext := none) :
1351
+ TermElabM Info := do
1352
+ return Info.ofPartialTermInfo { elaborator, lctx := lctx?.getD (← getLCtx), stx, expectedType? }
1353
+
1354
+ /--
1355
+ Pushes a new leaf node to the info tree associating the expression `e` to the syntax `stx`.
1356
+ As a result, when the user hovers over `stx` they will see the type of `e`, and if `e`
1357
+ is a constant they will see the constant's doc string.
1358
+
1359
+ * `expectedType?`: the expected type of `e` at the point of elaboration, if available
1360
+ * `lctx?`: the local context in which to interpret `e` (otherwise it will use `← getLCtx`)
1361
+ * `elaborator`: a declaration name used as an alternative target for go-to-definition
1362
+ * `isBinder`: if true, this will be treated as defining `e` (which should be a local constant)
1363
+ for the purpose of go-to-definition on local variables
1364
+ * `force`: In patterns, the effect of `addTermInfo` is usually suppressed and replaced
1365
+ by a `patternWithRef?` annotation which will be turned into a term info on the
1366
+ post-match-elaboration expression. This flag overrides that behavior and adds the term
1367
+ info immediately. (See https://github.com/leanprover/lean4/pull/1664.)
1368
+ -/
1369
+ def addTermInfo (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none)
1370
+ (lctx? : Option LocalContext := none) (elaborator := Name.anonymous)
1371
+ (isBinder := false) (force := false) : TermElabM Expr := do
1372
+ if (← read).inPattern && !force then
1373
+ return mkPatternWithRef e stx
1374
+ else
1375
+ discard <| withInfoContext'
1376
+ (pure ())
1377
+ (fun _ => mkTermInfo elaborator stx e expectedType? lctx? isBinder)
1378
+ (mkPartialTermInfo elaborator stx expectedType? lctx?)
1379
+ return e
1380
+
1381
+ def addTermInfo' (stx : Syntax) (e : Expr) (expectedType? : Option Expr := none) (lctx? : Option LocalContext := none) (elaborator := Name.anonymous) (isBinder := false) : TermElabM Unit :=
1382
+ discard <| addTermInfo stx e expectedType? lctx? elaborator isBinder
1383
+
1384
+ def withInfoContext' (stx : Syntax) (x : TermElabM Expr)
1385
+ (mkInfo : Expr → TermElabM (Sum Info MVarId)) (mkInfoOnError : TermElabM Info) :
1386
+ TermElabM Expr := do
1387
+ if (← read).inPattern then
1388
+ let e ← x
1389
+ return mkPatternWithRef e stx
1390
+ else
1391
+ Elab.withInfoContext' x mkInfo mkInfoOnError
1392
+
1393
+ /-- Info node capturing `def/let rec` bodies, used by the unused variables linter. -/
1394
+ structure BodyInfo where
1395
+ /-- The body as a fully elaborated term. `none` if the body failed to elaborate. -/
1396
+ value? : Option Expr
1397
+ deriving TypeName
1398
+
1399
+ /-- Creates an `Info.ofCustomInfo` node backed by a `BodyInfo`. -/
1400
+ def mkBodyInfo (stx : Syntax) (value? : Option Expr) : Info :=
1401
+ .ofCustomInfo { stx, value := .mk { value? : BodyInfo } }
1402
+
1403
+ /-- Extracts a `BodyInfo` custom info. -/
1404
+ def getBodyInfo? : Info → Option BodyInfo
1405
+ | .ofCustomInfo { value, .. } => value.get? BodyInfo
1406
+ | _ => none
1407
+
1408
+ def withTermInfoContext' (elaborator : Name) (stx : Syntax) (x : TermElabM Expr)
1409
+ (expectedType? : Option Expr := none) (lctx? : Option LocalContext := none)
1410
+ (isBinder : Bool := false) :
1411
+ TermElabM Expr :=
1412
+ withInfoContext' stx x
1413
+ (mkTermInfo elaborator stx (expectedType? := expectedType?) (lctx? := lctx?) (isBinder := isBinder))
1414
+ (mkPartialTermInfo elaborator stx (expectedType? := expectedType?) (lctx? := lctx?))
1415
+
1416
+ /--
1417
+ Postpone the elaboration of `stx`, return a metavariable that acts as a placeholder, and
1418
+ ensures the info tree is updated and a hole id is introduced.
1419
+ When `stx` is elaborated, new info nodes are created and attached to the new hole id in the info tree.
1420
+ -/
1421
+ def postponeElabTerm (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
1422
+ withTermInfoContext' .anonymous stx (expectedType? := expectedType?) do
1423
+ postponeElabTermCore stx expectedType?
1424
+
1425
+ /--
1426
+ Helper function for `elabTerm` that tries the registered elaboration functions for `stxNode` kind until it finds one that supports the syntax or
1427
+ an error is found. -/
1428
+ private def elabUsingElabFnsAux (s : SavedState) (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone : Bool)
1429
+ : List (KeyedDeclsAttribute.AttributeEntry TermElab) → TermElabM Expr
1430
+ | [] => do throwError "unexpected syntax{indentD stx}"
1431
+ | (elabFn::elabFns) =>
1432
+ try
1433
+ -- record elaborator in info tree, but only when not backtracking to other elaborators (outer `try`)
1434
+ withTermInfoContext' elabFn.declName stx (expectedType? := expectedType?)
1435
+ (try
1436
+ elabFn.value stx expectedType?
1437
+ catch ex => match ex with
1438
+ | .error .. =>
1439
+ if (← read).errToSorry then
1440
+ exceptionToSorry ex expectedType?
1441
+ else
1442
+ throw ex
1443
+ | .internal id _ =>
1444
+ if (← read).errToSorry && id == abortTermExceptionId then
1445
+ exceptionToSorry ex expectedType?
1446
+ else if id == unsupportedSyntaxExceptionId then
1447
+ throw ex -- to outer try
1448
+ else if catchExPostpone && id == postponeExceptionId then
1449
+ /- If `elab` threw `Exception.postpone`, we reset any state modifications.
1450
+ For example, we want to make sure pending synthetic metavariables created by `elab` before
1451
+ it threw `Exception.postpone` are discarded.
1452
+ Note that we are also discarding the messages created by `elab`.
1453
+
1454
+ For example, consider the expression.
1455
+ `((f.x a1).x a2).x a3`
1456
+ Now, suppose the elaboration of `f.x a1` produces an `Exception.postpone`.
1457
+ Then, a new metavariable `?m` is created. Then, `?m.x a2` also throws `Exception.postpone`
1458
+ because the type of `?m` is not yet known. Then another, metavariable `?n` is created, and
1459
+ finally `?n.x a3` also throws `Exception.postpone`. If we did not restore the state, we would
1460
+ keep "dead" metavariables `?m` and `?n` on the pending synthetic metavariable list. This is
1461
+ wasteful because when we resume the elaboration of `((f.x a1).x a2).x a3`, we start it from scratch
1462
+ and new metavariables are created for the nested functions. -/
1463
+ s.restore
1464
+ postponeElabTermCore stx expectedType?
1465
+ else
1466
+ throw ex)
1467
+ catch ex => match ex with
1468
+ | .internal id _ =>
1469
+ if id == unsupportedSyntaxExceptionId then
1470
+ s.restore -- also removes the info tree created above
1471
+ elabUsingElabFnsAux s stx expectedType? catchExPostpone elabFns
1472
+ else
1473
+ throw ex
1474
+ | _ => throw ex
1475
+
1476
+ private def elabUsingElabFns (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone : Bool) : TermElabM Expr := do
1477
+ let s ← saveState
1478
+ let k := stx.getKind
1479
+ match termElabAttribute.getEntries (← getEnv) k with
1480
+ | [] => throwError "elaboration function for '{k}' has not been implemented{indentD stx}"
1481
+ | elabFns => elabUsingElabFnsAux s stx expectedType? catchExPostpone elabFns
1482
+
1483
+ instance : MonadMacroAdapter TermElabM where
1484
+ getCurrMacroScope := getCurrMacroScope
1485
+ getNextMacroScope := return (← getThe Core.State).nextMacroScope
1486
+ setNextMacroScope next := modifyThe Core.State fun s => { s with nextMacroScope := next }
1487
+
1488
+ private def isExplicit (stx : Syntax) : Bool :=
1489
+ match stx with
1490
+ | `(@$_) => true
1491
+ | _ => false
1492
+
1493
+ private def isExplicitApp (stx : Syntax) : Bool :=
1494
+ stx.getKind == ``Lean.Parser.Term.app && isExplicit stx[0]
1495
+
1496
+ /--
1497
+ Return true if `stx` is a lambda abstraction containing a `{}` or `[]` binder annotation.
1498
+ Example: `fun {α} (a : α) => a` -/
1499
+ private def isLambdaWithImplicit (stx : Syntax) : Bool :=
1500
+ match stx with
1501
+ | `(fun $binders* => $_) => binders.raw.any fun b => b.isOfKind ``Lean.Parser.Term.implicitBinder || b.isOfKind `Lean.Parser.Term.instBinder
1502
+ | _ => false
1503
+
1504
+ private partial def dropTermParens : Syntax → Syntax := fun stx =>
1505
+ match stx with
1506
+ | `(($stx)) => dropTermParens stx
1507
+ | _ => stx
1508
+
1509
+ private def isHole (stx : Syntax) : Bool :=
1510
+ stx.isOfKind ``Lean.Parser.Term.hole || stx.isOfKind ``Lean.Parser.Term.syntheticHole
1511
+
1512
+ private def isTacticBlock (stx : Syntax) : Bool :=
1513
+ match stx with
1514
+ | `(by $_:tacticSeq) => true
1515
+ | _ => false
1516
+
1517
+ private def isNoImplicitLambda (stx : Syntax) : Bool :=
1518
+ match stx with
1519
+ | `(no_implicit_lambda% $_:term) => true
1520
+ | _ => false
1521
+
1522
+ private def isTypeAscription (stx : Syntax) : Bool :=
1523
+ match stx with
1524
+ | `(($_ : $[$_]?)) => true
1525
+ | _ => false
1526
+
1527
+ def hasNoImplicitLambdaAnnotation (type : Expr) : Bool :=
1528
+ annotation? `noImplicitLambda type |>.isSome
1529
+
1530
+ def mkNoImplicitLambdaAnnotation (type : Expr) : Expr :=
1531
+ if hasNoImplicitLambdaAnnotation type then
1532
+ type
1533
+ else
1534
+ mkAnnotation `noImplicitLambda type
1535
+
1536
+ /-- Block usage of implicit lambdas if `stx` is `@f` or `@f arg1 ...` or `fun` with an implicit binder annotation. -/
1537
+ def blockImplicitLambda (stx : Syntax) : Bool :=
1538
+ let stx := dropTermParens stx
1539
+ -- TODO: make it extensible
1540
+ isExplicit stx || isExplicitApp stx || isLambdaWithImplicit stx || isHole stx || isTacticBlock stx ||
1541
+ isNoImplicitLambda stx || isTypeAscription stx
1542
+
1543
+ /-- Return true iff `stx` is a `Syntax.ident`, and it is a local variable. -/
1544
+ def isLocalIdent? (stx : Syntax) : TermElabM (Option Expr) :=
1545
+ match stx with
1546
+ | Syntax.ident _ _ val _ => do
1547
+ let r? ← resolveLocalName val
1548
+ match r? with
1549
+ | some (fvar, []) => return some fvar
1550
+ | _ => return none
1551
+ | _ => return none
1552
+
1553
+ inductive UseImplicitLambdaResult where
1554
+ | no
1555
+ | yes (expectedType : Expr)
1556
+ | postpone
1557
+
1558
+ /--
1559
+ Return normalized expected type if it is of the form `{a : α} → β` or `[a : α] → β` and
1560
+ `blockImplicitLambda stx` is not true, else return `none`.
1561
+
1562
+ Remark: implicit lambdas are not triggered by the strict implicit binder annotation `{{a : α}} → β`
1563
+ -/
1564
+ private def useImplicitLambda (stx : Syntax) (expectedType? : Option Expr) : TermElabM UseImplicitLambdaResult := do
1565
+ if blockImplicitLambda stx then
1566
+ return .no
1567
+ let some expectedType := expectedType? | return .no
1568
+ if hasNoImplicitLambdaAnnotation expectedType then
1569
+ return .no
1570
+ let expectedType ← whnfForall expectedType
1571
+ let .forallE _ _ _ c := expectedType | return .no
1572
+ unless c.isImplicit || c.isInstImplicit do
1573
+ return .no
1574
+ if let some x ← isLocalIdent? stx then
1575
+ if (← isMVarApp (← inferType x)) then
1576
+ /-
1577
+ If `stx` is a local variable without type information, then adding implicit lambdas makes elaboration fail.
1578
+ We should try to postpone elaboration until the type of the local variable becomes available, or disable
1579
+ implicit lambdas if we cannot postpone anymore.
1580
+ Here is an example where this special case is useful.
1581
+ ```
1582
+ def foo2mk (_ : ∀ {α : Type} (a : α), a = a) : nat := 37
1583
+ example (x) : foo2mk x = foo2mk x := rfl
1584
+ ```
1585
+ The example about would fail without this special case.
1586
+ The expected type would be `(a : α✝) → a = a`, where `α✝` is a new free variable introduced by the implicit lambda.
1587
+ Now, let `?m` be the type of `x`. Then, the constraint `?m =?= (a : α✝) → a = a` cannot be solved using the
1588
+ assignment `?m := (a : α✝) → a = a` since `α✝` is not in the scope of `?m`.
1589
+
1590
+ Note that, this workaround does not prevent the following example from failing.
1591
+ ```
1592
+ example (x) : foo2mk (id x) = 37 := rfl
1593
+ ```
1594
+ The user can write
1595
+ ```
1596
+ example (x) : foo2mk (id @x) = 37 := rfl
1597
+ ```
1598
+ -/
1599
+ return .postpone
1600
+ return .yes expectedType
1601
+
1602
+ private def decorateErrorMessageWithLambdaImplicitVars (ex : Exception) (impFVars : Array Expr) : TermElabM Exception := do
1603
+ match ex with
1604
+ | .error ref msg =>
1605
+ if impFVars.isEmpty then
1606
+ return Exception.error ref msg
1607
+ else
1608
+ let mut msg := m!"{msg}\nthe following variables have been introduced by the implicit lambda feature"
1609
+ for impFVar in impFVars do
1610
+ let auxMsg := m!"{impFVar} : {← inferType impFVar}"
1611
+ let auxMsg ← addMessageContext auxMsg
1612
+ msg := m!"{msg}{indentD auxMsg}"
1613
+ msg := m!"{msg}\nyou can disable implicit lambdas using `@` or writing a lambda expression with `\{}` or `[]` binder annotations."
1614
+ return Exception.error ref msg
1615
+ | _ => return ex
1616
+
1617
+ private def elabImplicitLambdaAux (stx : Syntax) (catchExPostpone : Bool) (expectedType : Expr) (impFVars : Array Expr) : TermElabM Expr := do
1618
+ let body ← elabUsingElabFns stx expectedType catchExPostpone
1619
+ try
1620
+ let body ← ensureHasType expectedType body
1621
+ let r ← mkLambdaFVars impFVars body
1622
+ trace[Elab.implicitForall] r
1623
+ return r
1624
+ catch ex =>
1625
+ throw (← decorateErrorMessageWithLambdaImplicitVars ex impFVars)
1626
+
1627
+ private partial def elabImplicitLambda (stx : Syntax) (catchExPostpone : Bool) (type : Expr) : TermElabM Expr :=
1628
+ loop type #[]
1629
+ where
1630
+ loop (type : Expr) (fvars : Array Expr) : TermElabM Expr := do
1631
+ match (← whnfForall type) with
1632
+ | .forallE n d b c =>
1633
+ if c.isExplicit then
1634
+ elabImplicitLambdaAux stx catchExPostpone type fvars
1635
+ else withFreshMacroScope do
1636
+ let n ← MonadQuotation.addMacroScope n
1637
+ withLocalDecl n c d fun fvar => do
1638
+ let type := b.instantiate1 fvar
1639
+ loop type (fvars.push fvar)
1640
+ | _ =>
1641
+ elabImplicitLambdaAux stx catchExPostpone type fvars
1642
+
1643
+ /-- Main loop for `elabTerm` -/
1644
+ private partial def elabTermAux (expectedType? : Option Expr) (catchExPostpone : Bool) (implicitLambda : Bool) : Syntax → TermElabM Expr
1645
+ | .missing => mkSyntheticSorryFor expectedType?
1646
+ | stx => withFreshMacroScope <| withIncRecDepth do
1647
+ withTraceNode `Elab.step (fun _ => return m!"expected type: {expectedType?}, term\n{stx}")
1648
+ (tag := stx.getKind.toString) do
1649
+ checkSystem "elaborator"
1650
+ let env ← getEnv
1651
+ let result ← match (← liftMacroM (expandMacroImpl? env stx)) with
1652
+ | some (decl, stxNew?) =>
1653
+ let stxNew ← liftMacroM <| liftExcept stxNew?
1654
+ withTermInfoContext' decl stx (expectedType? := expectedType?) <|
1655
+ withMacroExpansion stx stxNew <|
1656
+ withRef stxNew <|
1657
+ elabTermAux expectedType? catchExPostpone implicitLambda stxNew
1658
+ | _ =>
1659
+ let useImplicitResult ← if implicitLambda && (← read).implicitLambda then useImplicitLambda stx expectedType? else pure .no
1660
+ match useImplicitResult with
1661
+ | .yes expectedType => elabImplicitLambda stx catchExPostpone expectedType
1662
+ | .no => elabUsingElabFns stx expectedType? catchExPostpone
1663
+ | .postpone =>
1664
+ /-
1665
+ Try to postpone elaboration, and if we cannot postpone anymore disable implicit lambdas.
1666
+ See comment at `useImplicitLambda`.
1667
+ -/
1668
+ if (← read).mayPostpone then
1669
+ if catchExPostpone then
1670
+ postponeElabTerm stx expectedType?
1671
+ else
1672
+ throwPostpone
1673
+ else
1674
+ elabUsingElabFns stx expectedType? catchExPostpone
1675
+ trace[Elab.step.result] result
1676
+ pure result
1677
+
1678
+ /-- Store in the `InfoTree` that `e` is a "dot"-completion target. `stx` should cover the entire term. -/
1679
+ def addDotCompletionInfo (stx : Syntax) (e : Expr) (expectedType? : Option Expr) : TermElabM Unit := do
1680
+ addCompletionInfo <| CompletionInfo.dot { expr := e, stx, lctx := (← getLCtx), elaborator := .anonymous, expectedType? } (expectedType? := expectedType?)
1681
+
1682
+ /--
1683
+ Main function for elaborating terms.
1684
+ It extracts the elaboration methods from the environment using the node kind.
1685
+ Recall that the environment has a mapping from `SyntaxNodeKind` to `TermElab` methods.
1686
+ It creates a fresh macro scope for executing the elaboration method.
1687
+ All unlogged trace messages produced by the elaboration method are logged using
1688
+ the position information at `stx`. If the elaboration method throws an `Exception.error` and `errToSorry == true`,
1689
+ the error is logged and a synthetic sorry expression is returned.
1690
+ If the elaboration throws `Exception.postpone` and `catchExPostpone == true`,
1691
+ a new synthetic metavariable of kind `SyntheticMVarKind.postponed` is created, registered,
1692
+ and returned.
1693
+ The option `catchExPostpone == false` is used to implement `resumeElabTerm`
1694
+ to prevent the creation of another synthetic metavariable when resuming the elaboration.
1695
+
1696
+ If `implicitLambda == false`, then disable implicit lambdas feature for the given syntax, but not for its subterms.
1697
+ We use this flag to implement, for example, the `@` modifier. If `Context.implicitLambda == false`, then this parameter has no effect.
1698
+ -/
1699
+ def elabTerm (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone := true) (implicitLambda := true) : TermElabM Expr :=
1700
+ withRef stx <| elabTermAux expectedType? catchExPostpone implicitLambda stx
1701
+
1702
+ /--
1703
+ Similar to `Lean.Elab.Term.elabTerm`, but ensures that the type of the elaborated term is `expectedType?`
1704
+ by inserting coercions if necessary.
1705
+
1706
+ If `errToSorry` is true, then if coercion insertion fails, this function returns `sorry` and logs the error.
1707
+ Otherwise, it throws the error.
1708
+ -/
1709
+ def elabTermEnsuringType (stx : Syntax) (expectedType? : Option Expr) (catchExPostpone := true) (implicitLambda := true) (errorMsgHeader? : Option String := none) : TermElabM Expr := do
1710
+ let e ← elabTerm stx expectedType? catchExPostpone implicitLambda
1711
+ try
1712
+ withRef stx <| ensureHasType expectedType? e errorMsgHeader?
1713
+ catch ex =>
1714
+ if (← read).errToSorry && ex matches .error .. then
1715
+ withRef stx <| exceptionToSorry ex expectedType?
1716
+ else
1717
+ throw ex
1718
+
1719
+ /-- Execute `x` and return `some` if no new errors were recorded or exceptions were thrown. Otherwise, return `none`. -/
1720
+ def commitIfNoErrors? (x : TermElabM α) : TermElabM (Option α) := do
1721
+ let saved ← saveState
1722
+ Core.resetMessageLog
1723
+ try
1724
+ let a ← x
1725
+ if (← MonadLog.hasErrors) then
1726
+ restoreState saved
1727
+ return none
1728
+ else
1729
+ Core.setMessageLog (saved.meta.core.messages ++ (← Core.getMessageLog))
1730
+ return a
1731
+ catch _ =>
1732
+ restoreState saved
1733
+ return none
1734
+
1735
+ /-- Adapt a syntax transformation to a regular, term-producing elaborator. -/
1736
+ def adaptExpander (exp : Syntax → TermElabM Syntax) : TermElab := fun stx expectedType? => do
1737
+ let stx' ← exp stx
1738
+ withMacroExpansion stx stx' <| elabTerm stx' expectedType?
1739
+
1740
+ /--
1741
+ Create a new metavariable with the given type, and try to synthesize it.
1742
+ If type class resolution cannot be executed (e.g., it is stuck because of metavariables in `type`),
1743
+ register metavariable as a pending one.
1744
+ -/
1745
+ def mkInstMVar (type : Expr) (extraErrorMsg? : Option MessageData := none) : TermElabM Expr := do
1746
+ let mvar ← mkFreshExprMVar type MetavarKind.synthetic
1747
+ let mvarId := mvar.mvarId!
1748
+ unless (← synthesizeInstMVarCore mvarId (extraErrorMsg? := extraErrorMsg?)) do
1749
+ registerSyntheticMVarWithCurrRef mvarId (.typeClass extraErrorMsg?)
1750
+ return mvar
1751
+
1752
+ /--
1753
+ Make sure `e` is a type by inferring its type and making sure it is an `Expr.sort`
1754
+ or is unifiable with `Expr.sort`, or can be coerced into one. -/
1755
+ def ensureType (e : Expr) : TermElabM Expr := do
1756
+ if (← isType e) then
1757
+ return e
1758
+ else
1759
+ let eType ← inferType e
1760
+ let u ← mkFreshLevelMVar
1761
+ if (← isDefEq eType (mkSort u)) then
1762
+ return e
1763
+ else if let some coerced ← coerceToSort? e then
1764
+ return coerced
1765
+ else
1766
+ if (← instantiateMVars e).hasSyntheticSorry then
1767
+ throwAbortTerm
1768
+ throwError "type expected, got\n ({← instantiateMVars e} : {← instantiateMVars eType})"
1769
+
1770
+ /-- Elaborate `stx` and ensure result is a type. -/
1771
+ def elabType (stx : Syntax) : TermElabM Expr := do
1772
+ let u ← mkFreshLevelMVar
1773
+ let type ← elabTerm stx (mkSort u)
1774
+ withRef stx <| ensureType type
1775
+
1776
+ /--
1777
+ Enable auto-bound implicits, and execute `k` while catching auto bound implicit exceptions. When an exception is caught,
1778
+ a new local declaration is created, registered, and `k` is tried to be executed again. -/
1779
+ partial def withAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
1780
+ let flag := autoImplicit.get (← getOptions)
1781
+ if flag then
1782
+ withReader (fun ctx => { ctx with autoBoundImplicit := flag, autoBoundImplicits := {} }) do
1783
+ let rec loop (s : SavedState) : TermElabM α := withIncRecDepth do
1784
+ checkSystem "auto-implicit"
1785
+ try
1786
+ k
1787
+ catch
1788
+ | ex => match isAutoBoundImplicitLocalException? ex with
1789
+ | some n =>
1790
+ -- Restore state, declare `n`, and try again
1791
+ s.restore (restoreInfo := true)
1792
+ withLocalDecl n .implicit (← mkFreshTypeMVar) fun x =>
1793
+ withReader (fun ctx => { ctx with autoBoundImplicits := ctx.autoBoundImplicits.push x } ) do
1794
+ loop (← saveState)
1795
+ | none => throw ex
1796
+ loop (← saveState)
1797
+ else
1798
+ k
1799
+
1800
+ def withoutAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
1801
+ withReader (fun ctx => { ctx with autoBoundImplicit := false, autoBoundImplicits := {} }) k
1802
+
1803
+ partial def withAutoBoundImplicitForbiddenPred (p : Name → Bool) (x : TermElabM α) : TermElabM α := do
1804
+ withReader (fun ctx => { ctx with autoBoundImplicitForbidden := fun n => p n || ctx.autoBoundImplicitForbidden n }) x
1805
+
1806
+ /--
1807
+ Collect unassigned metavariables in `type` that are not already in `init` and not satisfying `except`.
1808
+ -/
1809
+ partial def collectUnassignedMVars (type : Expr) (init : Array Expr := #[]) (except : MVarId → Bool := fun _ => false)
1810
+ : TermElabM (Array Expr) := do
1811
+ let mvarIds ← getMVars type
1812
+ if mvarIds.isEmpty then
1813
+ return init
1814
+ else
1815
+ go mvarIds.toList init init
1816
+ where
1817
+ go (mvarIds : List MVarId) (result visited : Array Expr) : TermElabM (Array Expr) := do
1818
+ match mvarIds with
1819
+ | [] => return result
1820
+ | mvarId :: mvarIds => do
1821
+ let visited := visited.push (mkMVar mvarId)
1822
+ if (← mvarId.isAssigned) then
1823
+ go mvarIds result visited
1824
+ else if result.contains (mkMVar mvarId) || except mvarId then
1825
+ go mvarIds result visited
1826
+ else
1827
+ let mvarType := (← getMVarDecl mvarId).type
1828
+ let mvarIdsNew ← getMVars mvarType
1829
+ let mvarIdsNew := mvarIdsNew.filter fun mvarId => !visited.contains (mkMVar mvarId)
1830
+ if mvarIdsNew.isEmpty then
1831
+ go mvarIds (result.push (mkMVar mvarId)) visited
1832
+ else
1833
+ go (mvarIdsNew.toList ++ mvarId :: mvarIds) result visited
1834
+
1835
+ /--
1836
+ Adds an `InlayHintInfo` for the fvar auto implicits in `autos` at `inlayHintPos`.
1837
+ The inserted inlay hint has a hover that denotes the type of the auto-implicit (with meta-variables)
1838
+ and can be inserted at `inlayHintPos`.
1839
+ -/
1840
+ def addAutoBoundImplicitsInlayHint (autos : Array Expr) (inlayHintPos : String.Pos) : TermElabM Unit := do
1841
+ -- If the list of auto-implicits contains a non-type fvar, then the list of auto-implicits will
1842
+ -- also contain an mvar that denotes the type of the non-type fvar.
1843
+ -- For example, the auto-implicit `x` in a type `Foo x` for `Foo.{u} {α : Sort u} (x : α) : Type`
1844
+ -- also comes with an auto-implicit mvar denoting the type of `x`.
1845
+ -- We have no way of displaying this mvar to the user in an inlay hint, as it doesn't have a name,
1846
+ -- so we filter it.
1847
+ -- This also means that inserting the inlay hint with the syntax displayed in the inlay hint will
1848
+ -- cause a "failed to infer binder type" error, since we don't have a name to insert in the code.
1849
+ let autos := autos.filter (· matches .fvar ..)
1850
+ if autos.isEmpty then
1851
+ return
1852
+ let autoNames ← autos.mapM (·.fvarId!.getUserName)
1853
+ let formattedHint := s!" \{{" ".intercalate <| Array.toList <| autoNames.map toString}}"
1854
+ let deferredResolution ih := do
1855
+ let description := "Automatically-inserted implicit parameters:"
1856
+ let codeBlockStart := "```lean"
1857
+ let typeInfos ← autos.mapM fun auto => do
1858
+ let name := toString <| ← auto.fvarId!.getUserName
1859
+ let type := toString <| ← Meta.ppExpr <| ← instantiateMVars (← inferType auto)
1860
+ return s!"{name} : {type}"
1861
+ let codeBlockEnd := "```"
1862
+ let tooltip := "\n".intercalate <| description :: codeBlockStart :: typeInfos.toList ++ [codeBlockEnd]
1863
+ return { ih with tooltip? := tooltip }
1864
+ pushInfoLeaf <| .ofCustomInfo {
1865
+ position := inlayHintPos
1866
+ label := .name formattedHint
1867
+ textEdits := #[{
1868
+ range := ⟨inlayHintPos, inlayHintPos⟩,
1869
+ newText := formattedHint
1870
+ }]
1871
+ kind? := some .parameter
1872
+ lctx := ← getLCtx
1873
+ deferredResolution
1874
+ : InlayHint
1875
+ }.toCustomInfo
1876
+
1877
+ /--
1878
+ Return `autoBoundImplicits ++ xs`
1879
+ This method throws an error if a variable in `autoBoundImplicits` depends on some `x` in `xs`.
1880
+ The `autoBoundImplicits` may contain free variables created by the auto-implicit feature, and unassigned free variables.
1881
+ It avoids the hack used at `autoBoundImplicitsOld`.
1882
+
1883
+ If `inlayHintPos?` is set, this function also inserts an inlay hint denoting `autoBoundImplicits`.
1884
+ See `addAutoBoundImplicitsInlayHint` for more information.
1885
+
1886
+ Remark: we cannot simply replace every occurrence of `addAutoBoundImplicitsOld` with this one because a particular
1887
+ use-case may not be able to handle the metavariables in the array being given to `k`.
1888
+ -/
1889
+ def addAutoBoundImplicits (xs : Array Expr) (inlayHintPos? : Option String.Pos) : TermElabM (Array Expr) := do
1890
+ let autos := (← read).autoBoundImplicits
1891
+ go autos.toList #[]
1892
+ where
1893
+ go (todo : List Expr) (autos : Array Expr) : TermElabM (Array Expr) := do
1894
+ match todo with
1895
+ | [] =>
1896
+ if let some inlayHintPos := inlayHintPos? then
1897
+ addAutoBoundImplicitsInlayHint autos inlayHintPos
1898
+ for auto in autos do
1899
+ if auto.isFVar then
1900
+ let localDecl ← auto.fvarId!.getDecl
1901
+ for x in xs do
1902
+ if (← localDeclDependsOn localDecl x.fvarId!) then
1903
+ throwError "invalid auto implicit argument '{auto}', it depends on explicitly provided argument '{x}'"
1904
+ return autos ++ xs
1905
+ | auto :: todo =>
1906
+ let autos ← collectUnassignedMVars (← inferType auto) autos
1907
+ go todo (autos.push auto)
1908
+
1909
+ /--
1910
+ Similar to `addAutoBoundImplicits`, but converts all metavariables into free variables.
1911
+
1912
+ It uses `mkForallFVars` + `forallBoundedTelescope` to convert metavariables into free variables.
1913
+ The type `type` is modified during the process if type depends on `xs`.
1914
+ We use this method to simplify the conversion of code using `autoBoundImplicitsOld` to `autoBoundImplicits`.
1915
+ -/
1916
+ def addAutoBoundImplicits' (xs : Array Expr) (type : Expr) (k : Array Expr → Expr → TermElabM α) (inlayHintPos? : Option String.Pos := none) : TermElabM α := do
1917
+ let xs ← addAutoBoundImplicits xs inlayHintPos?
1918
+ if xs.all (·.isFVar) then
1919
+ k xs type
1920
+ else
1921
+ forallBoundedTelescope (← mkForallFVars xs type) xs.size fun xs type => k xs type
1922
+
1923
+ def mkAuxName (suffix : Name) : TermElabM Name := mkAuxDeclName (kind := suffix)
1924
+
1925
+ builtin_initialize registerTraceClass `Elab.letrec
1926
+
1927
+ /-- Return true if mvarId is an auxiliary metavariable created for compiling `let rec` or it
1928
+ is delayed assigned to one. -/
1929
+ def isLetRecAuxMVar (mvarId : MVarId) : TermElabM Bool := do
1930
+ trace[Elab.letrec] "mvarId: {mkMVar mvarId} letrecMVars: {(← get).letRecsToLift.map (mkMVar $ ·.mvarId)}"
1931
+ let mvarId ← getDelayedMVarRoot mvarId
1932
+ trace[Elab.letrec] "mvarId root: {mkMVar mvarId}"
1933
+ return (← get).letRecsToLift.any (·.mvarId == mvarId)
1934
+
1935
+ private def checkDeprecatedCore (constName : Name) : TermElabM Unit := do
1936
+ if (← read).checkDeprecated then
1937
+ Linter.checkDeprecated constName
1938
+
1939
+ /--
1940
+ Create an `Expr.const` using the given name and explicit levels.
1941
+ Remark: fresh universe metavariables are created if the constant has more universe
1942
+ parameters than `explicitLevels`.
1943
+
1944
+ If `checkDeprecated := true`, then `Linter.checkDeprecated` is invoked.
1945
+ -/
1946
+ def mkConst (constName : Name) (explicitLevels : List Level := []) : TermElabM Expr := do
1947
+ checkDeprecatedCore constName
1948
+ let cinfo ← getConstVal constName
1949
+ if explicitLevels.length > cinfo.levelParams.length then
1950
+ throwError "too many explicit universe levels for '{constName}'"
1951
+ else
1952
+ let numMissingLevels := cinfo.levelParams.length - explicitLevels.length
1953
+ let us ← mkFreshLevelMVars numMissingLevels
1954
+ return Lean.mkConst constName (explicitLevels ++ us)
1955
+
1956
+ def checkDeprecated (ref : Syntax) (e : Expr) : TermElabM Unit := do
1957
+ if let .const declName _ := e.getAppFn then
1958
+ withRef ref do checkDeprecatedCore declName
1959
+
1960
+ @[inline] def withoutCheckDeprecated [MonadWithReaderOf Context m] : m α → m α :=
1961
+ withTheReader Context (fun ctx => { ctx with checkDeprecated := false })
1962
+
1963
+ private def mkConsts (candidates : List (Name × List String)) (explicitLevels : List Level) : TermElabM (List (Expr × List String)) := do
1964
+ candidates.foldlM (init := []) fun result (declName, projs) => do
1965
+ -- TODO: better support for `mkConst` failure. We may want to cache the failures, and report them if all candidates fail.
1966
+ /-
1967
+ We disable `checkDeprecated` here because there may be many overloaded symbols.
1968
+ Note that, this method and `resolveName` and `resolveName'` return a list of pairs instead of a list of `TermElabResult`s.
1969
+ We perform the `checkDeprecated` test at `resolveId?` and `elabAppFnId`.
1970
+ At `elabAppFnId`, we perform the check when converting the list returned by `resolveName'` into a list of
1971
+ `TermElabResult`s.
1972
+ -/
1973
+ let const ← withoutCheckDeprecated <| mkConst declName explicitLevels
1974
+ return (const, projs) :: result
1975
+
1976
+ def resolveName (stx : Syntax) (n : Name) (preresolved : List Syntax.Preresolved) (explicitLevels : List Level) (expectedType? : Option Expr := none) : TermElabM (List (Expr × List String)) := do
1977
+ addCompletionInfo <| CompletionInfo.id stx stx.getId (danglingDot := false) (← getLCtx) expectedType?
1978
+ if let some (e, projs) ← resolveLocalName n then
1979
+ unless explicitLevels.isEmpty do
1980
+ throwError "invalid use of explicit universe parameters, '{e}' is a local variable"
1981
+ return [(e, projs)]
1982
+ let preresolved := preresolved.filterMap fun
1983
+ | .decl n projs => some (n, projs)
1984
+ | _ => none
1985
+ -- check for section variable capture by a quotation
1986
+ let ctx ← read
1987
+ if let some (e, projs) := preresolved.findSome? fun (n, projs) => ctx.sectionFVars.find? n |>.map (·, projs) then
1988
+ return [(e, projs)] -- section variables should shadow global decls
1989
+ if preresolved.isEmpty then
1990
+ process (← realizeGlobalName n)
1991
+ else
1992
+ process preresolved
1993
+ where
1994
+ process (candidates : List (Name × List String)) : TermElabM (List (Expr × List String)) := do
1995
+ if candidates.isEmpty then
1996
+ if (← read).autoBoundImplicit &&
1997
+ !(← read).autoBoundImplicitForbidden n &&
1998
+ isValidAutoBoundImplicitName n (relaxedAutoImplicit.get (← getOptions)) then
1999
+ throwAutoBoundImplicitLocal n
2000
+ else
2001
+ throwUnknownIdentifierAt stx m!"unknown identifier '{Lean.mkConst n}'"
2002
+ mkConsts candidates explicitLevels
2003
+
2004
+ /--
2005
+ Similar to `resolveName`, but creates identifiers for the main part and each projection with position information derived from `ident`.
2006
+ Example: Assume resolveName `v.head.bla.boo` produces `(v.head, ["bla", "boo"])`, then this method produces
2007
+ `(v.head, id, [f₁, f₂])` where `id` is an identifier for `v.head`, and `f₁` and `f₂` are identifiers for fields `"bla"` and `"boo"`. -/
2008
+ def resolveName' (ident : Syntax) (explicitLevels : List Level) (expectedType? : Option Expr := none) : TermElabM (List (Expr × Syntax × List Syntax)) := do
2009
+ match ident with
2010
+ | .ident _ _ n preresolved =>
2011
+ let r ← resolveName ident n preresolved explicitLevels expectedType?
2012
+ r.mapM fun (c, fields) => do
2013
+ let ids := ident.identComponents (nFields? := fields.length)
2014
+ return (c, ids.head!, ids.tail!)
2015
+ | _ => throwError "identifier expected"
2016
+
2017
+ def resolveId? (stx : Syntax) (kind := "term") (withInfo := false) : TermElabM (Option Expr) := withRef stx do
2018
+ match stx with
2019
+ | .ident _ _ val preresolved =>
2020
+ let rs ← try resolveName stx val preresolved [] catch _ => pure []
2021
+ let rs := rs.filter fun ⟨_, projs⟩ => projs.isEmpty
2022
+ let fs := rs.map fun (f, _) => f
2023
+ match fs with
2024
+ | [] => return none
2025
+ | [f] =>
2026
+ let f ← if withInfo then addTermInfo stx f else pure f
2027
+ checkDeprecated stx f
2028
+ return some f
2029
+ | _ => throwError "ambiguous {kind}, use fully qualified name, possible interpretations {fs}"
2030
+ | _ => throwError "identifier expected"
2031
+
2032
+ def TermElabM.run (x : TermElabM α) (ctx : Context := {}) (s : State := {}) : MetaM (α × State) :=
2033
+ withConfig setElabConfig (x ctx |>.run s)
2034
+
2035
+ @[inline] def TermElabM.run' (x : TermElabM α) (ctx : Context := {}) (s : State := {}) : MetaM α :=
2036
+ (·.1) <$> x.run ctx s
2037
+
2038
+ def TermElabM.toIO (x : TermElabM α)
2039
+ (ctxCore : Core.Context) (sCore : Core.State)
2040
+ (ctxMeta : Meta.Context) (sMeta : Meta.State)
2041
+ (ctx : Context) (s : State) : IO (α × Core.State × Meta.State × State) := do
2042
+ let ((a, s), sCore, sMeta) ← (x.run ctx s).toIO ctxCore sCore ctxMeta sMeta
2043
+ return (a, sCore, sMeta, s)
2044
+
2045
+ /--
2046
+ Execute `x` and then tries to solve pending universe constraints.
2047
+ Note that, stuck constraints will not be discarded.
2048
+ -/
2049
+ def universeConstraintsCheckpoint (x : TermElabM α) : TermElabM α := do
2050
+ let a ← x
2051
+ discard <| processPostponed (mayPostpone := true) (exceptionOnFailure := true)
2052
+ return a
2053
+
2054
+ def expandDeclId (currNamespace : Name) (currLevelNames : List Name) (declId : Syntax) (modifiers : Modifiers) : TermElabM ExpandDeclIdResult := do
2055
+ let r ← Elab.expandDeclId currNamespace currLevelNames declId modifiers
2056
+ if (← read).sectionVars.contains r.shortName then
2057
+ throwError "invalid declaration name '{r.shortName}', there is a section variable with the same name"
2058
+ return r
2059
+
2060
+ /--
2061
+ Helper function for "embedding" an `Expr` in `Syntax`.
2062
+ It creates a named hole `?m` and immediately assigns `e` to it.
2063
+ Examples:
2064
+ ```lean
2065
+ let e := mkConst ``Nat.zero
2066
+ `(Nat.succ $(← exprToSyntax e))
2067
+ ```
2068
+ -/
2069
+ def exprToSyntax (e : Expr) : TermElabM Term := withFreshMacroScope do
2070
+ let result ← `(?m)
2071
+ let eType ← inferType e
2072
+ let mvar ← elabTerm result eType
2073
+ mvar.mvarId!.assign e
2074
+ return result
2075
+
2076
+ end Term
2077
+
2078
+ open Term in
2079
+ def withoutModifyingStateWithInfoAndMessages [MonadControlT TermElabM m] [Monad m] (x : m α) : m α := do
2080
+ controlAt TermElabM fun runInBase => withoutModifyingStateWithInfoAndMessagesImpl <| runInBase x
2081
+
2082
+ builtin_initialize
2083
+ registerTraceClass `Elab.postpone
2084
+ registerTraceClass `Elab.coe
2085
+ registerTraceClass `Elab.debug
2086
+ registerTraceClass `Elab.reuse
2087
+
2088
+ /--
2089
+ Marks an elaborator (tactic or command, currently) as supporting incremental elaboration.
2090
+
2091
+ For unmarked elaborators, the corresponding snapshot bundle field in the elaboration context is
2092
+ unset so as to prevent accidental, incorrect reuse.
2093
+ -/
2094
+ @[builtin_doc]
2095
+ builtin_initialize incrementalAttr : TagAttribute ←
2096
+ registerTagAttribute `incremental "Marks an elaborator (tactic or command, currently) as \
2097
+ supporting incremental elaboration. For unmarked elaborators, the corresponding snapshot bundle \
2098
+ field in the elaboration context is unset so as to prevent accidental, incorrect reuse."
2099
+
2100
+ builtin_initialize builtinIncrementalElabs : IO.Ref NameSet ← IO.mkRef {}
2101
+
2102
+ def addBuiltinIncrementalElab (decl : Name) : IO Unit := do
2103
+ builtinIncrementalElabs.modify fun s => s.insert decl
2104
+
2105
+ @[inherit_doc incrementalAttr, builtin_doc]
2106
+ builtin_initialize
2107
+ registerBuiltinAttribute {
2108
+ name := `builtin_incremental
2109
+ descr := s!"(builtin) {incrementalAttr.attr.descr}"
2110
+ applicationTime := .afterCompilation
2111
+ add := fun decl stx kind => do
2112
+ Attribute.Builtin.ensureNoArgs stx
2113
+ unless kind == AttributeKind.global do
2114
+ throwError "invalid attribute 'builtin_incremental', must be global"
2115
+ declareBuiltin decl <| mkApp (mkConst ``addBuiltinIncrementalElab) (toExpr decl)
2116
+ }
2117
+
2118
+ /-- Checks whether a declaration is annotated with `[builtin_incremental]` or `[incremental]`. -/
2119
+ def isIncrementalElab [Monad m] [MonadEnv m] [MonadLiftT IO m] (decl : Name) : m Bool :=
2120
+ (return (← builtinIncrementalElabs.get (m := IO)).contains decl) <||>
2121
+ (return incrementalAttr.hasTag (← getEnv) decl)
2122
+
2123
+ export Term (TermElabM)
2124
+
2125
+ builtin_initialize
2126
+ registerTraceClass `Elab.implicitForall
2127
+
2128
+ end Lean.Elab
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Time.lean ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2021 Mario Carneiro. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Mario Carneiro
5
+ -/
6
+ prelude
7
+ import Lean.Elab.Command
8
+
9
+ /-!
10
+ # Defines `#time` command.
11
+
12
+ Time the elaboration of a command, and print the result (in milliseconds).
13
+ -/
14
+
15
+ namespace Lean.Elab.Time
16
+
17
+ open Command
18
+
19
+ @[builtin_command_elab Lean.Parser.timeCmd] def elabTimeCmd : CommandElab
20
+ | `(#time%$tk $stx:command) => do
21
+ let start ← IO.monoMsNow
22
+ elabCommand stx
23
+ logInfoAt tk m!"time: {(← IO.monoMsNow) - start}ms"
24
+ | _ => throwUnsupportedSyntax
25
+
26
+ end Lean.Elab.Time
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/Util.lean ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2019 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Leonardo de Moura
5
+ -/
6
+ prelude
7
+ import Lean.Parser.Command
8
+ import Lean.KeyedDeclsAttribute
9
+ import Lean.Elab.Exception
10
+
11
+ namespace Lean
12
+
13
+ def Syntax.prettyPrint (stx : Syntax) : Format :=
14
+ match stx.unsetTrailing.reprint with -- TODO use syntax pretty printer
15
+ | some str => format str.toFormat
16
+ | none => format stx
17
+
18
+ def MacroScopesView.format (view : MacroScopesView) (mainModule : Name) : Format :=
19
+ Std.format <|
20
+ if view.scopes.isEmpty then
21
+ view.name
22
+ else if view.mainModule == mainModule then
23
+ view.scopes.foldl Name.mkNum (view.name ++ view.imported)
24
+ else
25
+ view.scopes.foldl Name.mkNum (view.name ++ view.imported ++ view.mainModule)
26
+
27
+ /--
28
+ Two names are from the same lexical scope if their scoping information modulo `MacroScopesView.name`
29
+ is equal.
30
+ -/
31
+ def MacroScopesView.equalScope (a b : MacroScopesView) : Bool :=
32
+ a.scopes == b.scopes && a.mainModule == b.mainModule && a.imported == b.imported
33
+
34
+ namespace Elab
35
+
36
+ def expandOptNamedPrio (stx : Syntax) : MacroM Nat :=
37
+ if stx.isNone then
38
+ return eval_prio default
39
+ else match stx[0] with
40
+ | `(Parser.Command.namedPrio| (priority := $prio)) => evalPrio prio
41
+ | _ => Macro.throwUnsupported
42
+
43
+ structure MacroStackElem where
44
+ before : Syntax
45
+ after : Syntax
46
+
47
+ abbrev MacroStack := List MacroStackElem
48
+
49
+ /-- If `ref` does not have position information, then try to use macroStack -/
50
+ def getBetterRef (ref : Syntax) (macroStack : MacroStack) : Syntax :=
51
+ match ref.getPos? with
52
+ | some _ => ref
53
+ | none =>
54
+ match macroStack.find? (·.before.getPos? != none) with
55
+ | some elem => elem.before
56
+ | none => ref
57
+
58
+ register_builtin_option pp.macroStack : Bool := {
59
+ defValue := false
60
+ group := "pp"
61
+ descr := "display macro expansion stack"
62
+ }
63
+
64
+ def addMacroStack {m} [Monad m] [MonadOptions m] (msgData : MessageData) (macroStack : MacroStack) : m MessageData := do
65
+ if !pp.macroStack.get (← getOptions) then pure msgData else
66
+ match macroStack with
67
+ | [] => pure msgData
68
+ | stack@(top::_) =>
69
+ let msgData := msgData ++ Format.line ++ "with resulting expansion" ++ indentD top.after
70
+ pure $ stack.foldl
71
+ (fun (msgData : MessageData) (elem : MacroStackElem) =>
72
+ msgData ++ Format.line ++ "while expanding" ++ indentD elem.before)
73
+ msgData
74
+
75
+ def checkSyntaxNodeKind [Monad m] [MonadEnv m] [MonadError m] (k : Name) : m Name := do
76
+ if Parser.isValidSyntaxNodeKind (← getEnv) k then pure k
77
+ else throwError "failed"
78
+
79
+ def checkSyntaxNodeKindAtNamespaces [Monad m] [MonadEnv m] [MonadError m] (k : Name) : Name → m Name
80
+ | n@(.str p _) => checkSyntaxNodeKind (n ++ k) <|> checkSyntaxNodeKindAtNamespaces k p
81
+ | .anonymous => checkSyntaxNodeKind k
82
+ | _ => throwError "failed"
83
+
84
+ def checkSyntaxNodeKindAtCurrentNamespaces (k : Name) : AttrM Name := do
85
+ let ctx ← read
86
+ checkSyntaxNodeKindAtNamespaces k ctx.currNamespace
87
+
88
+ def syntaxNodeKindOfAttrParam (defaultParserNamespace : Name) (stx : Syntax) : AttrM SyntaxNodeKind := do
89
+ let k ← Attribute.Builtin.getId stx
90
+ checkSyntaxNodeKindAtCurrentNamespaces k
91
+ <|>
92
+ checkSyntaxNodeKind (defaultParserNamespace ++ k)
93
+ <|>
94
+ throwError "invalid syntax node kind '{k}'"
95
+
96
+ private unsafe def evalSyntaxConstantUnsafe (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax :=
97
+ env.evalConstCheck Syntax opts `Lean.Syntax constName
98
+
99
+ @[implemented_by evalSyntaxConstantUnsafe]
100
+ opaque evalSyntaxConstant (env : Environment) (opts : Options) (constName : Name) : ExceptT String Id Syntax := throw ""
101
+
102
+ unsafe def mkElabAttribute (γ) (attrBuiltinName attrName : Name) (parserNamespace : Name) (typeName : Name) (kind : String)
103
+ (attrDeclName : Name := by exact decl_name%) : IO (KeyedDeclsAttribute γ) :=
104
+ KeyedDeclsAttribute.init {
105
+ builtinName := attrBuiltinName
106
+ name := attrName
107
+ descr := kind ++ " elaborator"
108
+ valueTypeName := typeName
109
+ evalKey := fun _ stx => do
110
+ let kind ← syntaxNodeKindOfAttrParam parserNamespace stx
111
+ /- Recall that a `SyntaxNodeKind` is often the name of the parser, but this is not always true, and we must check it. -/
112
+ if (← getEnv).contains kind && (← getInfoState).enabled then
113
+ addConstInfo stx[1] kind none
114
+ return kind
115
+ onAdded := fun builtin declName => do
116
+ if builtin then
117
+ declareBuiltinDocStringAndRanges declName
118
+ } attrDeclName
119
+
120
+ unsafe def mkMacroAttributeUnsafe (ref : Name) : IO (KeyedDeclsAttribute Macro) :=
121
+ mkElabAttribute Macro `builtin_macro `macro Name.anonymous `Lean.Macro "macro" ref
122
+
123
+ @[implemented_by mkMacroAttributeUnsafe]
124
+ opaque mkMacroAttribute (ref : Name) : IO (KeyedDeclsAttribute Macro)
125
+
126
+ /--
127
+ Registers a macro expander for a given syntax node kind.
128
+
129
+ A macro expander should have type `Lean.Macro` (which is `Lean.Syntax → Lean.MacroM Lean.Syntax`),
130
+ i.e. should take syntax of the given syntax node kind as a parameter and produce different syntax
131
+ in the same syntax category.
132
+
133
+ The `macro_rules` and `macro` commands should usually be preferred over using this attribute
134
+ directly.
135
+ -/
136
+ @[builtin_doc]
137
+ builtin_initialize macroAttribute : KeyedDeclsAttribute Macro ← mkMacroAttribute decl_name%
138
+
139
+ /--
140
+ Try to expand macro at syntax tree root and return macro declaration name and new syntax if successful.
141
+ Return none if all macros threw `Macro.Exception.unsupportedSyntax`.
142
+ -/
143
+ def expandMacroImpl? (env : Environment) : Syntax → MacroM (Option (Name × Except Macro.Exception Syntax)) := fun stx => do
144
+ for e in macroAttribute.getEntries env stx.getKind do
145
+ try
146
+ let stx' ← withFreshMacroScope (e.value stx)
147
+ return (e.declName, Except.ok stx')
148
+ catch
149
+ | Macro.Exception.unsupportedSyntax => pure ()
150
+ | ex => return (e.declName, Except.error ex)
151
+ return none
152
+
153
+ class MonadMacroAdapter (m : Type → Type) where
154
+ getCurrMacroScope : m MacroScope
155
+ getNextMacroScope : m MacroScope
156
+ setNextMacroScope : MacroScope → m Unit
157
+
158
+ @[always_inline]
159
+ instance (m n) [MonadLift m n] [MonadMacroAdapter m] : MonadMacroAdapter n := {
160
+ getCurrMacroScope := liftM (MonadMacroAdapter.getCurrMacroScope : m _)
161
+ getNextMacroScope := liftM (MonadMacroAdapter.getNextMacroScope : m _)
162
+ setNextMacroScope := fun s => liftM (MonadMacroAdapter.setNextMacroScope s : m _)
163
+ }
164
+
165
+ def liftMacroM [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadError m] [MonadResolveName m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] (x : MacroM α) : m α := do
166
+ let env ← getEnv
167
+ let currNamespace ← getCurrNamespace
168
+ let openDecls ← getOpenDecls
169
+ let methods := Macro.mkMethods {
170
+ -- TODO: record recursive expansions in info tree?
171
+ expandMacro? := fun stx => do
172
+ match (← expandMacroImpl? env stx) with
173
+ | some (_, stx?) => liftExcept stx?
174
+ | none => return none
175
+ hasDecl := fun declName => return env.contains declName
176
+ getCurrNamespace := return currNamespace
177
+ resolveNamespace := fun n => return ResolveName.resolveNamespace env currNamespace openDecls n
178
+ resolveGlobalName := fun n => return ResolveName.resolveGlobalName env currNamespace openDecls n
179
+ }
180
+ match x { methods := methods
181
+ ref := ← getRef
182
+ currMacroScope := ← MonadMacroAdapter.getCurrMacroScope
183
+ mainModule := env.mainModule
184
+ currRecDepth := ← MonadRecDepth.getRecDepth
185
+ maxRecDepth := ← MonadRecDepth.getMaxRecDepth
186
+ } { macroScope := (← MonadMacroAdapter.getNextMacroScope) } with
187
+ | EStateM.Result.error Macro.Exception.unsupportedSyntax _ => throwUnsupportedSyntax
188
+ | EStateM.Result.error (Macro.Exception.error ref msg) _ =>
189
+ if msg == maxRecDepthErrorMessage then
190
+ -- Make sure we can detect exception using `Exception.isMaxRecDepth`
191
+ throwMaxRecDepthAt ref
192
+ else
193
+ throwErrorAt ref msg
194
+ | EStateM.Result.ok a s =>
195
+ MonadMacroAdapter.setNextMacroScope s.macroScope
196
+ s.traceMsgs.reverse.forM fun (clsName, msg) => trace clsName fun _ => msg
197
+ return a
198
+
199
+ @[inline] def adaptMacro {m : Type → Type} [Monad m] [MonadMacroAdapter m] [MonadEnv m] [MonadRecDepth m] [MonadError m] [MonadResolveName m] [MonadTrace m] [MonadOptions m] [AddMessageContext m] [MonadLiftT IO m] (x : Macro) (stx : Syntax) : m Syntax :=
200
+ liftMacroM (x stx)
201
+
202
+ partial def mkUnusedBaseName (baseName : Name) : MacroM Name := do
203
+ let currNamespace ← Macro.getCurrNamespace
204
+ if ← Macro.hasDecl (currNamespace ++ baseName) then
205
+ let rec loop (idx : Nat) := do
206
+ let name := baseName.appendIndexAfter idx
207
+ if ← Macro.hasDecl (currNamespace ++ name) then
208
+ loop (idx+1)
209
+ else
210
+ return name
211
+ loop 1
212
+ else
213
+ return baseName
214
+
215
+ def logException [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] [MonadLiftT IO m] (ex : Exception) : m Unit := do
216
+ match ex with
217
+ | Exception.error ref msg => logErrorAt ref msg
218
+ | Exception.internal id _ =>
219
+ unless isAbortExceptionId id || ex.isInterrupt do
220
+ let name ← id.getName
221
+ logError m!"internal exception: {name}"
222
+
223
+ /--
224
+ If `x` throws an exception, catch it and turn it into a log message (using `logException`).
225
+ -/
226
+ def withLogging [Monad m] [MonadLog m] [MonadExcept Exception m] [AddMessageContext m] [MonadOptions m] [MonadLiftT IO m]
227
+ (x : m Unit) : m Unit := do
228
+ try x catch ex => logException ex
229
+
230
+ def nestedExceptionToMessageData [Monad m] [MonadLog m] (ex : Exception) : m MessageData := do
231
+ let pos ← getRefPos
232
+ match ex.getRef.getPos? with
233
+ | none => return ex.toMessageData
234
+ | some exPos =>
235
+ if pos == exPos then
236
+ return ex.toMessageData
237
+ else
238
+ let exPosition := (← getFileMap).toPosition exPos
239
+ return m!"{exPosition.line}:{exPosition.column} {ex.toMessageData}"
240
+
241
+ def throwErrorWithNestedErrors [MonadError m] [Monad m] [MonadLog m] (msg : MessageData) (exs : Array Exception) : m α := do
242
+ throwError "{msg}, errors {toMessageList (← exs.mapM fun | ex => nestedExceptionToMessageData ex)}"
243
+
244
+ builtin_initialize
245
+ registerTraceClass `Elab
246
+ registerTraceClass `Elab.step
247
+ registerTraceClass `Elab.step.result (inherited := true)
248
+
249
+ end Lean.Elab
backend/core/leanprover--lean4---v4.22.0/src/lean/Lean/Elab/WhereFinally.lean ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /-
2
+ Copyright (c) 2025 Microsoft Corporation. All rights reserved.
3
+ Released under Apache 2.0 license as described in the file LICENSE.
4
+ Authors: Sebastian Graf
5
+ -/
6
+ prelude
7
+ import Lean.Parser.Term
8
+
9
+ namespace Lean.Elab
10
+
11
+ structure WhereFinallyView where
12
+ ref : Syntax
13
+ tactic : TSyntax ``Lean.Parser.Tactic.tacticSeq
14
+ deriving Inhabited
15
+
16
+ def WhereFinallyView.none : WhereFinallyView := { ref := .missing, tactic := ⟨.missing⟩ }
17
+
18
+ def WhereFinallyView.isNone (o : WhereFinallyView) : Bool := o.ref.isMissing && o.tactic.raw.isMissing
19
+
20
+ /-- Creates a view of the `finally` section of a `whereDecls` syntax object -/
21
+ def mkWhereFinallyView {m} [Monad m] [MonadError m] (stx : TSyntax ``Parser.Term.whereDecls) : m WhereFinallyView := do
22
+ -- Fail gracefully upon partial parses/missing where or finally sections
23
+ let whereFinally := stx.raw[2][0]
24
+ if whereFinally.isMissing then
25
+ return { ref := stx, tactic := ⟨.missing⟩ }
26
+ if !whereFinally[2][0].isMissing then
27
+ throwErrorAt stx "`where ... finally` does not currently support any named sub-sections `| sectionName => ...`"
28
+ let tactic := ⟨whereFinally[1]⟩
29
+ return { ref := whereFinally, tactic }
external/alphageometry/.venv-ag/Lib/site-packages/absl/app.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The Abseil Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Generic entry point for Abseil Python applications.
16
+
17
+ To use this module, define a ``main`` function with a single ``argv`` argument
18
+ and call ``app.run(main)``. For example::
19
+
20
+ def main(argv):
21
+ if len(argv) > 1:
22
+ raise app.UsageError('Too many command-line arguments.')
23
+
24
+ if __name__ == '__main__':
25
+ app.run(main)
26
+ """
27
+
28
+ import collections
29
+ import errno
30
+ import os
31
+ import pdb
32
+ import sys
33
+ import textwrap
34
+ import traceback
35
+
36
+ from absl import command_name
37
+ from absl import flags
38
+ from absl import logging
39
+
40
+ try:
41
+ import faulthandler
42
+ except ImportError:
43
+ faulthandler = None
44
+
45
+ FLAGS = flags.FLAGS
46
+
47
+ flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode')
48
+ flags.DEFINE_boolean('pdb_post_mortem', False,
49
+ 'Set to true to handle uncaught exceptions with PDB '
50
+ 'post mortem.')
51
+ flags.DEFINE_alias('pdb', 'pdb_post_mortem')
52
+ flags.DEFINE_boolean('run_with_profiling', False,
53
+ 'Set to true for profiling the script. '
54
+ 'Execution will be slower, and the output format might '
55
+ 'change over time.')
56
+ flags.DEFINE_string('profile_file', None,
57
+ 'Dump profile information to a file (for python -m '
58
+ 'pstats). Implies --run_with_profiling.')
59
+ flags.DEFINE_boolean('use_cprofile_for_profiling', True,
60
+ 'Use cProfile instead of the profile module for '
61
+ 'profiling. This has no effect unless '
62
+ '--run_with_profiling is set.')
63
+ flags.DEFINE_boolean('only_check_args', False,
64
+ 'Set to true to validate args and exit.',
65
+ allow_hide_cpp=True)
66
+
67
+
68
+ # If main() exits via an abnormal exception, call into these
69
+ # handlers before exiting.
70
+ EXCEPTION_HANDLERS = []
71
+
72
+
73
+ class Error(Exception):
74
+ pass
75
+
76
+
77
+ class UsageError(Error):
78
+ """Exception raised when the arguments supplied by the user are invalid.
79
+
80
+ Raise this when the arguments supplied are invalid from the point of
81
+ view of the application. For example when two mutually exclusive
82
+ flags have been supplied or when there are not enough non-flag
83
+ arguments. It is distinct from flags.Error which covers the lower
84
+ level of parsing and validating individual flags.
85
+ """
86
+
87
+ def __init__(self, message, exitcode=1):
88
+ super().__init__(message)
89
+ self.exitcode = exitcode
90
+
91
+
92
+ class HelpFlag(flags.BooleanFlag):
93
+ """Special boolean flag that displays usage and raises SystemExit."""
94
+ NAME = 'help'
95
+ SHORT_NAME = '?'
96
+
97
+ def __init__(self):
98
+ super().__init__(
99
+ self.NAME,
100
+ False,
101
+ 'show this help',
102
+ short_name=self.SHORT_NAME,
103
+ allow_hide_cpp=True,
104
+ )
105
+
106
+ def parse(self, arg):
107
+ if self._parse(arg):
108
+ usage(shorthelp=True, writeto_stdout=True)
109
+ # Advertise --helpfull on stdout, since usage() was on stdout.
110
+ print()
111
+ print('Try --helpfull to get a list of all flags.')
112
+ sys.exit(1)
113
+
114
+
115
+ class HelpshortFlag(HelpFlag):
116
+ """--helpshort is an alias for --help."""
117
+ NAME = 'helpshort'
118
+ SHORT_NAME = None
119
+
120
+
121
+ class HelpfullFlag(flags.BooleanFlag):
122
+ """Display help for flags in the main module and all dependent modules."""
123
+
124
+ def __init__(self):
125
+ super().__init__('helpfull', False, 'show full help', allow_hide_cpp=True)
126
+
127
+ def parse(self, arg):
128
+ if self._parse(arg):
129
+ usage(writeto_stdout=True)
130
+ sys.exit(1)
131
+
132
+
133
+ class HelpXMLFlag(flags.BooleanFlag):
134
+ """Similar to HelpfullFlag, but generates output in XML format."""
135
+
136
+ def __init__(self):
137
+ super().__init__(
138
+ 'helpxml',
139
+ False,
140
+ 'like --helpfull, but generates XML output',
141
+ allow_hide_cpp=True,
142
+ )
143
+
144
+ def parse(self, arg):
145
+ if self._parse(arg):
146
+ flags.FLAGS.write_help_in_xml_format(sys.stdout)
147
+ sys.exit(1)
148
+
149
+
150
+ def parse_flags_with_usage(args):
151
+ """Tries to parse the flags, print usage, and exit if unparsable.
152
+
153
+ Args:
154
+ args: [str], a non-empty list of the command line arguments including
155
+ program name.
156
+
157
+ Returns:
158
+ [str], a non-empty list of remaining command line arguments after parsing
159
+ flags, including program name.
160
+ """
161
+ try:
162
+ return FLAGS(args)
163
+ except flags.Error as error:
164
+ message = str(error)
165
+ if '\n' in message:
166
+ final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent(
167
+ message, ' ')
168
+ else:
169
+ final_message = 'FATAL Flags parsing error: %s\n' % message
170
+ sys.stderr.write(final_message)
171
+ sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n')
172
+ sys.exit(1)
173
+
174
+
175
+ _define_help_flags_called = False
176
+
177
+
178
+ def define_help_flags():
179
+ """Registers help flags. Idempotent."""
180
+ # Use a global to ensure idempotence.
181
+ global _define_help_flags_called
182
+
183
+ if not _define_help_flags_called:
184
+ flags.DEFINE_flag(HelpFlag())
185
+ flags.DEFINE_flag(HelpshortFlag()) # alias for --help
186
+ flags.DEFINE_flag(HelpfullFlag())
187
+ flags.DEFINE_flag(HelpXMLFlag())
188
+ _define_help_flags_called = True
189
+
190
+
191
+ def _register_and_parse_flags_with_usage(
192
+ argv=None,
193
+ flags_parser=parse_flags_with_usage,
194
+ ):
195
+ """Registers help flags, parses arguments and shows usage if appropriate.
196
+
197
+ This also calls sys.exit(0) if flag --only_check_args is True.
198
+
199
+ Args:
200
+ argv: [str], a non-empty list of the command line arguments including
201
+ program name, sys.argv is used if None.
202
+ flags_parser: Callable[[List[str]], Any], the function used to parse flags.
203
+ The return value of this function is passed to `main` untouched. It must
204
+ guarantee FLAGS is parsed after this function is called.
205
+
206
+ Returns:
207
+ The return value of `flags_parser`. When using the default `flags_parser`,
208
+ it returns the following:
209
+ [str], a non-empty list of remaining command line arguments after parsing
210
+ flags, including program name.
211
+
212
+ Raises:
213
+ Error: Raised when flags_parser is called, but FLAGS is not parsed.
214
+ SystemError: Raised when it's called more than once.
215
+ """
216
+ # fmt: on
217
+ if _register_and_parse_flags_with_usage.done:
218
+ raise SystemError('Flag registration can be done only once.')
219
+
220
+ define_help_flags()
221
+
222
+ original_argv = sys.argv if argv is None else argv
223
+ args_to_main = flags_parser(original_argv)
224
+ if not FLAGS.is_parsed():
225
+ raise Error('FLAGS must be parsed after flags_parser is called.')
226
+
227
+ # Exit when told so.
228
+ if FLAGS.only_check_args:
229
+ sys.exit(0)
230
+ # Immediately after flags are parsed, bump verbosity to INFO if the flag has
231
+ # not been set.
232
+ if FLAGS['verbosity'].using_default_value:
233
+ FLAGS.verbosity = 0
234
+ _register_and_parse_flags_with_usage.done = True
235
+
236
+ return args_to_main
237
+
238
+ _register_and_parse_flags_with_usage.done = False
239
+
240
+
241
+ def _run_main(main, argv):
242
+ """Calls main, optionally with pdb or profiler."""
243
+ if FLAGS.run_with_pdb:
244
+ sys.exit(pdb.runcall(main, argv))
245
+ elif FLAGS.run_with_profiling or FLAGS.profile_file:
246
+ # Avoid import overhead since most apps (including performance-sensitive
247
+ # ones) won't be run with profiling.
248
+ # pylint: disable=g-import-not-at-top
249
+ import atexit
250
+ if FLAGS.use_cprofile_for_profiling:
251
+ import cProfile as profile
252
+ else:
253
+ import profile
254
+ profiler = profile.Profile()
255
+ if FLAGS.profile_file:
256
+ atexit.register(profiler.dump_stats, FLAGS.profile_file)
257
+ else:
258
+ atexit.register(profiler.print_stats)
259
+ sys.exit(profiler.runcall(main, argv))
260
+ else:
261
+ sys.exit(main(argv))
262
+
263
+
264
+ def _call_exception_handlers(exception):
265
+ """Calls any installed exception handlers."""
266
+ for handler in EXCEPTION_HANDLERS:
267
+ try:
268
+ if handler.wants(exception):
269
+ handler.handle(exception)
270
+ except: # pylint: disable=bare-except
271
+ try:
272
+ # We don't want to stop for exceptions in the exception handlers but
273
+ # we shouldn't hide them either.
274
+ logging.error(traceback.format_exc())
275
+ except: # pylint: disable=bare-except
276
+ # In case even the logging statement fails, ignore.
277
+ pass
278
+
279
+
280
+ def run(
281
+ main,
282
+ argv=None,
283
+ flags_parser=parse_flags_with_usage,
284
+ ):
285
+ """Begins executing the program.
286
+
287
+ Args:
288
+ main: The main function to execute. It takes an single argument "argv",
289
+ which is a list of command line arguments with parsed flags removed.
290
+ The return value is passed to `sys.exit`, and so for example
291
+ a return value of 0 or None results in a successful termination, whereas
292
+ a return value of 1 results in abnormal termination.
293
+ For more details, see https://docs.python.org/3/library/sys#sys.exit
294
+ argv: A non-empty list of the command line arguments including program name,
295
+ sys.argv is used if None.
296
+ flags_parser: Callable[[List[str]], Any], the function used to parse flags.
297
+ The return value of this function is passed to `main` untouched.
298
+ It must guarantee FLAGS is parsed after this function is called.
299
+ Should be passed as a keyword-only arg which will become mandatory in a
300
+ future release.
301
+ - Parses command line flags with the flag module.
302
+ - If there are any errors, prints usage().
303
+ - Calls main() with the remaining arguments.
304
+ - If main() raises a UsageError, prints usage and the error message.
305
+ """
306
+ # fmt: on
307
+ try:
308
+ args = _run_init(
309
+ sys.argv if argv is None else argv,
310
+ flags_parser,
311
+ )
312
+ while _init_callbacks:
313
+ callback = _init_callbacks.popleft()
314
+ callback()
315
+ try:
316
+ _run_main(main, args)
317
+ except UsageError as error:
318
+ usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
319
+ except:
320
+ exc = sys.exc_info()[1]
321
+ # Don't try to post-mortem debug successful SystemExits, since those
322
+ # mean there wasn't actually an error. In particular, the test framework
323
+ # raises SystemExit(False) even if all tests passed.
324
+ if isinstance(exc, SystemExit) and not exc.code:
325
+ raise
326
+
327
+ # Check the tty so that we don't hang waiting for input in an
328
+ # non-interactive scenario.
329
+ if FLAGS.pdb_post_mortem and sys.stdout.isatty():
330
+ traceback.print_exc()
331
+ print()
332
+ print(' *** Entering post-mortem debugging ***')
333
+ print()
334
+ pdb.post_mortem()
335
+ raise
336
+ except Exception as e:
337
+ _call_exception_handlers(e)
338
+ raise
339
+
340
+ # Callbacks which have been deferred until after _run_init has been called.
341
+ _init_callbacks = collections.deque()
342
+
343
+
344
+ def call_after_init(callback):
345
+ """Calls the given callback only once ABSL has finished initialization.
346
+
347
+ If ABSL has already finished initialization when ``call_after_init`` is
348
+ called then the callback is executed immediately, otherwise `callback` is
349
+ stored to be executed after ``app.run`` has finished initializing (aka. just
350
+ before the main function is called).
351
+
352
+ If called after ``app.run``, this is equivalent to calling ``callback()`` in
353
+ the caller thread. If called before ``app.run``, callbacks are run
354
+ sequentially (in an undefined order) in the same thread as ``app.run``.
355
+
356
+ Args:
357
+ callback: a callable to be called once ABSL has finished initialization.
358
+ This may be immediate if initialization has already finished. It
359
+ takes no arguments and returns nothing.
360
+ """
361
+ if _run_init.done:
362
+ callback()
363
+ else:
364
+ _init_callbacks.append(callback)
365
+
366
+
367
+ def _run_init(
368
+ argv,
369
+ flags_parser,
370
+ ):
371
+ """Does one-time initialization and re-parses flags on rerun."""
372
+ if _run_init.done:
373
+ return flags_parser(argv)
374
+ command_name.make_process_name_useful()
375
+ # Set up absl logging handler.
376
+ logging.use_absl_handler()
377
+ args = _register_and_parse_flags_with_usage(
378
+ argv=argv,
379
+ flags_parser=flags_parser,
380
+ )
381
+ if faulthandler:
382
+ try:
383
+ faulthandler.enable()
384
+ except Exception: # pylint: disable=broad-except
385
+ # Some tests verify stderr output very closely, so don't print anything.
386
+ # Disabled faulthandler is a low-impact error.
387
+ pass
388
+ _run_init.done = True
389
+ return args
390
+
391
+
392
+ _run_init.done = False
393
+
394
+
395
+ def usage(shorthelp=False, writeto_stdout=False, detailed_error=None,
396
+ exitcode=None):
397
+ """Writes __main__'s docstring to stderr with some help text.
398
+
399
+ Args:
400
+ shorthelp: bool, if True, prints only flags from the main module,
401
+ rather than all flags.
402
+ writeto_stdout: bool, if True, writes help message to stdout,
403
+ rather than to stderr.
404
+ detailed_error: str, additional detail about why usage info was presented.
405
+ exitcode: optional integer, if set, exits with this status code after
406
+ writing help.
407
+ """
408
+ if writeto_stdout:
409
+ stdfile = sys.stdout
410
+ else:
411
+ stdfile = sys.stderr
412
+
413
+ doc = sys.modules['__main__'].__doc__
414
+ if not doc:
415
+ doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
416
+ doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
417
+ else:
418
+ # Replace all '%s' with sys.argv[0], and all '%%' with '%'.
419
+ num_specifiers = doc.count('%') - 2 * doc.count('%%')
420
+ try:
421
+ doc %= (sys.argv[0],) * num_specifiers
422
+ except (OverflowError, TypeError, ValueError):
423
+ # Just display the docstring as-is.
424
+ pass
425
+ if shorthelp:
426
+ flag_str = FLAGS.main_module_help()
427
+ else:
428
+ flag_str = FLAGS.get_help()
429
+ try:
430
+ stdfile.write(doc)
431
+ if flag_str:
432
+ stdfile.write('\nflags:\n')
433
+ stdfile.write(flag_str)
434
+ stdfile.write('\n')
435
+ if detailed_error is not None:
436
+ stdfile.write('\n%s\n' % detailed_error)
437
+ except OSError as e:
438
+ # We avoid printing a huge backtrace if we get EPIPE, because
439
+ # "foo.par --help | less" is a frequent use case.
440
+ if e.errno != errno.EPIPE:
441
+ raise
442
+ if exitcode is not None:
443
+ sys.exit(exitcode)
444
+
445
+
446
+ class ExceptionHandler:
447
+ """Base exception handler from which other may inherit."""
448
+
449
+ def wants(self, exc):
450
+ """Returns whether this handler wants to handle the exception or not.
451
+
452
+ This base class returns True for all exceptions by default. Override in
453
+ subclass if it wants to be more selective.
454
+
455
+ Args:
456
+ exc: Exception, the current exception.
457
+ """
458
+ del exc # Unused.
459
+ return True
460
+
461
+ def handle(self, exc):
462
+ """Do something with the current exception.
463
+
464
+ Args:
465
+ exc: Exception, the current exception
466
+
467
+ This method must be overridden.
468
+ """
469
+ raise NotImplementedError()
470
+
471
+
472
+ def install_exception_handler(handler):
473
+ """Installs an exception handler.
474
+
475
+ Args:
476
+ handler: ExceptionHandler, the exception handler to install.
477
+
478
+ Raises:
479
+ TypeError: Raised when the handler was not of the correct type.
480
+
481
+ All installed exception handlers will be called if main() exits via
482
+ an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt,
483
+ FlagsError or UsageError.
484
+ """
485
+ if not isinstance(handler, ExceptionHandler):
486
+ raise TypeError('handler of type %s does not inherit from ExceptionHandler'
487
+ % type(handler))
488
+ EXCEPTION_HANDLERS.append(handler)
external/alphageometry/.venv-ag/Lib/site-packages/absl/app.pyi ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, TypeVar, Union, overload
2
+
3
+ from absl.flags import _flag
4
+
5
+ _MainArgs = TypeVar('_MainArgs')
6
+ _Exc = TypeVar('_Exc', bound=Exception)
7
+
8
+ class ExceptionHandler():
9
+
10
+ def wants(self, exc: _Exc) -> bool:
11
+ ...
12
+
13
+ def handle(self, exc: _Exc):
14
+ ...
15
+
16
+ EXCEPTION_HANDLERS: List[ExceptionHandler] = ...
17
+
18
+ class HelpFlag(_flag.BooleanFlag):
19
+ def __init__(self):
20
+ ...
21
+
22
+ class HelpshortFlag(HelpFlag):
23
+ ...
24
+
25
+ class HelpfullFlag(_flag.BooleanFlag):
26
+ def __init__(self):
27
+ ...
28
+
29
+ class HelpXMLFlag(_flag.BooleanFlag):
30
+ def __init__(self):
31
+ ...
32
+
33
+ def define_help_flags() -> None:
34
+ ...
35
+
36
+ @overload
37
+ def usage(shorthelp: Union[bool, int] = ...,
38
+ writeto_stdout: Union[bool, int] = ...,
39
+ detailed_error: Optional[Any] = ...,
40
+ exitcode: None = ...) -> None:
41
+ ...
42
+
43
+ @overload
44
+ def usage(shorthelp: Union[bool, int],
45
+ writeto_stdout: Union[bool, int],
46
+ detailed_error: Optional[Any],
47
+ exitcode: int) -> NoReturn:
48
+ ...
49
+
50
+ @overload
51
+ def usage(shorthelp: Union[bool, int] = ...,
52
+ writeto_stdout: Union[bool, int] = ...,
53
+ detailed_error: Optional[Any] = ...,
54
+ *,
55
+ exitcode: int) -> NoReturn:
56
+ ...
57
+
58
+ def install_exception_handler(handler: ExceptionHandler) -> None:
59
+ ...
60
+
61
+ class Error(Exception):
62
+ ...
63
+
64
+ class UsageError(Error):
65
+ exitcode: int
66
+
67
+ def parse_flags_with_usage(args: List[str]) -> List[str]:
68
+ ...
69
+
70
+ def call_after_init(callback: Callable[[], Any]) -> None:
71
+ ...
72
+
73
+ # Without the flag_parser argument, `main` should require a List[str].
74
+ @overload
75
+ def run(
76
+ main: Callable[[List[str]], Any],
77
+ argv: Optional[List[str]] = ...,
78
+ ) -> NoReturn:
79
+ ...
80
+
81
+ @overload
82
+ def run(
83
+ main: Callable[[_MainArgs], Any],
84
+ argv: Optional[List[str]] = ...,
85
+ *,
86
+ flags_parser: Callable[[List[str]], _MainArgs],
87
+ ) -> NoReturn:
88
+ ...
external/alphageometry/.venv-ag/Lib/site-packages/absl/command_name.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 The Abseil Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A tiny stand alone library to change the kernel process name on Linux."""
16
+
17
+ import os
18
+ import sys
19
+
20
+ # This library must be kept small and stand alone. It is used by small things
21
+ # that require no extension modules.
22
+
23
+
24
+ def make_process_name_useful():
25
+ """Sets the process name to something better than 'python' if possible."""
26
+ set_kernel_process_name(os.path.basename(sys.argv[0]))
27
+
28
+
29
+ def set_kernel_process_name(name):
30
+ """Changes the Kernel's /proc/self/status process name on Linux.
31
+
32
+ The kernel name is NOT what will be shown by the ps or top command.
33
+ It is a 15 character string stored in the kernel's process table that
34
+ is included in the kernel log when a process is OOM killed.
35
+ The first 15 bytes of name are used. Non-ASCII unicode is replaced with '?'.
36
+
37
+ Does nothing if /proc/self/comm cannot be written or prctl() fails.
38
+
39
+ Args:
40
+ name: bytes|unicode, the Linux kernel's command name to set.
41
+ """
42
+ if not isinstance(name, bytes):
43
+ name = name.encode('ascii', 'replace')
44
+ try:
45
+ # This is preferred to using ctypes to try and call prctl() when possible.
46
+ with open('/proc/self/comm', 'wb') as proc_comm:
47
+ proc_comm.write(name[:15])
48
+ except OSError:
49
+ try:
50
+ import ctypes # pylint: disable=g-import-not-at-top
51
+ except ImportError:
52
+ return # No ctypes.
53
+ try:
54
+ libc = ctypes.CDLL('libc.so.6')
55
+ except OSError:
56
+ return # No libc.so.6.
57
+ pr_set_name = ctypes.c_ulong(15) # linux/prctl.h PR_SET_NAME value.
58
+ zero = ctypes.c_ulong(0)
59
+ try:
60
+ libc.prctl(pr_set_name, name, zero, zero, zero)
61
+ # Ignore the prctl return value. Nothing we can do if it errored.
62
+ except AttributeError:
63
+ return # No prctl.
external/alphageometry/.venv-ag/Lib/site-packages/absl/py.typed ADDED
File without changes
external/alphageometry/.venv-ag/Lib/site-packages/distutils-precedence.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'local') == 'local'; enabled and __import__('_distutils_hack').add_shim();
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Note: import <name> as <name> is required for names to be exported.
16
+ # See PEP 484 & https://github.com/google/jax/issues/7570
17
+
18
+ from jax.experimental.x64_context import (
19
+ enable_x64 as enable_x64,
20
+ disable_x64 as disable_x64,
21
+ )
22
+ from jax._src.callback import (
23
+ io_callback as io_callback
24
+ )
25
+ from jax._src.earray import (
26
+ EArray as EArray
27
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/__init__.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This module includes experimental JAX support for the `Python array API standard`_.
17
+ Support for this is currently experimental and not fully complete.
18
+
19
+ Example Usage::
20
+
21
+ >>> from jax.experimental import array_api as xp
22
+
23
+ >>> xp.__array_api_version__
24
+ '2023.12'
25
+
26
+ >>> arr = xp.arange(1000)
27
+
28
+ >>> arr.sum()
29
+ Array(499500, dtype=int32)
30
+
31
+ The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
32
+ and implements most of the API listed in the standard.
33
+
34
+ .. _Python array API standard: https://data-apis.org/array-api/latest/
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
40
+
41
+ from jax.experimental.array_api import fft as fft
42
+ from jax.experimental.array_api import linalg as linalg
43
+
44
+ from jax.numpy import (
45
+ abs as abs,
46
+ acos as acos,
47
+ acosh as acosh,
48
+ add as add,
49
+ all as all,
50
+ any as any,
51
+ argmax as argmax,
52
+ argmin as argmin,
53
+ argsort as argsort,
54
+ asin as asin,
55
+ asinh as asinh,
56
+ atan as atan,
57
+ atan2 as atan2,
58
+ atanh as atanh,
59
+ bitwise_and as bitwise_and,
60
+ bitwise_invert as bitwise_invert,
61
+ bitwise_left_shift as bitwise_left_shift,
62
+ bitwise_or as bitwise_or,
63
+ bitwise_right_shift as bitwise_right_shift,
64
+ bitwise_xor as bitwise_xor,
65
+ bool as bool,
66
+ broadcast_arrays as broadcast_arrays,
67
+ broadcast_to as broadcast_to,
68
+ can_cast as can_cast,
69
+ complex128 as complex128,
70
+ complex64 as complex64,
71
+ concat as concat,
72
+ conj as conj,
73
+ copysign as copysign,
74
+ cos as cos,
75
+ cosh as cosh,
76
+ cumulative_sum as cumulative_sum,
77
+ divide as divide,
78
+ e as e,
79
+ empty as empty,
80
+ empty_like as empty_like,
81
+ equal as equal,
82
+ exp as exp,
83
+ expand_dims as expand_dims,
84
+ expm1 as expm1,
85
+ flip as flip,
86
+ float32 as float32,
87
+ float64 as float64,
88
+ floor_divide as floor_divide,
89
+ from_dlpack as from_dlpack,
90
+ full as full,
91
+ full_like as full_like,
92
+ greater as greater,
93
+ greater_equal as greater_equal,
94
+ iinfo as iinfo,
95
+ imag as imag,
96
+ inf as inf,
97
+ int16 as int16,
98
+ int32 as int32,
99
+ int64 as int64,
100
+ int8 as int8,
101
+ isdtype as isdtype,
102
+ isfinite as isfinite,
103
+ isinf as isinf,
104
+ isnan as isnan,
105
+ less as less,
106
+ less_equal as less_equal,
107
+ log as log,
108
+ log10 as log10,
109
+ log1p as log1p,
110
+ log2 as log2,
111
+ logaddexp as logaddexp,
112
+ logical_and as logical_and,
113
+ logical_not as logical_not,
114
+ logical_or as logical_or,
115
+ logical_xor as logical_xor,
116
+ matmul as matmul,
117
+ matrix_transpose as matrix_transpose,
118
+ max as max,
119
+ maximum as maximum,
120
+ mean as mean,
121
+ meshgrid as meshgrid,
122
+ min as min,
123
+ minimum as minimum,
124
+ moveaxis as moveaxis,
125
+ multiply as multiply,
126
+ nan as nan,
127
+ negative as negative,
128
+ newaxis as newaxis,
129
+ nonzero as nonzero,
130
+ not_equal as not_equal,
131
+ ones as ones,
132
+ ones_like as ones_like,
133
+ permute_dims as permute_dims,
134
+ pi as pi,
135
+ positive as positive,
136
+ pow as pow,
137
+ prod as prod,
138
+ real as real,
139
+ remainder as remainder,
140
+ repeat as repeat,
141
+ result_type as result_type,
142
+ roll as roll,
143
+ round as round,
144
+ searchsorted as searchsorted,
145
+ sign as sign,
146
+ signbit as signbit,
147
+ sin as sin,
148
+ sinh as sinh,
149
+ sort as sort,
150
+ sqrt as sqrt,
151
+ square as square,
152
+ squeeze as squeeze,
153
+ stack as stack,
154
+ subtract as subtract,
155
+ sum as sum,
156
+ take as take,
157
+ tan as tan,
158
+ tanh as tanh,
159
+ tensordot as tensordot,
160
+ tile as tile,
161
+ tril as tril,
162
+ triu as triu,
163
+ uint16 as uint16,
164
+ uint32 as uint32,
165
+ uint64 as uint64,
166
+ uint8 as uint8,
167
+ unique_all as unique_all,
168
+ unique_counts as unique_counts,
169
+ unique_inverse as unique_inverse,
170
+ unique_values as unique_values,
171
+ unstack as unstack,
172
+ vecdot as vecdot,
173
+ where as where,
174
+ zeros as zeros,
175
+ zeros_like as zeros_like,
176
+ )
177
+
178
+ from jax.experimental.array_api._manipulation_functions import (
179
+ reshape as reshape,
180
+ )
181
+
182
+ from jax.experimental.array_api._creation_functions import (
183
+ arange as arange,
184
+ asarray as asarray,
185
+ eye as eye,
186
+ linspace as linspace,
187
+ )
188
+
189
+ from jax.experimental.array_api._data_type_functions import (
190
+ astype as astype,
191
+ finfo as finfo,
192
+ )
193
+
194
+ from jax.experimental.array_api._elementwise_functions import (
195
+ ceil as ceil,
196
+ clip as clip,
197
+ floor as floor,
198
+ hypot as hypot,
199
+ trunc as trunc,
200
+ )
201
+
202
+ from jax.experimental.array_api._statistical_functions import (
203
+ std as std,
204
+ var as var,
205
+ )
206
+
207
+ from jax.experimental.array_api._utility_functions import (
208
+ __array_namespace_info__ as __array_namespace_info__,
209
+ )
210
+
211
+ from jax.experimental.array_api import _array_methods
212
+ _array_methods.add_array_object_methods()
213
+ del _array_methods
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_array_methods.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any
18
+
19
+ import jax
20
+ from jax._src.array import ArrayImpl
21
+ from jax.experimental.array_api._version import __array_api_version__
22
+ from jax.sharding import Sharding
23
+
24
+ from jax._src.lib import xla_extension as xe
25
+
26
+
27
+ def _array_namespace(self, /, *, api_version: None | str = None):
28
+ if api_version is not None and api_version != __array_api_version__:
29
+ raise ValueError(f"{api_version=!r} is not available; "
30
+ f"available versions are: {[__array_api_version__]}")
31
+ return jax.experimental.array_api
32
+
33
+
34
+ def _to_device(self, device: xe.Device | Sharding | None, *,
35
+ stream: int | Any | None = None):
36
+ if stream is not None:
37
+ raise NotImplementedError("stream argument of array.to_device()")
38
+ return jax.device_put(self, device)
39
+
40
+
41
+ def add_array_object_methods():
42
+ # TODO(jakevdp): set on tracers as well?
43
+ setattr(ArrayImpl, "__array_namespace__", _array_namespace)
44
+ setattr(ArrayImpl, "to_device", _to_device)
45
+ setattr(ArrayImpl, "device", property(lambda self: self.sharding))
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_creation_functions.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+
20
+ # TODO(micky774): Deprecate after adding device argument to jax.numpy functions
21
+ def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
22
+ return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
23
+
24
+ def asarray(obj, /, *, dtype=None, device=None, copy=None):
25
+ return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
26
+
27
+ def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
28
+ return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
29
+
30
+ def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
31
+ return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_data_type_functions.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import builtins
18
+ from typing import NamedTuple
19
+ import numpy as np
20
+
21
+ import jax.numpy as jnp
22
+
23
+ from jax._src.lib import xla_client as xc
24
+ from jax._src.sharding import Sharding
25
+ from jax._src import dtypes as _dtypes
26
+
27
+ # TODO(micky774): Update jax.numpy dtypes to dtype *objects*
28
+ bool = np.dtype('bool')
29
+ int8 = np.dtype('int8')
30
+ int16 = np.dtype('int16')
31
+ int32 = np.dtype('int32')
32
+ int64 = np.dtype('int64')
33
+ uint8 = np.dtype('uint8')
34
+ uint16 = np.dtype('uint16')
35
+ uint32 = np.dtype('uint32')
36
+ uint64 = np.dtype('uint64')
37
+ float32 = np.dtype('float32')
38
+ float64 = np.dtype('float64')
39
+ complex64 = np.dtype('complex64')
40
+ complex128 = np.dtype('complex128')
41
+
42
+
43
+ # TODO(micky774): Remove when jax.numpy.astype is deprecation is completed
44
+ def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
45
+ src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x)
46
+ if (
47
+ src_dtype is not None
48
+ and _dtypes.isdtype(src_dtype, "complex floating")
49
+ and _dtypes.isdtype(dtype, ("integral", "real floating"))
50
+ ):
51
+ raise ValueError(
52
+ "Casting from complex to non-complex dtypes is not permitted. Please "
53
+ "first use jnp.real or jnp.imag to take the real/imaginary component of "
54
+ "your input."
55
+ )
56
+ return jnp.astype(x, dtype, copy=copy, device=device)
57
+
58
+
59
+ class FInfo(NamedTuple):
60
+ bits: int
61
+ eps: float
62
+ max: float
63
+ min: float
64
+ smallest_normal: float
65
+ dtype: jnp.dtype
66
+
67
+ # TODO(micky774): Update jax.numpy.finfo so that its attributes are python
68
+ # floats
69
+ def finfo(type, /) -> FInfo:
70
+ info = jnp.finfo(type)
71
+ return FInfo(
72
+ bits=info.bits,
73
+ eps=float(info.eps),
74
+ max=float(info.max),
75
+ min=float(info.min),
76
+ smallest_normal=float(info.smallest_normal),
77
+ dtype=jnp.dtype(type)
78
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_elementwise_functions.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ from jax.numpy import isdtype
17
+ from jax._src.dtypes import issubdtype
18
+ from jax._src.numpy.util import promote_args
19
+
20
+
21
+ # TODO(micky774): Update jnp.ceil to preserve integral dtype
22
+ def ceil(x, /):
23
+ """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
24
+ x, = promote_args("ceil", x)
25
+ if isdtype(x.dtype, "integral"):
26
+ return x
27
+ return jax.numpy.ceil(x)
28
+
29
+
30
+ # TODO(micky774): Remove when jnp.clip deprecation is completed
31
+ # (began 2024-4-2) and default behavior is Array API 2023 compliant
32
+ def clip(x, /, min=None, max=None):
33
+ """Returns the complex conjugate for each element x_i of the input array x."""
34
+ x, = promote_args("clip", x)
35
+
36
+ if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
37
+ raise ValueError(
38
+ "Clip received a complex value either through the input or the min/max "
39
+ "keywords. Complex values have no ordering and cannot be clipped. "
40
+ "Please convert to a real value or array by taking the real or "
41
+ "imaginary components via jax.numpy.real/imag respectively."
42
+ )
43
+ return jax.numpy.clip(x, min=min, max=max)
44
+
45
+
46
+ # TODO(micky774): Update jnp.floor to preserve integral dtype
47
+ def floor(x, /):
48
+ """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i."""
49
+ x, = promote_args("floor", x)
50
+ if isdtype(x.dtype, "integral"):
51
+ return x
52
+ return jax.numpy.floor(x)
53
+
54
+
55
+ # TODO(micky774): Remove when jnp.hypot deprecation is completed
56
+ # (began 2024-4-14) and default behavior is Array API 2023 compliant
57
+ def hypot(x1, x2, /):
58
+ """Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
59
+ x1, x2 = promote_args("hypot", x1, x2)
60
+
61
+ if issubdtype(x1.dtype, jax.numpy.complexfloating):
62
+ raise ValueError(
63
+ "hypot does not support complex-valued inputs. Please convert to real "
64
+ "values first, such as by using jnp.real or jnp.imag to take the real "
65
+ "or imaginary components respectively.")
66
+ return jax.numpy.hypot(x1, x2)
67
+
68
+
69
+ # TODO(micky774): Update jnp.trunc to preserve integral dtype
70
+ def trunc(x, /):
71
+ """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i."""
72
+ x, = promote_args("trunc", x)
73
+ if isdtype(x.dtype, "integral"):
74
+ return x
75
+ return jax.numpy.trunc(x)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_fft_functions.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax.numpy as jnp
16
+
17
+ # TODO(micky774): Remove after adding device parameter to corresponding jnp.fft
18
+ # functions.
19
+ def fftfreq(n, /, *, d=1.0, device=None):
20
+ """Returns the discrete Fourier transform sample frequencies."""
21
+ return jnp.fft.fftfreq(n, d=d).to_device(device)
22
+
23
+ def rfftfreq(n, /, *, d=1.0, device=None):
24
+ """Returns the discrete Fourier transform sample frequencies (for rfft and irfft)."""
25
+ return jnp.fft.rfftfreq(n, d=d).to_device(device)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_linear_algebra_functions.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+
17
+ # TODO(micky774): Remove after deprecation is completed (began 2024-5-14)
18
+ def matrix_rank(x, /, *, rtol=None):
19
+ """
20
+ Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices).
21
+ """
22
+ return jax.numpy.linalg.matrix_rank(x, rtol)
23
+
24
+ def pinv(x, /, *, rtol=None):
25
+ """
26
+ Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x.
27
+ """
28
+ return jax.numpy.linalg.pinv(x, rtol)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_manipulation_functions.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import jax
18
+ from jax import Array
19
+
20
+
21
+ # TODO(micky774): Implement copy
22
+ def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
23
+ """Reshapes an array without changing its data."""
24
+ del copy # unused
25
+ return jax.numpy.reshape(x, shape)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_statistical_functions.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+
17
+
18
+ def std(x, /, *, axis=None, correction=0.0, keepdims=False):
19
+ """Calculates the standard deviation of the input array x."""
20
+ return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)
21
+
22
+
23
+ def var(x, /, *, axis=None, correction=0.0, keepdims=False):
24
+ """Calculates the variance of the input array x."""
25
+ return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_utility_functions.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import jax
18
+ from typing import Tuple
19
+ from jax._src.sharding import Sharding
20
+ from jax._src.lib import xla_client as xc
21
+ from jax._src import dtypes as _dtypes, config
22
+
23
+ # TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api
24
+ # deprecation
25
+ class __array_namespace_info__:
26
+
27
+ def __init__(self):
28
+ self._capabilities = {
29
+ "boolean indexing": True,
30
+ "data-dependent shapes": False,
31
+ }
32
+
33
+
34
+ def _build_dtype_dict(self):
35
+ array_api_types = {
36
+ "bool", "int8", "int16",
37
+ "int32", "uint8", "uint16",
38
+ "uint32", "float32", "complex64"
39
+ }
40
+ if config.enable_x64.value:
41
+ array_api_types |= {"int64", "uint64", "float64", "complex128"}
42
+ return {category: {t.name: t for t in types if t.name in array_api_types}
43
+ for category, types in _dtypes._dtype_kinds.items()}
44
+
45
+ def default_device(self):
46
+ # By default JAX arrays are uncommitted (device=None), meaning that
47
+ # JAX is free to choose the most efficient device placement.
48
+ return None
49
+
50
+ def devices(self):
51
+ return jax.devices()
52
+
53
+ def capabilities(self):
54
+ return self._capabilities
55
+
56
+ def default_dtypes(self, *, device: xc.Device | Sharding | None = None):
57
+ # Array API supported dtypes are device-independent in JAX
58
+ del device
59
+ default_dtypes = {
60
+ "real floating": "f",
61
+ "complex floating": "c",
62
+ "integral": "i",
63
+ "indexing": "i",
64
+ }
65
+ return {
66
+ dtype_name: _dtypes.canonicalize_dtype(
67
+ _dtypes._default_types.get(kind)
68
+ ) for dtype_name, kind in default_dtypes.items()
69
+ }
70
+
71
+ def dtypes(
72
+ self, *,
73
+ device: xc.Device | Sharding | None = None,
74
+ kind: str | Tuple[str, ...] | None = None):
75
+ # Array API supported dtypes are device-independent in JAX
76
+ del device
77
+ data_types = self._build_dtype_dict()
78
+ if kind is None:
79
+ out_dict = data_types["numeric"] | data_types["bool"]
80
+ elif isinstance(kind, tuple):
81
+ out_dict = {}
82
+ for _kind in kind:
83
+ out_dict |= data_types[_kind]
84
+ else:
85
+ out_dict = data_types[kind]
86
+ return out_dict
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/_version.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __array_api_version__ = '2023.12'
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/fft.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from jax.numpy.fft import (
16
+ fft as fft,
17
+ fftn as fftn,
18
+ fftshift as fftshift,
19
+ hfft as hfft,
20
+ ifft as ifft,
21
+ ifftn as ifftn,
22
+ ifftshift as ifftshift,
23
+ ihfft as ihfft,
24
+ irfft as irfft,
25
+ irfftn as irfftn,
26
+ rfft as rfft,
27
+ rfftn as rfftn,
28
+ )
29
+
30
+ from jax.experimental.array_api._fft_functions import (
31
+ fftfreq as fftfreq,
32
+ rfftfreq as rfftfreq,
33
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_api/linalg.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from jax.numpy.linalg import (
16
+ cholesky as cholesky,
17
+ cross as cross,
18
+ det as det,
19
+ diagonal as diagonal,
20
+ eigh as eigh,
21
+ eigvalsh as eigvalsh,
22
+ inv as inv,
23
+ matmul as matmul,
24
+ matrix_norm as matrix_norm,
25
+ matrix_power as matrix_power,
26
+ matrix_transpose as matrix_transpose,
27
+ outer as outer,
28
+ qr as qr,
29
+ slogdet as slogdet,
30
+ solve as solve,
31
+ svd as svd,
32
+ svdvals as svdvals,
33
+ tensordot as tensordot,
34
+ vecdot as vecdot,
35
+ vector_norm as vector_norm,
36
+ )
37
+
38
+ from jax.numpy.linalg import trace as trace
39
+
40
+ from jax.experimental.array_api._linear_algebra_functions import (
41
+ matrix_rank as matrix_rank,
42
+ pinv as pinv,
43
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Array serialization and deserialization."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import abc
19
+ import asyncio
20
+ from collections.abc import Awaitable, Sequence
21
+ from functools import partial
22
+ import itertools
23
+ import logging
24
+ import os
25
+ import re
26
+ import sys
27
+ import threading
28
+ import time
29
+ from typing import Any, Callable, Optional, Union
30
+
31
+ import jax
32
+ from jax._src import array
33
+ from jax._src import distributed
34
+ from jax._src import sharding
35
+ from jax._src import sharding_impls
36
+ from jax._src.layout import Layout, DeviceLocalLayout as DLL
37
+ from jax._src import typing
38
+ from jax._src import util
39
+ from jax._src.lib import xla_extension as xe
40
+ import jax.numpy as jnp
41
+ import numpy as np
42
+ import tensorstore as ts
43
+
44
+
45
+ TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
46
+ _REMOVED_VALUE = 'Value removed'
47
+ _CHECKPOINT_SUCCESS = 'checkpoint_write_success'
48
+ _module_unique_count = itertools.count()
49
+ _DEFAULT_DRIVER = 'file'
50
+ _DISTRIBUTED_SYSTEM_MSG = (
51
+ 'Please initialize the distributed system via '
52
+ '`jax.distributed.initialize()` at the start of your program.')
53
+ _REMOTE_URL_PREFIXES = ['gs://', 's3://']
54
+ _REMOTE_DRIVER_VALIDATIONS = [
55
+ {'driver': 'gcs', 'path_regex': None},
56
+ {'driver': 's3', 'path_regex': None},
57
+ ]
58
+
59
+ class BarrierTimeoutException(Exception):
60
+ pass
61
+
62
+ _BARRIER_TIMED_OUT_MSG = (
63
+ "Suggestions for possible fixes:\n"
64
+ "* Check the logs to see if one or more processes failed.\n"
65
+ "* Make sure the training and checkpointing endpoints are close geographically.\n"
66
+ "* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ async def create_async_array_from_callback(
72
+ global_shape: array.Shape,
73
+ inp_sharding: jax.sharding.Sharding,
74
+ data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
75
+ ):
76
+ device_to_index_map = inp_sharding.devices_indices_map(global_shape)
77
+ addressable_da = inp_sharding._addressable_device_assignment
78
+ future_arrays = [data_callback(device_to_index_map[d], d)
79
+ for d in addressable_da]
80
+ dbs = await asyncio.gather(*future_arrays)
81
+ return array.make_array_from_single_device_arrays(
82
+ global_shape, inp_sharding, dbs)
83
+
84
+
85
+ def _get_metadata(arr):
86
+ local_shape = arr.addressable_data(0).shape
87
+ return {
88
+ 'compressor': {'id': 'zstd'},
89
+ 'shape': arr.shape,
90
+ 'chunks': np.array(np.maximum(1, local_shape)),
91
+ }
92
+
93
+
94
+ def _spec_has_metadata(tree):
95
+ if not isinstance(tree, dict):
96
+ return False
97
+ return 'metadata' in tree or any(
98
+ _spec_has_metadata(subtree) for _, subtree in tree.items())
99
+
100
+ def _get_kvstore_for_gcs(ckpt_path: str):
101
+ m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
102
+ if m is None:
103
+ raise ValueError('The ckpt_path should contain the bucket name and the '
104
+ f'file path inside the bucket. Got: {ckpt_path}')
105
+ gcs_bucket = m.group(1)
106
+ path_without_bucket = m.group(2)
107
+ return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}
108
+
109
+ def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
110
+ # Normalize path to exclude trailing '/'. In GCS path case, we will need to
111
+ # fix the path prefix to add back the stripped '/'.
112
+ ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
113
+ is_gcs_path = ckpt_path.startswith('gs://')
114
+ spec = {'driver': 'zarr', 'kvstore': {}}
115
+ if ocdbt:
116
+ if not is_gcs_path and not os.path.isabs(ckpt_path):
117
+ raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
118
+ base_path = os.path.dirname(ckpt_path)
119
+ spec['kvstore'] = {
120
+ 'driver': 'ocdbt',
121
+ 'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}',
122
+ 'path': os.path.basename(ckpt_path),
123
+ }
124
+ else:
125
+ if is_gcs_path:
126
+ spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path)
127
+ else:
128
+ spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path}
129
+
130
+ return spec
131
+
132
+
133
+ def is_remote_storage(tspec: Union[dict[str, Any], str]) -> bool:
134
+ """Detect if user is using cloud storages.
135
+
136
+ This can detect common defines and unable to detect some corner cases such as
137
+ using gcsfuse.
138
+ """
139
+ if isinstance(tspec, str):
140
+ # KvStoreUrl
141
+ if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
142
+ return True
143
+ else:
144
+ return False
145
+
146
+ for key in ('base', 'kvstore'):
147
+ if key in tspec:
148
+ return is_remote_storage(tspec[key])
149
+
150
+ if 'driver' in tspec:
151
+ for rule in _REMOTE_DRIVER_VALIDATIONS:
152
+ if tspec['driver'] == rule['driver']:
153
+ if rule['path_regex'] is None:
154
+ return True
155
+
156
+ # check if path matches the regex.
157
+ if re.match(rule['path_regex'], tspec['path']):
158
+ return True
159
+
160
+ return False
161
+
162
+
163
+ # Lifted from T5X.
164
+ class _LimitInFlightBytes:
165
+ """Limits in-flight bytes when reading/writing checkpoints per process."""
166
+
167
+ def __init__(self, num_bytes):
168
+ self._max_bytes = num_bytes
169
+ self._available_bytes = num_bytes
170
+ self._cv = asyncio.Condition(lock=asyncio.Lock())
171
+
172
+ async def wait_for_bytes(self, requested_bytes):
173
+ if requested_bytes >= self._max_bytes:
174
+ raise ValueError('Requested more bytes than we reserved space for: '
175
+ f'{requested_bytes} > {self._max_bytes}')
176
+ async with self._cv:
177
+ await self._cv.wait_for(lambda: self._available_bytes > requested_bytes)
178
+ self._available_bytes -= requested_bytes
179
+ assert self._available_bytes >= 0
180
+
181
+ async def release_bytes(self, requested_bytes):
182
+ async with self._cv:
183
+ self._available_bytes += requested_bytes
184
+ assert self._available_bytes <= self._max_bytes
185
+ self._cv.notify_all()
186
+
187
+
188
+ async def async_serialize(
189
+ arr_inp,
190
+ tensorstore_spec,
191
+ commit_future=None,
192
+ context=TS_CONTEXT,
193
+ primary_host: Optional[int] = 0,
194
+ replica_id: int = 0,
195
+ ):
196
+ """Serialize an array using TensorStore.
197
+
198
+ Args:
199
+ arr_inp: The array to serialize.
200
+ tensorstore_spec: The tensorstore spec to use.
201
+ commit_future: A list of futures that will be appended to. The futures can
202
+ be awaited asynchronously. If None, the futures will be awaited
203
+ synchronously by this method.
204
+ context: ts.Context instance.
205
+ primary_host: Primary host, which indicates the host that will be treated as
206
+ the "leader". If None, all hosts are treated as the primary. DO NOT USE
207
+ unless you are sure you know what you are doing.
208
+ replica_id: Allows overriding the shard replica id that will be saved.
209
+ DO NOT USE unless you are sure you know what you are doing.
210
+ """
211
+ if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
212
+ arr_inp.is_fully_addressable):
213
+ raise ValueError(
214
+ f'Passing fully addressable arrays to a multiprocess '
215
+ f'serialization is not allowed, as this may lead to a race condition '
216
+ f'between processes. Serialization have failed for the array with '
217
+ f'the path "{tensorstore_spec["kvstore"]["path"]}".')
218
+
219
+ # 'metadata' may not be present at the top level (for example, if we are using
220
+ # a 'cast' driver).
221
+ if not _spec_has_metadata(tensorstore_spec):
222
+ tensorstore_spec['metadata'] = _get_metadata(arr_inp)
223
+
224
+ # Set dtype if it's not in spec
225
+ if 'dtype' not in tensorstore_spec:
226
+ tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name
227
+
228
+ # If primary_host is None, all hosts will checkpoint. This is used
229
+ # for checkpointing to local filesystem.
230
+ if primary_host is None or jax.process_index() == primary_host:
231
+ open_future = ts.open(
232
+ ts.Spec(tensorstore_spec),
233
+ create=True,
234
+ open=True,
235
+ context=context,
236
+ )
237
+ # Asynchronous case.
238
+ if commit_future is not None:
239
+ assert isinstance(commit_future, list)
240
+ commit_future.append(open_future)
241
+ else:
242
+ await open_future
243
+
244
+ # `ts.open` runs twice for process `primary_host` because for the first time,
245
+ # we just get the future to be awaited upon in the background thread. The
246
+ # second one runs with `assume_metadata=True` which does no I/O operation and
247
+ # returns the tensorstore object.
248
+ # For every process other than `primary_host`, we open with
249
+ # `assume_metadata=True`.
250
+ t = await ts.open(
251
+ ts.Spec(tensorstore_spec),
252
+ open=True,
253
+ assume_metadata=True,
254
+ context=context,
255
+ )
256
+
257
+ async def _write_array(shard):
258
+ if shard.replica_id == replica_id:
259
+ write_future = t[shard.index].write(shard.data)
260
+ if commit_future is not None:
261
+ assert isinstance(commit_future, list)
262
+ commit_future.append(write_future.commit)
263
+ await write_future.copy
264
+ else:
265
+ await write_future.commit
266
+
267
+ local_shards = arr_inp.addressable_shards
268
+ future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
269
+ return await asyncio.gather(*future_write_state)
270
+
271
+
272
+ def run_serialization(arrays, tensorstore_specs):
273
+ async def _run_serializer():
274
+ future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
275
+ return await asyncio.gather(*future_writer)
276
+ asyncio.run(_run_serializer())
277
+
278
+
279
+ def estimate_read_memory_footprint(t: ts.TensorStore,
280
+ domain: ts.IndexDomain) -> int:
281
+ rank = t.rank
282
+ num_bytes = t.dtype.numpy_dtype.itemsize
283
+ chunk_template = t.chunk_layout.read_chunk_template
284
+ if domain is None:
285
+ domain = t.domain
286
+ origin = domain.origin
287
+ shape = domain.shape
288
+ chunk_origin = chunk_template.origin
289
+ chunk_shape = chunk_template.shape
290
+
291
+ # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
292
+ # For those, instead of returning a near-infinite memory footprint, estimate
293
+ # the footprint as the entire shape.
294
+ for i in range(rank):
295
+ if not chunk_template[i].finite:
296
+ return domain.size * num_bytes
297
+
298
+ # Otherwise, if we have a chunked driver, estimate based on chunk size.
299
+ for i in range(rank):
300
+ origin_value = origin[i]
301
+ chunk_origin_value = chunk_origin[i]
302
+ chunk_size = chunk_shape[i]
303
+ lower = origin_value - chunk_origin_value
304
+ upper = origin_value + shape[i] - chunk_origin_value
305
+ lower_aligned = lower // chunk_size * chunk_size
306
+ upper_aligned = -(-upper // chunk_size) * chunk_size
307
+ num_bytes *= (upper_aligned - lower_aligned)
308
+
309
+ return num_bytes
310
+
311
+
312
+ async def async_deserialize(
313
+ user_in_sharding: jax.sharding.Sharding | Layout,
314
+ tensorstore_spec: ts.Spec | dict[str, Any],
315
+ global_shape: Sequence[int] | None = None,
316
+ dtype=None,
317
+ byte_limiter: _LimitInFlightBytes | None = None,
318
+ context=TS_CONTEXT,
319
+ assume_metadata: bool = False,
320
+ ):
321
+ in_sharding = (user_in_sharding.sharding
322
+ if isinstance(user_in_sharding, Layout) else user_in_sharding)
323
+ if not isinstance(in_sharding, jax.sharding.Sharding):
324
+ raise ValueError(
325
+ 'sharding passed to deserialization should be specified, concrete and'
326
+ f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
327
+ dll = (user_in_sharding.device_local_layout
328
+ if isinstance(user_in_sharding, Layout) else None)
329
+ t = await ts.open(
330
+ tensorstore_spec,
331
+ open=True,
332
+ assume_metadata=assume_metadata,
333
+ context=context,
334
+ )
335
+ shape = t.shape if global_shape is None else global_shape
336
+ new_shard_shape = in_sharding.shard_shape(tuple(shape))
337
+
338
+ async def cb(index: array.Index, device: jax.Device):
339
+ requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
340
+ restricted_domain = t.domain.intersect(requested_domain)
341
+ requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
342
+ # Limit the bytes read for every shard.
343
+ if byte_limiter is not None:
344
+ await byte_limiter.wait_for_bytes(requested_bytes)
345
+ # This maybe needed because the shape the array was saved with is smaller
346
+ # than the requested shape of the array in which it will be reloaded. So
347
+ # the extra values will be filled with 0s.
348
+ out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
349
+ await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][
350
+ restricted_domain].write(t[restricted_domain])
351
+ if dtype is not None:
352
+ # Cast while reloading on process to avoid 2 copies on device if the
353
+ # casting is done on device.
354
+ out = out.astype(dtype)
355
+ # Convert to jnp array so that layouts are initialized properly for
356
+ # sub-byte dtypes.
357
+ # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to
358
+ # make this work.
359
+ if out.dtype == jnp.int4:
360
+ out = jnp.asarray(out) # type: ignore
361
+ result = jax.device_put(
362
+ out, Layout(dll, jax.sharding.SingleDeviceSharding(device)))
363
+ if byte_limiter is not None:
364
+ # NB: `out` actually might not be ready for garbage collection by the
365
+ # time we call release_bytes . Thus peak memory usage still might grow
366
+ # beyond what byte_limiter limit suggests it should. The simplest option
367
+ # would be to call `result.block_until_ready()`` here. However it
368
+ # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU
369
+ # transfer instead of loading data. In the future, if memory pressure
370
+ # becomes a problem, we can instead instrument bytelimiter to
371
+ # keep track of all in-flight tensors and only block_until_ready, if byte
372
+ # limiter hits the limit to get reduced memory usage, without losing
373
+ # performance in common use cases.
374
+ await byte_limiter.release_bytes(requested_bytes)
375
+ return result
376
+
377
+ return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
378
+
379
+
380
+ def run_deserialization(shardings: Sequence[sharding.Sharding | Layout],
381
+ tensorstore_specs: Sequence[dict[str, Any]],
382
+ global_shapes: Sequence[array.Shape] | None = None,
383
+ dtypes: Sequence[typing.DTypeLike] | None = None,
384
+ concurrent_gb: int = 32):
385
+ concurrent_bytes = concurrent_gb * 10**9
386
+
387
+ async def _run_deserializer():
388
+ # Object should be created once per process.
389
+ byte_limiter = _LimitInFlightBytes(concurrent_bytes)
390
+
391
+ future_arrays = jax.tree_util.tree_map(
392
+ partial(async_deserialize, byte_limiter=byte_limiter),
393
+ shardings, tensorstore_specs,
394
+ [None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
395
+ [None] * len(tensorstore_specs) if dtypes is None else dtypes)
396
+ return await asyncio.gather(*future_arrays)
397
+ return asyncio.run(_run_deserializer())
398
+
399
+
400
+ def _get_key(key: int):
401
+ return f'tensorstore_checkpoint_{key}'
402
+
403
+
404
+ class GlobalAsyncCheckpointManagerBase(util.StrictABC):
405
+ """Interface for checkpointing GDAs asynchronously.
406
+
407
+ This class manages the state of an ongoing asynchronous checkpoint.
408
+
409
+ For example, say a checkpoint happens on every step. If you checkpoint on
410
+ step 1 and after some computation the model is on checkpoint 2. But step 1's
411
+ checkpoint hasn't finished committing to the storage layer yet. So until that
412
+ is finished, checkpoint for step 2 will need to be blocked. Maintaining a
413
+ class allows to maintain that state.
414
+
415
+ Example:
416
+
417
+ Below is a simplified training loop:
418
+
419
+ ```
420
+ # Call this at the start of your program.
421
+ jax.distributed.initialize()
422
+
423
+ manager = GlobalAsyncCheckpointManager()
424
+
425
+ # Restore checkpoint if available or initialize the train_state from
426
+ # init_fn().
427
+ train_state = manager.deserialize(...)
428
+
429
+ while ...:
430
+ if step % num_steps_between_checkpoints == 0:
431
+ manager.serialize(train_state, temp_checkpoint_dir=...,
432
+ final_checkpoint_dir=...)
433
+ train_state = train_step(train_state, input)
434
+ # This is a non-blocking call.
435
+ manager.check_for_errors()
436
+
437
+ manager.serialize(train_state, temp_checkpoint_dir=...,
438
+ final_checkpoint_dir=...)
439
+ # Wait before the end of the program for the checkpoint to finish. This is a
440
+ # blocking call.
441
+ manager.wait_until_finished()
442
+ ```
443
+ """
444
+
445
+ @abc.abstractmethod
446
+ def check_for_errors(self):
447
+ """Checks if any errors have been raised in the child thread.
448
+
449
+ This is a non-blocking call that can be called in the main thread.
450
+ """
451
+
452
+ @abc.abstractmethod
453
+ def wait_until_finished(self):
454
+ """Blocks until serialization has finished."""
455
+
456
+ @abc.abstractmethod
457
+ def serialize(self, arrays, tensorstore_specs, *,
458
+ on_commit_callback: Callable[[], None]):
459
+ """Serializes GDAs to TensorStore."""
460
+
461
+ @abc.abstractmethod
462
+ def deserialize(self, shardings: Sequence[sharding.Sharding],
463
+ tensorstore_specs: Sequence[dict[str, Any]],
464
+ global_shapes: Sequence[array.Shape] | None = None,
465
+ dtypes: Sequence[typing.DTypeLike] | None = None):
466
+ """Deserializes GDAs from TensorStore."""
467
+
468
+
469
+ class AsyncManager:
470
+
471
+ def __init__(self, timeout_secs=300):
472
+ self._timeout_secs = timeout_secs
473
+ self._timeout_in_ms = self._timeout_secs * 1000
474
+
475
+ self._commit_futures = None
476
+ self._thread = None
477
+ self._exception = None
478
+
479
+ if jax.process_count() > 1 and distributed.global_state.client is None:
480
+ raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
481
+ if jax.process_count() > 1:
482
+ self._client = distributed.global_state.client
483
+ self._count = None
484
+
485
+ def __del__(self):
486
+ if self._thread is not None and self._thread.is_alive():
487
+ logger.warning('Please add `.wait_until_finished()` in the main thread '
488
+ 'before your program finishes because there is a '
489
+ 'possibility of losing errors raised if the '
490
+ 'this class is deleted before writing is completed.')
491
+
492
+ def _thread_func(self):
493
+ try:
494
+ current_process = jax.process_index()
495
+ process_count = jax.process_count()
496
+ logger.info('Starting commit to storage layer by process: %s',
497
+ current_process)
498
+ thread_start_time = time.time()
499
+ for future in self._commit_futures:
500
+ future.result()
501
+ logger.info('Finished committing to storage layer by process: %s',
502
+ current_process)
503
+
504
+ if process_count > 1:
505
+ # All processes will wait at the barrier. When all processes are at the
506
+ # barrier, the barrier will be satisfied. If not, then it will timeout.
507
+ key_for_barrier = _get_key(self._count)
508
+ logger.info('Key used for barrier is %s for process %s',
509
+ key_for_barrier, current_process)
510
+ self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
511
+ logger.info('Finished waiting at barrier for process %s',
512
+ current_process)
513
+
514
+ if current_process == 0:
515
+ self._on_commit_callback()
516
+ logger.info('on_commit_callback successfully ran!')
517
+ if process_count > 1:
518
+ self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
519
+ logger.info('Process 0 successfully set key %s in the kv store',
520
+ key_for_barrier)
521
+
522
+ jax.monitoring.record_event_duration_secs(
523
+ '/jax/checkpoint/write/async/thread_duration_sec',
524
+ time.time() - thread_start_time)
525
+
526
+ except Exception as e:
527
+ self._exception = e
528
+
529
+ def _start_async_commit(self, on_commit_callback):
530
+ self._count = next(_module_unique_count)
531
+
532
+ self._on_commit_callback = on_commit_callback
533
+ self._thread = threading.Thread(target=self._thread_func)
534
+ self._thread.start()
535
+
536
+ def check_for_errors(self):
537
+ if self._exception is not None:
538
+ # Clears self._exception so it is only raised once.
539
+ exception = self._exception
540
+ self._exception = None
541
+ if (isinstance(exception, xe.XlaRuntimeError) and
542
+ 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)):
543
+ raise BarrierTimeoutException(
544
+ '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG]))
545
+ raise exception # pylint: disable=raising-bad-type
546
+
547
+ def wait_until_finished(self):
548
+ if self._thread is not None:
549
+ self._thread.join()
550
+ self._thread = None
551
+ logger.info('Thread joined successfully')
552
+
553
+ self.check_for_errors()
554
+ logger.info('Error check finished successfully')
555
+
556
+ if jax.process_count() > 1 and self._count is not None:
557
+ # Block until process 0 writes success value to the key value store.
558
+ # If it fails to write it, then `blocking_key_value_get` will time out.
559
+ get_key = _get_key(self._count)
560
+ self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
561
+ logger.info('blocking_key_value_get on key %s was successfully '
562
+ 'completed.', get_key)
563
+
564
+ def _add_futures(self, futures: Sequence[asyncio.Future]):
565
+ self._commit_futures = futures
566
+
567
+
568
+ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
569
+ """Responsible for serializing GDAs via TensorStore."""
570
+
571
+ def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
572
+ """Serializes Arrays or Arrays via TensorStore asynchronously.
573
+
574
+ TensorStore writes to a storage layer in 2 steps:
575
+ * Reading/copying from the source after which the source can be modified.
576
+ * Returns a copy future.
577
+ * Writing/committing to the storage layer.
578
+ * Returns a commit future.
579
+
580
+ In asynchronous mode, the serialization waits for the commit future to
581
+ finish in a separate thread allowing other computation to proceed.
582
+
583
+ Args:
584
+ arrays: Arrays or Arrays that should be serialized.
585
+ tensorstore_specs: TensorStore specs that are used to serialize GDAs or
586
+ Arrays.
587
+ on_commit_callback: This callback will be executed after all processes
588
+ have finished writing their checkpoints to disk. Filesystems where
589
+ atomic rename operations are supported, you can rename from the
590
+ temporary directory to the final directory. On GCS, you write to the
591
+ final directory directly and in `on_commit_callback` you write a
592
+ success file indicating that the serialization was successful because
593
+ GCS does not support atomic rename operations.
594
+ """
595
+ logger.info('Waiting for previous serialization to finish.')
596
+ self.wait_until_finished()
597
+
598
+ commit_futures = [[] for _ in range(len(tensorstore_specs))]
599
+
600
+ async def _run_serializer():
601
+ future_writer = jax.tree_util.tree_map(
602
+ async_serialize, arrays, tensorstore_specs, commit_futures)
603
+ return await asyncio.gather(*future_writer)
604
+
605
+ asyncio.run(_run_serializer())
606
+
607
+ self._add_futures(jax.tree_util.tree_flatten(commit_futures)[0])
608
+
609
+ # Used in wait_until_finished to check on process != 0, if the checkpoint
610
+ # has finished writing.
611
+ self._start_async_commit(on_commit_callback)
612
+
613
+ def serialize_with_paths(self, arrays: Sequence[jax.Array],
614
+ paths: Sequence[str], *, on_commit_callback):
615
+ tspecs = jax.tree.map(get_tensorstore_spec, paths)
616
+ self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)
617
+
618
+ def deserialize(self, shardings: Sequence[sharding.Sharding | Layout],
619
+ tensorstore_specs: Sequence[dict[str, Any]],
620
+ global_shapes: Sequence[array.Shape] | None = None,
621
+ dtypes: Sequence[typing.DTypeLike] | None = None,
622
+ concurrent_gb: int = 32):
623
+ self.wait_until_finished()
624
+ return run_deserialization(shardings, tensorstore_specs,
625
+ global_shapes, dtypes, concurrent_gb)
626
+
627
+ def deserialize_with_paths(
628
+ self, shardings: Sequence[sharding.Sharding],
629
+ paths: Sequence[str],
630
+ global_shapes: Sequence[array.Shape] | None = None,
631
+ dtypes: Sequence[typing.DTypeLike] | None = None,
632
+ concurrent_gb: int = 32):
633
+ tspecs = jax.tree.map(get_tensorstore_spec, paths)
634
+ return self.deserialize(shardings, tspecs, global_shapes, dtypes,
635
+ concurrent_gb)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/array_serialization/serialization_test.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for serialization and deserialization of GDA."""
15
+
16
+ import asyncio
17
+ import contextlib
18
+ import math
19
+ from functools import partial
20
+ import re
21
+ import os
22
+ import pathlib
23
+ import tracemalloc as tm
24
+
25
+ from absl.testing import absltest
26
+ from absl.testing import parameterized
27
+ import jax
28
+ import jax.numpy as jnp
29
+ from jax._src import test_util as jtu
30
+ from jax._src import array
31
+ from jax._src import xla_bridge as xb
32
+ from jax.sharding import NamedSharding, GSPMDSharding
33
+ from jax.sharding import PartitionSpec as P
34
+ from jax.experimental.array_serialization import serialization
35
+ from jax.experimental.layout import Layout, DeviceLocalLayout as DLL
36
+ import numpy as np
37
+ import tensorstore as ts
38
+
39
+ jax.config.parse_flags_with_absl()
40
+ _exit_stack = contextlib.ExitStack()
41
+
42
+ def setUpModule():
43
+ _exit_stack.enter_context(jtu.set_host_platform_device_count(8))
44
+
45
+ def tearDownModule():
46
+ _exit_stack.close()
47
+
48
+
49
+ pattern = re.compile(r"\{(.*?):")
50
+
51
+ def extract_minor_to_major(l):
52
+ match = re.search(pattern, str(l))
53
+ return tuple(int(i) for i in match.groups()[0].split(','))
54
+
55
+
56
+ class CheckpointTest(jtu.JaxTestCase):
57
+
58
+ def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
59
+ os.rename(temp_ckpt_dir, final_ckpt_dir)
60
+
61
+ @jtu.skip_on_devices('cpu')
62
+ def test_memory_consumption(self):
63
+ global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y'))
64
+ inp_shape = (2_048, 4_096)
65
+ pspec = P('x', 'y')
66
+ num = math.prod(inp_shape)
67
+ sharding = NamedSharding(global_mesh, pspec)
68
+ src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9
69
+ inp = array.make_array_from_callback(
70
+ inp_shape, sharding,
71
+ lambda idx: src[idx])
72
+ ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path)
73
+ tspec = serialization.get_tensorstore_spec(str(ckpt_dir))
74
+
75
+ manager = serialization.GlobalAsyncCheckpointManager()
76
+ manager.serialize(
77
+ [inp], [tspec],
78
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
79
+ manager.wait_until_finished()
80
+
81
+ async def deserialize_with_byte_limit():
82
+ r = await serialization.async_deserialize(
83
+ sharding, tspec, inp_shape,
84
+ byte_limiter=serialization._LimitInFlightBytes(4_200_000))
85
+ r.block_until_ready()
86
+
87
+ tm.start()
88
+ asyncio.run(deserialize_with_byte_limit())
89
+ unused_current, peak = tm.get_traced_memory()
90
+ # NB: some padding + tensorstore overhead. It should always be
91
+ # less than array size (2048 * 4096 * 4 = 32M)
92
+ self.assertLess(peak, 10_000_000)
93
+ deserialize_wo_limit = serialization.async_deserialize(
94
+ sharding, tspec, inp_shape)
95
+ tm.clear_traces()
96
+ # NB: call block_until_ready() is important here and above
97
+ # because otherwise this leads to racing condition and segfault with
98
+ # tensorstore attempting to dealloc using tracemalloc which is already
99
+ # destroyed.
100
+ asyncio.run(deserialize_wo_limit).block_until_ready()
101
+
102
+ unused_current, peak = tm.get_traced_memory()
103
+ # We load entire array in memory here.
104
+ self.assertGreater(peak, 30_000_000)
105
+ tm.stop()
106
+
107
+ def test_checkpointing_with_path_variant(self):
108
+ global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
109
+ inp_shape = (8, 2)
110
+ pspec = P('x', 'y')
111
+ num = math.prod(inp_shape)
112
+
113
+ # First Array
114
+ global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
115
+ a1 = array.make_array_from_callback(
116
+ inp_shape, NamedSharding(global_mesh, pspec),
117
+ lambda idx: global_input_data1[idx])
118
+ ckpt_dir = pathlib.Path(self.create_tempdir('ckpt_variant').full_path)
119
+ ckpt_path1 = pathlib.Path(
120
+ self.create_tempdir(f'{ckpt_dir}/first').full_path)
121
+
122
+ ckpt_paths = [str(ckpt_path1)]
123
+ manager = serialization.GlobalAsyncCheckpointManager()
124
+ manager.serialize_with_paths(
125
+ [a1], ckpt_paths,
126
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
127
+ manager.wait_until_finished()
128
+
129
+ m1, = manager.deserialize_with_paths(
130
+ [NamedSharding(global_mesh, pspec)], ckpt_paths)
131
+ self.assertIsInstance(m1, array.ArrayImpl)
132
+ self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
133
+ np.array([[0], [2]], dtype=np.int32))
134
+ self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
135
+ np.array([[1], [3]], dtype=np.int32))
136
+ self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
137
+ self.assertEqual(m1.dtype, np.int32)
138
+
139
+ def test_checkpointing_jax_array(self):
140
+ global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
141
+ inp_shape = (8, 2)
142
+ pspec = P('x', 'y')
143
+ num = math.prod(inp_shape)
144
+
145
+ # First Array
146
+ global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape)
147
+ a1 = array.make_array_from_callback(
148
+ inp_shape, NamedSharding(global_mesh, pspec),
149
+ lambda idx: global_input_data1[idx])
150
+ ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
151
+ ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
152
+
153
+ # Second Array
154
+ global_input_data2 = np.arange(
155
+ num, num + num, dtype=np.int32).reshape(inp_shape)
156
+ a2 = array.make_array_from_callback(
157
+ inp_shape, NamedSharding(global_mesh, pspec),
158
+ lambda idx: global_input_data2[idx])
159
+ ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path)
160
+
161
+ # Third Array
162
+ def cb3(_):
163
+ return np.array([], dtype=np.float32)
164
+ global_mesh1d = jtu.create_global_mesh((8,), ('x',))
165
+ a3 = array.make_array_from_callback(
166
+ (0,), NamedSharding(global_mesh1d, P(None)), cb3)
167
+ ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path)
168
+
169
+ ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)]
170
+ tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
171
+
172
+ manager = serialization.GlobalAsyncCheckpointManager()
173
+ manager.serialize(
174
+ [a1, a2, a3], tspecs,
175
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
176
+ manager.wait_until_finished()
177
+
178
+ m1, m2, m3 = serialization.run_deserialization(
179
+ [NamedSharding(global_mesh, pspec),
180
+ NamedSharding(global_mesh, P('x')),
181
+ NamedSharding(global_mesh1d, P(None))],
182
+ tspecs)
183
+
184
+ self.assertIsInstance(m1, array.ArrayImpl)
185
+ self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
186
+ np.array([[0], [2]], dtype=np.int32))
187
+ self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
188
+ np.array([[1], [3]], dtype=np.int32))
189
+ self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
190
+ self.assertEqual(m1.dtype, np.int32)
191
+
192
+ self.assertIsInstance(m2, array.ArrayImpl)
193
+ self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
194
+ np.array([[16, 17], [18, 19]], dtype=np.int32))
195
+ self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
196
+ np.array([[16, 17], [18, 19]], dtype=np.int32))
197
+ self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
198
+ self.assertEqual(m2.dtype, np.int32)
199
+
200
+ self.assertIsInstance(m3, array.ArrayImpl)
201
+ for i, s in enumerate(m3.addressable_shards):
202
+ self.assertEqual(s.index, (slice(None),))
203
+ self.assertEqual(s.replica_id, i)
204
+ self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32))
205
+ self.assertEqual(m3.dtype, np.float32)
206
+
207
+ @parameterized.product(input_dtype=[np.int32, jnp.bfloat16])
208
+ def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype):
209
+ global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
210
+ global_input_shape = (8, 2)
211
+ num = math.prod(global_input_shape)
212
+
213
+ global_input_data1 = np.arange(num, dtype=input_dtype).reshape(
214
+ global_input_shape
215
+ )
216
+ def cb1(index):
217
+ return global_input_data1[index]
218
+ arr = array.make_array_from_callback(
219
+ global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb1)
220
+ ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
221
+
222
+ ckpt_paths = [str(ckpt_dir)]
223
+ tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
224
+
225
+ manager = serialization.GlobalAsyncCheckpointManager()
226
+ manager.serialize(
227
+ [arr], tspecs,
228
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
229
+ manager.wait_until_finished()
230
+
231
+ ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
232
+
233
+ m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
234
+ [np.float32])
235
+
236
+ expected_data = {
237
+ 0: np.array([[0], [2], [4]], dtype=np.float32),
238
+ 1: np.array([[1], [3], [5]], dtype=np.float32),
239
+ 2: np.array([[6], [8], [10]], dtype=np.float32),
240
+ 3: np.array([[7], [9], [11]], dtype=np.float32),
241
+ 4: np.array([[12], [14], [0]], dtype=np.float32),
242
+ 5: np.array([[13], [15], [0]], dtype=np.float32),
243
+ 6: np.array([[0], [0], [0]], dtype=np.float32),
244
+ 7: np.array([[0], [0], [0]], dtype=np.float32),
245
+ }
246
+
247
+ for l in m1.addressable_shards:
248
+ self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
249
+
250
+ new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
251
+ m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
252
+ for l in m2.addressable_shards:
253
+ self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
254
+
255
+ @parameterized.product(input_dtype=[jnp.int4, jnp.int8])
256
+ def test_checkpointing_with_int4(self, input_dtype):
257
+ global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
258
+ global_input_shape = (8, 2)
259
+ num = math.prod(global_input_shape)
260
+
261
+ global_input_data = np.arange(num, dtype=input_dtype).reshape(
262
+ global_input_shape
263
+ )
264
+ def cb(index):
265
+ return global_input_data[index]
266
+ arr = array.make_array_from_callback(
267
+ global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb)
268
+ ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
269
+
270
+ ckpt_paths = [str(ckpt_dir)]
271
+ tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
272
+
273
+ manager = serialization.GlobalAsyncCheckpointManager()
274
+ manager.serialize(
275
+ [arr], tspecs,
276
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
277
+ manager.wait_until_finished()
278
+
279
+ ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
280
+
281
+ target_dtype = jnp.dtype('int4')
282
+ m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
283
+ [target_dtype])
284
+
285
+ # values bigger than 7 are converted properly.
286
+ expected_data = {
287
+ 0: jnp.array([[0], [2], [4]], dtype=target_dtype),
288
+ 1: jnp.array([[1], [3], [5]], dtype=target_dtype),
289
+ 2: jnp.array([[6], [8], [10]], dtype=target_dtype),
290
+ 3: jnp.array([[7], [9], [11]], dtype=target_dtype),
291
+ 4: jnp.array([[12], [14], [0]], dtype=target_dtype),
292
+ 5: jnp.array([[13], [15], [0]], dtype=target_dtype),
293
+ 6: jnp.array([[0], [0], [0]], dtype=target_dtype),
294
+ 7: jnp.array([[0], [0], [0]], dtype=target_dtype),
295
+ }
296
+
297
+ for l in m1.addressable_shards:
298
+ self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
299
+
300
+ new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
301
+ m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype])
302
+ for l in m2.addressable_shards:
303
+ self.assertArraysEqual(l.data, global_input_data.astype(target_dtype))
304
+
305
+ def test_checkpointing_scalar_jax_array(self):
306
+ global_mesh = jtu.create_global_mesh((2,), ('x'))
307
+ global_input_shape = ()
308
+ data = np.array(4)
309
+ s = NamedSharding(global_mesh, P(None))
310
+ array1 = array.make_array_from_callback(
311
+ global_input_shape, s, lambda idx: data[idx])
312
+ ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
313
+
314
+ ckpt_paths = [str(ckpt_dir)]
315
+ tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
316
+
317
+ manager = serialization.GlobalAsyncCheckpointManager()
318
+ manager.serialize(
319
+ [array1], tspecs,
320
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
321
+ manager.wait_until_finished()
322
+
323
+ ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))
324
+
325
+ m1, = serialization.run_deserialization(
326
+ [ds],
327
+ tspecs,
328
+ [()],
329
+ [np.float32]
330
+ )
331
+
332
+ for l in m1.addressable_shards:
333
+ self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))
334
+
335
+ def test_deserialize_tensorstore_array_jax_array(self):
336
+ global_mesh = jtu.create_global_mesh((2,), ('x'))
337
+ data = np.arange(1024)
338
+ tspec = ts.array(data).spec()
339
+ m1, = serialization.run_deserialization(
340
+ [NamedSharding(global_mesh, P(None))],
341
+ [tspec]
342
+ )
343
+ for l in m1.addressable_shards:
344
+ self.assertArraysEqual(np.asarray(l.data), data)
345
+
346
+ def test_spec_has_metadata(self):
347
+ spec = {
348
+ 'a': {
349
+ 'b': 1,
350
+ 'c': 2,
351
+ },
352
+ 'd': 3,
353
+ 'e': {
354
+ 'a': 2,
355
+ 'metadata': 3
356
+ },
357
+ 'f': 4
358
+ }
359
+ self.assertTrue(serialization._spec_has_metadata(spec))
360
+ self.assertTrue(
361
+ serialization._spec_has_metadata({
362
+ 'driver': 'zarr',
363
+ 'kvstore': 'gfile',
364
+ 'metadata': {
365
+ 'chunks': 4,
366
+ 'shape': (32, 64)
367
+ },
368
+ 'one_more': 'thing'
369
+ }))
370
+
371
+ def test_spec_has_no_metadata(self):
372
+ spec = {
373
+ 'a': {
374
+ 'b': 1,
375
+ 'c': 2,
376
+ },
377
+ 'd': 3,
378
+ 'e': {
379
+ 'a': 2,
380
+ },
381
+ 'f': 4
382
+ }
383
+ self.assertFalse(serialization._spec_has_metadata(spec))
384
+
385
+ def test_empty_spec_has_no_metadata(self):
386
+ spec = {}
387
+ self.assertFalse(serialization._spec_has_metadata(spec))
388
+
389
+ @parameterized.named_parameters(
390
+ ('gcs', 'gs://my/ckpt/dir/path'),
391
+ ('file', '/my/ckpt/dir/path')
392
+ )
393
+ def test_get_tensorstore_spec_ocdbt(self, path):
394
+ spec = serialization.get_tensorstore_spec(path, ocdbt=True)
395
+ is_gcs_path = path.startswith('gs://')
396
+ if is_gcs_path:
397
+ self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
398
+ else:
399
+ self.assertEqual(spec['kvstore']['base'],
400
+ f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}')
401
+ self.assertEqual(spec['kvstore']['path'], 'path')
402
+
403
+ def test_get_tensorstore_spec_not_absolute_path(self):
404
+ path = 'my/ckpt/path'
405
+ with self.assertRaisesRegex(ValueError,
406
+ "Checkpoint path should be absolute"):
407
+ serialization.get_tensorstore_spec(path, ocdbt=True)
408
+
409
+ def test_maybe_cloud_storage(self):
410
+ gs_path = 'gs://some-buck/path'
411
+ gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True)
412
+ self.assertTrue(serialization.is_remote_storage(gs_spec))
413
+
414
+ local_path = '/tmp/checkpoint'
415
+ local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True)
416
+ self.assertFalse(serialization.is_remote_storage(local_spec))
417
+
418
+ nested_tspec = {
419
+ 'driver': 'cast',
420
+ 'dtype': 'int32',
421
+ 'base': {
422
+ 'driver': 'zarr',
423
+ 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'},
424
+ },
425
+ }
426
+ self.assertTrue(serialization.is_remote_storage(nested_tspec))
427
+
428
+ def test_load_with_layout(self):
429
+ if not jtu.test_device_matches(['tpu']):
430
+ self.skipTest('Layouts are only supported on TPUs')
431
+
432
+ mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
433
+ np_inp = np.arange(32).reshape(8, 4)
434
+ s = NamedSharding(mesh, P('x', 'y'))
435
+ arr = jax.device_put(np_inp, s)
436
+
437
+ out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower(
438
+ arr).compile().output_layouts()
439
+ self.assertEqual(extract_minor_to_major(arr.layout),
440
+ extract_minor_to_major(out_layout)[::-1])
441
+
442
+ ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
443
+ ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
444
+ tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path])
445
+
446
+ manager = serialization.GlobalAsyncCheckpointManager()
447
+ manager.serialize(
448
+ [arr], tspecs,
449
+ on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
450
+ manager.wait_until_finished()
451
+
452
+ out, = serialization.run_deserialization([out_layout], tspecs)
453
+
454
+ self.assertEqual(out.layout, out_layout)
455
+ self.assertIsInstance(out, array.ArrayImpl)
456
+ self.assertArraysEqual(out, np_inp)
457
+ for s in out.addressable_shards:
458
+ self.assertArraysEqual(s.data, np_inp[s.index])
459
+
460
+ def test_deserialization_with_int4(self):
461
+ dtype = jnp.int4
462
+ shape = (8, 2)
463
+ arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
464
+
465
+ ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path)
466
+
467
+ # Run serialization.
468
+ sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices())
469
+ tspecs = jax.tree_util.tree_map(
470
+ serialization.get_tensorstore_spec, [ckpt_dir]
471
+ )
472
+ manager = serialization.GlobalAsyncCheckpointManager()
473
+ manager.serialize(
474
+ [arr],
475
+ tspecs,
476
+ on_commit_callback=lambda: None,
477
+ )
478
+ manager.wait_until_finished()
479
+
480
+ # Run deserialization.
481
+ deserialized_arr, = serialization.run_deserialization(
482
+ shardings=[sharding],
483
+ tensorstore_specs=tspecs,
484
+ global_shapes=[shape],
485
+ dtypes=[dtype],
486
+ )
487
+
488
+ out = deserialized_arr.astype(jnp.int8) # doesn't crash
489
+ self.assertEqual(out.dtype, jnp.int8)
490
+ self.assertArraysEqual(out + out, out * 2)
491
+
492
+ if __name__ == '__main__':
493
+ absltest.main(testLoader=jtu.JaxTestLoader())
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/compilation_cache/compilation_cache.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from jax._src.compilation_cache import (
16
+ is_initialized as is_initialized, # deprecated
17
+ initialize_cache as initialize_cache, # deprecated; use set_cache_dir instead
18
+ set_cache_dir as set_cache_dir,
19
+ reset_cache as reset_cache,
20
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/export/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from jax._src.export._export import (
17
+ minimum_supported_serialization_version,
18
+ maximum_supported_serialization_version,
19
+ Exported,
20
+ export,
21
+ call_exported, # TODO: deprecate
22
+ call,
23
+ DisabledSafetyCheck,
24
+ default_lowering_platform,
25
+ )
26
+ from jax._src.export.shape_poly import (
27
+ is_symbolic_dim,
28
+ symbolic_shape,
29
+ symbolic_args_specs,
30
+ SymbolicScope,
31
+ )
32
+ from jax._src.export.serialization import (
33
+ serialize,
34
+ deserialize,
35
+ )
36
+ from jax._src.export import shape_poly_decision
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from jax.experimental.jax2tf.jax2tf import (
16
+ convert as convert,
17
+ eval_polymorphic_shape as eval_polymorphic_shape,
18
+ dtype_of_val as dtype_of_val,
19
+ split_to_logical_devices as split_to_logical_devices,
20
+ DisabledSafetyCheck as DisabledSafetyCheck,
21
+ PolyShape as PolyShape # TODO: deprecate
22
+ )
23
+ from jax.experimental.jax2tf.call_tf import call_tf as call_tf
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/call_tf.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Allows JAX to call TensorFlow functions with support for autodiff.
15
+
16
+ **Experimental: please give feedback, and expect changes.**
17
+
18
+ This module introduces the function :func:`call_tf` that allows JAX to call
19
+ TensorFlow functions.
20
+
21
+ For examples and details, see
22
+ https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
23
+
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from collections.abc import Sequence
29
+ import dataclasses
30
+ import functools
31
+ from typing import Any, Callable, Optional
32
+
33
+ from absl import logging
34
+ import jax
35
+ from jax import dlpack
36
+ from jax import dtypes
37
+ from jax import numpy as jnp
38
+ from jax import tree_util
39
+ from jax._src import ad_util
40
+ from jax._src import core
41
+ from jax._src import effects
42
+ from jax._src import util
43
+ from jax._src.lib import xla_client
44
+ from jax._src.lib.mlir import ir
45
+ from jax._src.lib.mlir.dialects import func as func_dialect
46
+ from jax._src.lib.mlir.dialects import hlo
47
+ from jax.experimental.jax2tf import jax2tf as jax2tf_internal
48
+ from jax.interpreters import mlir
49
+ import numpy as np
50
+ import tensorflow as tf
51
+
52
+
53
+ map = util.safe_map
54
+ zip = util.safe_zip
55
+
56
+ TfConcreteFunction = Any
57
+ TfVal = jax2tf_internal.TfVal
58
+
59
+ # The platforms for which to use DLPack to avoid copying (only works on GPU
60
+ # and CPU at the moment, and only for Array). For CPU we don't need
61
+ # DLPack, if we are careful.
62
+ _DLPACK_PLATFORMS = ("gpu",)
63
+
64
+ class UnspecifiedOutputShapeDtype:
65
+ pass
66
+
67
+ def call_tf(
68
+ callable_tf: Callable,
69
+ has_side_effects=True,
70
+ ordered=False,
71
+ output_shape_dtype=UnspecifiedOutputShapeDtype(),
72
+ call_tf_graph=False,
73
+ ) -> Callable:
74
+ """Calls a TensorFlow function from JAX, with support for reverse autodiff.
75
+
76
+ The ``callable_tf`` will be called with TensorFlow-compatible arguments (
77
+ numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
78
+ function must return the same type of results.
79
+
80
+ If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
81
+ or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
82
+ ``callable_tf`` will be compiled with ``tf.function(callable_tf,
83
+ jit_compile=True)``
84
+ and the resulting XLA computation will be embedded in JAX's XLA computation.
85
+
86
+ If ``call_tf`` appears outside a JAX staging context, it will be called inline
87
+ using TensorFlow eager mode.
88
+
89
+ The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
90
+ ``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
91
+ that the gradient will be TensorFlow-accurate, e.g., will respect the
92
+ custom gradients that may be defined for the code in ``callable_tf``.
93
+
94
+ For an example and more details see the
95
+ `README
96
+ <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
97
+
98
+ Args:
99
+ callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
100
+ arguments.
101
+ has_side_effects: if True then it ensures that instances of this primitive
102
+ are not removed or replicated by JAX optimizations such as dead-code
103
+ elimination.
104
+ ordered: If true, calls are modeled as having ordered effects.
105
+ output_shape_dtype: An optional declaration of the expected shape and dtype
106
+ of the result of the called TensorFlow function. If given it will be used
107
+ during JAX tracing to form the abstract values of the results of the
108
+ `call_tf`. If not given then we form a `tf.Graph` for the called
109
+ TensorFlow function and we use the TensorFlow-inferred shapes and types.
110
+ Must be a pytree matching the structure of the nested structure returned
111
+ from the TensorFlow function, containing objects with `.shape` and
112
+ `.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
113
+ call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
114
+ future.
115
+
116
+ Returns: a JAX callable that can be invoked with JAX pytree arguments, in
117
+ op-by-op mode or in a staged context. This callable can be used with JAX's
118
+ reverse-mode autodiff (:func:`jax.grad`).
119
+ """
120
+ @jax.custom_vjp
121
+ def make_call(*args_jax):
122
+ """We wrap it all in `make_call` so that we can attach custom VJP."""
123
+
124
+ args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
125
+ # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
126
+ def canonical_arg(v):
127
+ v = v if getattr(v, "dtype", None) else np.asarray(v)
128
+ dtype = dtypes.canonicalize_dtype(v.dtype)
129
+ if dtype != v.dtype:
130
+ v = v.astype(dtype)
131
+ return v
132
+
133
+ args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
134
+ def make_tensorspec(a_jax):
135
+ a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
136
+ a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
137
+ return tf.TensorSpec(a_tf_shape, a_tf_dtype)
138
+ args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
139
+
140
+ if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
141
+ output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
142
+ output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
143
+ else:
144
+ output_avals, output_shape_dtype_tree = None, None
145
+
146
+ res_treedef = None # We'll store here the result treedef
147
+ res_tf_flat = None # For error reporting
148
+ # The function below will be called at least once, either in eager
149
+ # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
150
+ def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
151
+ args_tf = args_treedef.unflatten(args_tf_flat)
152
+ res_tf = callable_tf(*args_tf)
153
+
154
+ # b/279454591: When `callable_tf` is a tf function with zero outputs, it
155
+ # returns a `StatefulPartitionedCall` (if the function is stateful) or
156
+ # `PartitionedCall` (if the function is stateless) op instead of
157
+ # tf.Tensors. We work around this issue by replacing the output `res_tf`
158
+ # with an empty list.
159
+
160
+ if isinstance(res_tf, tf.Operation):
161
+ assert (
162
+ res_tf.type == "StatefulPartitionedCall"
163
+ or res_tf.type == "PartitionedCall"
164
+ )
165
+ t_out = res_tf.get_attr("Tout")
166
+ # t_out should be an empty list.
167
+ assert not t_out, (
168
+ "The TF function returned an unexpected result, please check its"
169
+ f" function body. res_tf = {res_tf}"
170
+ )
171
+ res_tf = t_out
172
+
173
+ nonlocal res_treedef, res_tf_flat
174
+ res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
175
+ assert res_treedef is None or res_treedef == res_treedef_now, (
176
+ f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
177
+ res_treedef = res_treedef_now
178
+ if output_avals is not None:
179
+ if res_treedef != output_shape_dtype_tree:
180
+ raise ValueError(
181
+ "The pytree of the TensorFlow function results does not match the "
182
+ "pytree of the declared output_shape_dtype:\n"
183
+ f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
184
+ assert len(output_avals) == len(res_tf_flat)
185
+
186
+ checked_res_tf_flat = [
187
+ check_tf_result(i, r_tf, r_aval)
188
+ for i, (r_tf, r_aval) in enumerate(
189
+ zip(res_tf_flat,
190
+ (output_avals
191
+ if output_avals is not None
192
+ else (None,) * len(res_tf_flat))))]
193
+ return checked_res_tf_flat
194
+
195
+ # Prepare a tf.function ahead of time, to cache the concrete functions. This
196
+ # won't be used in op-by-op execution mode.
197
+ function_flat_tf = tf.function(
198
+ callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
199
+
200
+ res_jax_flat = call_tf_p.bind(
201
+ *args_flat_jax,
202
+ # Carry the actual function such that op-by-op call can call in TF eager mode.
203
+ callable_flat_tf=callable_flat_tf,
204
+ function_flat_tf=function_flat_tf,
205
+ args_flat_sig_tf=args_flat_sig_tf,
206
+ output_avals=output_avals,
207
+ has_side_effects=has_side_effects,
208
+ ordered=ordered,
209
+ call_tf_graph=call_tf_graph,
210
+ )
211
+
212
+ # We must have called callable_flat_tf by nοw
213
+ assert res_treedef is not None
214
+ return res_treedef.unflatten(res_jax_flat)
215
+
216
+ # Define the fwd and bwd custom_vjp functions
217
+ def make_call_vjp_fwd(*args_jax):
218
+ # Return the primal arguments as the residual
219
+ return make_call(*args_jax), args_jax
220
+
221
+ def make_call_vjp_bwd(residual_jax, ct_res_jax):
222
+ args_jax = residual_jax # residual is the primal argument
223
+
224
+ def tf_vjp_fun(args_tf, ct_res_tf):
225
+ """Invoke TF gradient."""
226
+
227
+ # TF does not like us to watch non-float vars
228
+ def replace_non_float(arg_tf):
229
+ if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
230
+ return arg_tf
231
+ else:
232
+ # When watched, this will be ignored. When used in results it will
233
+ # result in a floating 0. gradient, which JAX will ignore (and
234
+ # replace it with a float0)
235
+ return tf.zeros((), dtype=tf.float32)
236
+
237
+ watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
238
+ with tf.GradientTape(persistent=True) as tape:
239
+ tape.watch(watched_args_tf)
240
+ res = callable_tf(*args_tf)
241
+
242
+ tf.nest.assert_same_structure(res, ct_res_tf)
243
+ dres_darg = tape.gradient(
244
+ tf.nest.map_structure(replace_non_float, res),
245
+ sources=watched_args_tf,
246
+ output_gradients=ct_res_tf,
247
+ unconnected_gradients=tf.UnconnectedGradients.ZERO)
248
+
249
+ dres_darg = tree_util.tree_map(
250
+ lambda x: x if x is None else tf.convert_to_tensor(x),
251
+ dres_darg,
252
+ )
253
+ tf.nest.assert_same_structure(dres_darg, args_tf)
254
+ return dres_darg
255
+
256
+ # Use call_tf to call the VJP function
257
+ ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
258
+ # We must make the float0s that JAX expects
259
+ def fix_float0(arg_jax, ct_arg_jax):
260
+ arg_dtype = dtypes.result_type(arg_jax) # May be scalar
261
+ ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
262
+ if ct_arg_dtype != ct_arg_jax.dtype:
263
+ return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
264
+ ct_arg_dtype))
265
+ return ct_arg_jax
266
+
267
+ ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
268
+ return ct_args_jax_fixed
269
+
270
+ make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
271
+ return util.wraps(callable_tf)(make_call)
272
+
273
+
274
+ def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
275
+ # Check that the TF function returns values of expected types. This
276
+ # improves error reporting, preventing hard-to-diagnose errors downstream
277
+ try:
278
+ jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
279
+ except Exception as e:
280
+ msg = ("The called TF function returns a result that is not "
281
+ f"convertible to JAX: {r_tf}.")
282
+ raise ValueError(msg) from e
283
+
284
+ if r_aval is None:
285
+ return r_tf
286
+ # We convert to TF type, and canonicalize to 32-bit if necessary
287
+ r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
288
+ # Checking shapes is trickier in presence of dynamic shapes. I wish we could
289
+ # check at runtime that the returned shape matches the declared shape. I wish
290
+ # that tf.ensure_shape did this, but it can only take shapes that contain None
291
+ # not computed shapes. However, in eager mode we should be able to resolve
292
+ # the declared shapes to constants and we get better checking.
293
+ if tf.executing_eagerly():
294
+ r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
295
+ else:
296
+ r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
297
+ # We do as much checking as we can here, instead of relying on tf.ensure_shape
298
+ # because the latter gives different errors in eager vs. compiled mode.
299
+ # TODO(b/279454591): This strange error is from TF. Eager function suppose
300
+ # return tf Val with concrete shape but not. Here we change exception to warn
301
+ # and bypass it. This case need revisit on TF side.
302
+ try:
303
+ _ = len(r_tf.shape)
304
+ except ValueError as e:
305
+ msg = (
306
+ "The shape check test cannot be performed because the shape of the"
307
+ "`r_tf` tensor cannot be obtained."
308
+ f"r_tf = {r_tf}, r_aval = {r_aval}"
309
+ )
310
+ msg += str(e)
311
+ logging.warning(msg)
312
+ return r_tf
313
+ if (r_tf.dtype != r_aval_dtype_tf or
314
+ len(r_tf.shape) != len(r_aval_shape_tf) or
315
+ any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
316
+ for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
317
+ msg = ("The shapes or dtypes returned by the TensorFlow function "
318
+ "do not match the declared output_shape_dtype:\n"
319
+ f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
320
+ raise ValueError(msg)
321
+ # At this point tf.ensure_shape does not do much, it should never throw an
322
+ # error, albeit it may refine the shape a bit.
323
+ return tf.ensure_shape(r_tf, r_aval_shape_tf)
324
+
325
+
326
+ call_tf_p = core.Primitive("call_tf")
327
+ call_tf_p.multiple_results = True
328
+
329
+ # The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
330
+ def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
331
+ # On GPU we use dlpack to avoid copies of data to the host.
332
+ def _arg_jax_to_tf(arg_jax):
333
+ if (isinstance(arg_jax, jax.Array) and
334
+ list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
335
+ arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
336
+ arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
337
+ return tf.experimental.dlpack.from_dlpack(arg_dlpack)
338
+ # The following avoids copies to the host on CPU, always for Array
339
+ # and even for ndarray if they are sufficiently aligned.
340
+ # TODO(necula): on TPU this copies to the host!
341
+ if getattr(arg_jax, 'dtype', None) == dtypes.float0:
342
+ return tf.zeros(shape=arg_jax.shape,
343
+ dtype=jax2tf_internal._tf_np_dtype_for_float0)
344
+ return tf.constant(np.asarray(arg_jax))
345
+
346
+ args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
347
+ with jax2tf_internal.inside_call_tf():
348
+ # Call in TF eager mode
349
+ res_tf_flat = callable_flat_tf(*args_tf_flat)
350
+
351
+ def _res_tf_to_jax(res_tf: TfVal):
352
+ res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
353
+ if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES:
354
+ res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
355
+ res_jax_platform = res_tf_platform.lower()
356
+ if res_jax_platform in _DLPACK_PLATFORMS:
357
+ res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
358
+ return jax.dlpack.from_dlpack(res_dlpack)
359
+
360
+ # When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
361
+ # To handle this special case, we create a numpy copy.
362
+ if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
363
+ return jax.device_put(jnp.array(res_tf.numpy()))
364
+ else:
365
+ return jax.device_put(np.asarray(res_tf))
366
+
367
+ return list(map(_res_tf_to_jax, res_tf_flat))
368
+
369
+
370
+ call_tf_p.def_impl(_call_tf_impl)
371
+
372
+ @functools.lru_cache(maxsize=128)
373
+ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
374
+ with jax2tf_internal.inside_call_tf():
375
+ return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
376
+
377
+
378
+ # Mark the effectful instances of call_tf
379
+ @dataclasses.dataclass(frozen=True)
380
+ class CallTfEffect(effects.Effect):
381
+ __str__ = lambda _: "CallTfEffect"
382
+
383
+ call_tf_effect = CallTfEffect()
384
+
385
+ effects.lowerable_effects.add_type(CallTfEffect)
386
+ effects.control_flow_allowed_effects.add_type(CallTfEffect)
387
+ effects.remat_allowed_effects.add_type(CallTfEffect)
388
+ effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
389
+
390
+
391
+ class CallTfOrderedEffect(effects.Effect):
392
+ __str__ = lambda _: "CallTfOrderedEffect"
393
+
394
+
395
+ call_tf_ordered_effect = CallTfOrderedEffect()
396
+
397
+ effects.lowerable_effects.add_type(CallTfOrderedEffect)
398
+ effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
399
+ effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
400
+ effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
401
+ effects.ordered_effects.add_type(CallTfOrderedEffect)
402
+ effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
403
+
404
+
405
+ def _call_tf_abstract_eval(
406
+ *args_flat_avals,
407
+ function_flat_tf,
408
+ args_flat_sig_tf,
409
+ has_side_effects,
410
+ ordered,
411
+ output_avals,
412
+ call_tf_graph,
413
+ **__,
414
+ ):
415
+ # Called only when we form a Jaxpr, i.e., under jit, scan, etc.
416
+ effects = set()
417
+ if ordered:
418
+ effects.add(call_tf_ordered_effect)
419
+ elif has_side_effects:
420
+ effects.add(call_tf_effect)
421
+
422
+ # If no output_avals is given, then we ask TF to infer the output shapes.
423
+ # We call this even if output_avals is given because it will ensure that
424
+ # callable_flat_tf is called. Since _get_concrete_function_tf is cached
425
+ # there is a small cost of calling it more often than needed.
426
+ concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
427
+ args_flat_sig_tf)
428
+
429
+ # In the case that the tf.function has no return value
430
+ if len(concrete_function_flat_tf.outputs) == 0:
431
+ return (), effects
432
+
433
+ if output_avals is not None:
434
+ return output_avals, effects
435
+
436
+ def is_fully_known_shape(s):
437
+ return s.rank is not None and all(d is not None for d in s)
438
+
439
+ if all(is_fully_known_shape(s)
440
+ for s in concrete_function_flat_tf.output_shapes):
441
+ avals_from_tf = tuple(
442
+ # We convert to JAX type, and canonicalize to 32-bit if necessary
443
+ core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
444
+ for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
445
+ concrete_function_flat_tf.output_shapes))
446
+ return avals_from_tf, effects
447
+
448
+ msg = ("call_tf cannot call functions whose output has dynamic shape. "
449
+ f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
450
+ "Consider using the `output_shape_dtype` argument to call_tf. "
451
+ "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
452
+ " for a discussion.")
453
+ raise ValueError(msg)
454
+
455
+
456
+ call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
457
+
458
+
459
+ def _call_tf_lowering(
460
+ ctx: mlir.LoweringRuleContext,
461
+ *args_op,
462
+ platform,
463
+ function_flat_tf,
464
+ args_flat_sig_tf,
465
+ has_side_effects,
466
+ ordered,
467
+ call_tf_graph,
468
+ output_avals,
469
+ **_,
470
+ ):
471
+ # We use the same TF lowering device as for the embedding JAX computation.
472
+ # One example when this is needed is when the code refers to variables on one
473
+ # device. Or, for sharding annotations (only supported on TPU).
474
+
475
+ if platform in ["cpu", "tpu"]:
476
+ tf_platform = platform.upper()
477
+ elif platform == "cuda":
478
+ tf_platform = "GPU"
479
+ else:
480
+ raise ValueError("platform {platform} not supported")
481
+
482
+ concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
483
+
484
+ captured_inputs = []
485
+ if concrete_function_flat_tf.captured_inputs:
486
+ # The function uses either captured variables or tensors.
487
+ msg = (
488
+ "call_tf works best with a TensorFlow function that does not capture "
489
+ "variables or tensors from the context. "
490
+ "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
491
+ f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
492
+ logging.warning(msg)
493
+ for inp in concrete_function_flat_tf.captured_inputs:
494
+ if inp.dtype == tf.resource: # A variable; lookup by handle
495
+ inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
496
+ assert len(inp_vars) == 1, f"Found {inp_vars}"
497
+ captured_inputs.append(inp_vars[0])
498
+ else:
499
+ captured_inputs.append(inp)
500
+
501
+ captured_ops = tuple(
502
+ mlir.ir_constant(np.asarray(inp))
503
+ for inp in captured_inputs
504
+ )
505
+
506
+ if call_tf_graph:
507
+ with jax2tf_internal.inside_call_tf():
508
+ return emit_tf_embedded_graph_custom_call(
509
+ ctx,
510
+ concrete_function_flat_tf,
511
+ tuple(args_op) + captured_ops,
512
+ has_side_effects,
513
+ ordered,
514
+ output_avals,
515
+ )
516
+
517
+ def convert_to_spec(x):
518
+ if isinstance(x, tf.TensorSpec):
519
+ return x
520
+ else:
521
+ return tf.TensorSpec.from_tensor(x)
522
+
523
+ args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
524
+
525
+ with jax2tf_internal.inside_call_tf():
526
+ try:
527
+ func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
528
+ *args_tf_flat
529
+ )(stage="hlo_serialized", platform_name=tf_platform)
530
+ except Exception as e:
531
+ msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
532
+ "\ncall_tf can used " +
533
+ "in a staged context (under jax.jit, lax.scan, etc.) only with " +
534
+ "compilable functions with static output shapes.\n" +
535
+ "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
536
+ "\n\nCaught TensorFlow exception: " + str(e))
537
+ raise ValueError(msg) from e
538
+
539
+ xla_comp = xla_client.XlaComputation(func_tf_hlo)
540
+
541
+ # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
542
+ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray:
543
+ if not res_shape.is_static():
544
+ msg = ("Compiled TensorFlow function has dynamic output shape " +
545
+ f"{res_shape}. call_tf can used " +
546
+ "in a staged context (under jax.jit, lax.scan, etc.) only with " +
547
+ "compilable functions with static output shapes. " +
548
+ "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.")
549
+ raise ValueError(msg)
550
+
551
+ res_dtype = res_shape.numpy_dtype()
552
+ jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
553
+ return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)
554
+
555
+ result_shape = xla_comp.program_shape().result_shape()
556
+ if not result_shape.is_tuple():
557
+ # TF does not wrap singletons as tuples, but JAX expects tuples because
558
+ # call_tf is a multiple_results primitive.
559
+ result_shapes = (result_shape,)
560
+ else:
561
+ result_shapes = result_shape.tuple_shapes() # type: ignore
562
+
563
+ result_avals = tuple(map(canonical_res_aval, result_shapes))
564
+
565
+ submodule = mlir.xla_computation_to_mlir_module(xla_comp)
566
+ symtab = ir.SymbolTable(submodule.operation)
567
+ callee_result_types = symtab["main"].type.results
568
+ fn = mlir.merge_mlir_modules(ctx.module_context.module,
569
+ f"call_tf_{function_flat_tf.name}",
570
+ submodule,
571
+ dst_symtab=ctx.module_context.symbol_table)
572
+ call = func_dialect.CallOp(callee_result_types,
573
+ ir.FlatSymbolRefAttr.get(fn),
574
+ tuple(args_op) + captured_ops)
575
+ if result_shape.is_tuple():
576
+ flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
577
+ for i in range(len(result_shapes))]
578
+ else:
579
+ flat_results = call.results
580
+
581
+ if ordered:
582
+ raise NotImplementedError(
583
+ "ordered=True is not supported in the jitted context without"
584
+ " `call_tf_graph=True`"
585
+ )
586
+
587
+ outputs = []
588
+ for op, res_aval, res_shape in zip(flat_results, result_avals,
589
+ result_shapes):
590
+ if res_aval.dtype != res_shape.numpy_dtype():
591
+ op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
592
+ outputs.append(op)
593
+ return outputs
594
+
595
+
596
+ def _register_call_lowering(platform):
597
+ mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
598
+ platform=platform),
599
+ platform=platform)
600
+ for platform in ("cpu", "cuda", "tpu"):
601
+ _register_call_lowering(platform)
602
+
603
+ # Support the call_tf under jax2tf.convert in eager mode
604
+ def _jax2tf_call_tf(*args: TfVal,
605
+ callable_flat_tf: Callable,
606
+ **_) -> TfVal:
607
+ with jax2tf_internal.inside_call_tf():
608
+ res_tf_flat = callable_flat_tf(*args)
609
+ return res_tf_flat
610
+
611
+ jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
612
+
613
+
614
+ def emit_tf_embedded_graph_custom_call(
615
+ ctx: mlir.LoweringRuleContext,
616
+ concrete_function_flat_tf,
617
+ operands: Sequence[ir.Value],
618
+ has_side_effects,
619
+ ordered,
620
+ output_avals,
621
+ ):
622
+ """Emits a custom call referencing a tf.Graph embedding of the TF function.
623
+
624
+ All call_tf called function information is stored in tf.metadata.
625
+ This includes:
626
+ (1) The called function name: This name will be used by the runtime to execute
627
+ the callback.
628
+ (2) The called function index in the XLACallModule `function_list` attribute.
629
+ """
630
+ call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
631
+ if call_tf_concrete_function_list is None:
632
+ raise ValueError(
633
+ "call_tf_graph=True only support exporting by jax2tf.convert currently."
634
+ )
635
+ # TODO(necula): It is dangerous to modify global state when lowering because
636
+ # there are a number of lowering caches that only cache the StableHLO.
637
+ # See call_tf_test.py:test_multi_platform_call_tf_graph.
638
+ called_index = add_to_call_tf_concrete_function_list(
639
+ concrete_function_flat_tf, call_tf_concrete_function_list)
640
+ tf_backend_config = {
641
+ "has_token_input_output": ir.BoolAttr.get(ordered),
642
+ "called_index": mlir.i64_attr(called_index),
643
+ }
644
+ result_avals = ctx.avals_out if ctx.avals_out is not None else ()
645
+
646
+ operands = list(operands)
647
+ result_types = list(
648
+ util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
649
+ )
650
+ if ordered:
651
+ operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
652
+ result_types.insert(0, mlir.token_type()[0])
653
+
654
+ custom_call = hlo.CustomCallOp(
655
+ result_types,
656
+ operands,
657
+ call_target_name=ir.StringAttr.get("tf.call_tf_function"),
658
+ has_side_effect=ir.BoolAttr.get(has_side_effects),
659
+ api_version=mlir.i32_attr(2),
660
+ called_computations=ir.ArrayAttr.get([]),
661
+ backend_config=ir.StringAttr.get(""),
662
+ )
663
+ # Store TF metadata in unregistered attribute
664
+ custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
665
+ tf_backend_config
666
+ )
667
+
668
+ results = list(custom_call.results)
669
+ if ordered:
670
+ token = results.pop(0)
671
+ ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
672
+
673
+ return results
674
+
675
+
676
+ def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
677
+ try:
678
+ called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
679
+ except ValueError:
680
+ called_index = len(call_tf_concrete_function_list)
681
+ call_tf_concrete_function_list.append(concrete_tf_fn)
682
+ return called_index
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Demonstrates reuse of a jax2tf model in Keras.
15
+
16
+ Includes the flags from saved_model_main.py.
17
+
18
+ See README.md.
19
+ """
20
+ import logging
21
+ from absl import app
22
+ from absl import flags
23
+ from jax.experimental.jax2tf.examples import mnist_lib
24
+ from jax.experimental.jax2tf.examples import saved_model_main
25
+ import tensorflow as tf
26
+ import tensorflow_datasets as tfds # type: ignore
27
+ import tensorflow_hub as hub # type: ignore
28
+
29
+
30
+ FLAGS = flags.FLAGS
31
+
32
+
33
+ def main(_):
34
+ FLAGS.model_classifier_layer = False # We only need the features
35
+ # Train the model and save the feature extractor
36
+ saved_model_main.train_and_save()
37
+
38
+ tf_accelerator, _ = saved_model_main.tf_accelerator_and_tolerances()
39
+ feature_model_dir = saved_model_main.savedmodel_dir()
40
+
41
+ # With Keras, we use the tf.distribute.OneDeviceStrategy as the high-level
42
+ # analogue of the tf.device(...) placement seen above.
43
+ # It works on CPU, GPU and TPU.
44
+ # Actual high-performance training would use the appropriately replicated
45
+ # TF Distribution Strategy.
46
+ strategy = tf.distribute.OneDeviceStrategy(tf_accelerator)
47
+ with strategy.scope():
48
+ images = tf.keras.layers.Input(
49
+ mnist_lib.input_shape, batch_size=mnist_lib.train_batch_size)
50
+ keras_feature_extractor = hub.KerasLayer(feature_model_dir, trainable=True)
51
+ features = keras_feature_extractor(images)
52
+ predictor = tf.keras.layers.Dense(10, activation="softmax")
53
+ predictions = predictor(features)
54
+ keras_model = tf.keras.Model(images, predictions)
55
+
56
+ keras_model.compile(
57
+ loss=tf.keras.losses.categorical_crossentropy,
58
+ optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
59
+ metrics=["accuracy"])
60
+ logging.info(keras_model.summary())
61
+
62
+ train_ds = mnist_lib.load_mnist(
63
+ tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
64
+ test_ds = mnist_lib.load_mnist(
65
+ tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
66
+ keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
67
+
68
+ if saved_model_main.SHOW_IMAGES.value:
69
+ mnist_lib.plot_images(
70
+ test_ds,
71
+ 1,
72
+ 5,
73
+ f"Keras inference with reuse of {saved_model_main.model_description()}",
74
+ inference_fn=lambda images: keras_model(tf.convert_to_tensor(images)))
75
+
76
+
77
+ if __name__ == "__main__":
78
+ app.run(main)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/keras_reuse_main_test.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from absl import flags
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+ import jax
20
+ from jax._src import test_util as jtu
21
+
22
+ from jax.experimental.jax2tf.examples import keras_reuse_main
23
+ from jax.experimental.jax2tf.tests import tf_test_util
24
+
25
+ jax.config.parse_flags_with_absl()
26
+ FLAGS = flags.FLAGS
27
+
28
+
29
+ class KerasReuseMainTest(tf_test_util.JaxToTfTestCase):
30
+
31
+ def setUp(self):
32
+ super().setUp()
33
+ FLAGS.model_path = os.path.join(absltest.get_default_test_tmpdir(),
34
+ "saved_models")
35
+ FLAGS.num_epochs = 1
36
+ FLAGS.test_savedmodel = True
37
+ FLAGS.mock_data = True
38
+ FLAGS.show_images = False
39
+ FLAGS.serving_batch_size = 1
40
+
41
+ @parameterized.named_parameters(
42
+ dict(testcase_name=f"_{model}", model=model)
43
+ for model in ["mnist_pure_jax", "mnist_flax"])
44
+ def test_keras_reuse(self, model="mnist_pure_jax"):
45
+ FLAGS.model = model
46
+ keras_reuse_main.main(None)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ absltest.main(testLoader=jtu.JaxTestLoader())
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/mnist_lib.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Definitions of two versions of MNIST (model and training code ).
15
+
16
+ One definition uses pure JAX (for those who prefer an example with fewer
17
+ moving parts, at the expense of code size), and another using Flax.
18
+
19
+ See README.md for how these are used.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from collections.abc import Sequence
25
+ import functools
26
+ import logging
27
+ import re
28
+ import time
29
+ from typing import Any, Callable, Optional
30
+ from absl import flags
31
+
32
+ import flax
33
+ from flax import linen as nn
34
+
35
+ import jax
36
+ import jax.numpy as jnp
37
+
38
+ from matplotlib import pyplot as plt
39
+ import numpy as np
40
+ import optax
41
+ import tensorflow as tf
42
+ import tensorflow_datasets as tfds # type: ignore
43
+
44
+ _MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
45
+ "Use fake data, for testing.")
46
+
47
+ #### Model parameters
48
+
49
+ # For fun, let's use different batch sizes for training and for evaluation.
50
+ train_batch_size = 128
51
+ test_batch_size = 16
52
+
53
+ # Define common parameters for both the JAX and the Flax models.
54
+ input_shape = (28, 28, 1) # Excluding batch_size
55
+ layer_sizes = [784, 512, 512, 10] # 10 is the number of classes
56
+ param_scale = 0.1
57
+ step_size = 0.001
58
+
59
+
60
+ def load_mnist(split: tfds.Split, batch_size: int):
61
+ """Loads either training or test MNIST data.
62
+
63
+ Args:
64
+ split: either tfds.Split.TRAIN or tfds.Split.TEST.
65
+
66
+ Returns:
67
+ an iterator with pairs (images, labels). The images have shape
68
+ (B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
69
+ """
70
+ if _MOCK_DATA.value:
71
+ with tfds.testing.mock_data(num_examples=batch_size):
72
+ try:
73
+ ds = tfds.load("mnist", split=split)
74
+ except Exception as e:
75
+ m = re.search(r'metadata files were not found in (.+/)mnist/', str(e))
76
+ if m:
77
+ msg = ("TFDS mock_data is missing the mnist metadata files. Run the "
78
+ "`saved_model_main.py` binary and see where TFDS downloads "
79
+ "the mnist data set (typically ~/tensorflow_datasets/mnist). "
80
+ f"Copy the `mnist` directory to {m.group(1)} and re-run the test")
81
+ raise ValueError(msg) from e
82
+ else:
83
+ raise e
84
+ else:
85
+ ds = tfds.load("mnist", split=split)
86
+
87
+ def _prepare_example(x):
88
+ image = tf.cast(x["image"], tf.float32) / 255.0
89
+ label = tf.one_hot(x["label"], 10)
90
+ return (image, label)
91
+
92
+ ds = ds.map(_prepare_example)
93
+ # drop_remainder=True is important for use with Keras
94
+ ds = ds.cache().shuffle(1000).batch(batch_size, drop_remainder=True)
95
+ return ds
96
+
97
+
98
+ class PureJaxMNIST:
99
+ """An MNIST model written using pure JAX.
100
+
101
+ There is an option for the model to skip the classifier layer, for
102
+ demonstrating reuse of the classifier-less model into a larger model.
103
+ See README.md.
104
+ """
105
+
106
+ name = "mnist_pure_jax"
107
+
108
+ @staticmethod
109
+ def predict(params: Sequence[tuple[Any, Any]], inputs, with_classifier=True):
110
+ """The prediction function.
111
+
112
+ Args:
113
+ params: a list with pairs of weights and biases for each layer.
114
+ inputs: the batch of images (B, 28, 28, 1)
115
+ with_classifier: whether to include the classifier layer.
116
+
117
+ Returns:
118
+ either the predictions (B, 10) if with_classifier=True, or the
119
+ final set of logits of shape (B, 512).
120
+ """
121
+ x = inputs.reshape((inputs.shape[0], -1)) # flatten to f32[B, 784]
122
+ for w, b in params[:-1]:
123
+ x = jnp.dot(x, w) + b
124
+ x = jnp.tanh(x)
125
+
126
+ if not with_classifier:
127
+ return x
128
+ final_w, final_b = params[-1]
129
+ logits = jnp.dot(x, final_w) + final_b
130
+ return logits - jax.scipy.special.logsumexp(
131
+ logits, axis=1, keepdims=True)
132
+
133
+ @staticmethod
134
+ def loss(params, inputs, labels):
135
+ predictions = PureJaxMNIST.predict(params, inputs, with_classifier=True)
136
+ return -jnp.mean(jnp.sum(predictions * labels, axis=1))
137
+
138
+ @staticmethod
139
+ def accuracy(predict: Callable, params, dataset):
140
+
141
+ @jax.jit
142
+ def _per_batch(inputs, labels):
143
+ target_class = jnp.argmax(labels, axis=1)
144
+ predicted_class = jnp.argmax(predict(params, inputs), axis=1)
145
+ return jnp.mean(predicted_class == target_class)
146
+
147
+ batched = [
148
+ _per_batch(inputs, labels) for inputs, labels in tfds.as_numpy(dataset)
149
+ ]
150
+ return jnp.mean(jnp.stack(batched))
151
+
152
+ @staticmethod
153
+ def update(params, inputs, labels):
154
+ grads = jax.grad(PureJaxMNIST.loss)(params, inputs, labels)
155
+ return [(w - step_size * dw, b - step_size * db)
156
+ for (w, b), (dw, db) in zip(params, grads)]
157
+
158
+ @staticmethod
159
+ def train(train_ds, test_ds, num_epochs, with_classifier=True):
160
+ """Trains a pure JAX MNIST predictor.
161
+
162
+ Returns:
163
+ a tuple with two elements:
164
+ - a predictor function with signature "(Params, ImagesBatch) ->
165
+ Predictions".
166
+ If `with_classifier=False` then the output of the predictor function
167
+ is the last layer of logits.
168
+ - the parameters "Params" for the predictor function
169
+ """
170
+ rng = jax.random.PRNGKey(0)
171
+ params = [(param_scale * jax.random.normal(rng, (m, n)),
172
+ param_scale * jax.random.normal(rng, (n,)))
173
+ for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
174
+
175
+ for epoch in range(num_epochs):
176
+ start_time = time.time()
177
+ for inputs, labels in tfds.as_numpy(train_ds):
178
+ params = jax.jit(PureJaxMNIST.update)(params, inputs, labels)
179
+ epoch_time = time.time() - start_time
180
+ train_acc = PureJaxMNIST.accuracy(PureJaxMNIST.predict, params, train_ds)
181
+ test_acc = PureJaxMNIST.accuracy(PureJaxMNIST.predict, params, test_ds)
182
+ logging.info("%s: Epoch %d in %0.2f sec", PureJaxMNIST.name, epoch,
183
+ epoch_time)
184
+ logging.info("%s: Training set accuracy %0.2f%%", PureJaxMNIST.name,
185
+ 100. * train_acc)
186
+ logging.info("%s: Test set accuracy %0.2f%%", PureJaxMNIST.name,
187
+ 100. * test_acc)
188
+
189
+ return (functools.partial(
190
+ PureJaxMNIST.predict, with_classifier=with_classifier), params)
191
+
192
+
193
+ class FlaxMNIST:
194
+ """An MNIST model using Flax."""
195
+
196
+ name = "mnist_flax"
197
+
198
+ class Module(nn.Module):
199
+ """A simple CNN model for MNIST.
200
+
201
+ There is an option for the model to skip the classifier layer, for
202
+ demonstrating reuse of the classifier-less model into a larger model.
203
+ See README.md.
204
+ """
205
+
206
+ @nn.compact
207
+ def __call__(self, x, with_classifier=True):
208
+ x = nn.Conv(features=32, kernel_size=(3, 3))(x)
209
+ x = nn.relu(x)
210
+ x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
211
+ x = nn.Conv(features=64, kernel_size=(3, 3))(x)
212
+ x = nn.relu(x)
213
+ x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
214
+ x = x.reshape((x.shape[0], -1)) # flatten
215
+ x = nn.Dense(features=256)(x)
216
+ x = nn.relu(x)
217
+ if not with_classifier:
218
+ return x
219
+ x = nn.Dense(features=10)(x)
220
+ x = nn.log_softmax(x)
221
+ return x
222
+
223
+ # Create the model and save it
224
+ model = Module()
225
+
226
+ @staticmethod
227
+ def predict(params, inputs, with_classifier=True):
228
+ return FlaxMNIST.model.apply({"params": params},
229
+ inputs,
230
+ with_classifier=with_classifier)
231
+
232
+ @staticmethod
233
+ def loss(params, inputs, labels): # Same as the pure JAX example
234
+ # Must use the classifier layer because the labels are classes
235
+ predictions = FlaxMNIST.predict(params, inputs, with_classifier=True)
236
+ return -jnp.mean(jnp.sum(predictions * labels, axis=1))
237
+
238
+ @staticmethod
239
+ def update(tx, params, opt_state, inputs, labels):
240
+ grad = jax.grad(FlaxMNIST.loss)(params, inputs, labels)
241
+ updates, opt_state = tx.update(grad, opt_state)
242
+ params = optax.apply_updates(params, updates)
243
+ return params, opt_state
244
+
245
+ @staticmethod
246
+ def train(train_ds, test_ds, num_epochs, with_classifier=True):
247
+ """Trains a pure JAX MNIST predictor.
248
+
249
+ Returns:
250
+ a tuple with two elements:
251
+ - a predictor function with signature "(Params, ImagesBatch) ->
252
+ Predictions".
253
+ If `with_classifier=False` then the output of the predictor function
254
+ is the last layer of logits.
255
+ - the parameters "Params" for the predictor function
256
+ """
257
+ rng = jax.random.PRNGKey(0)
258
+ momentum_mass = 0.9
259
+
260
+ init_shape = jnp.ones((1,) + input_shape, jnp.float32)
261
+ params = FlaxMNIST.model.init(rng, init_shape)["params"]
262
+ tx = optax.sgd(learning_rate=step_size, momentum=momentum_mass)
263
+ opt_state = tx.init(params)
264
+
265
+ for epoch in range(num_epochs):
266
+ start_time = time.time()
267
+ for inputs, labels in tfds.as_numpy(train_ds):
268
+ params, opt_state = jax.jit(FlaxMNIST.update,
269
+ static_argnums=0)(tx, params, opt_state,
270
+ inputs, labels)
271
+ epoch_time = time.time() - start_time
272
+ # Same accuracy function as for the pure JAX example
273
+ train_acc = PureJaxMNIST.accuracy(FlaxMNIST.predict, params,
274
+ train_ds)
275
+ test_acc = PureJaxMNIST.accuracy(FlaxMNIST.predict, params,
276
+ test_ds)
277
+ logging.info("%s: Epoch %d in %0.2f sec", FlaxMNIST.name, epoch,
278
+ epoch_time)
279
+ logging.info("%s: Training set accuracy %0.2f%%", FlaxMNIST.name,
280
+ 100. * train_acc)
281
+ logging.info("%s: Test set accuracy %0.2f%%", FlaxMNIST.name,
282
+ 100. * test_acc)
283
+
284
+ # See discussion in README.md for packaging Flax models for conversion
285
+ predict_fn = functools.partial(FlaxMNIST.predict,
286
+ with_classifier=with_classifier)
287
+ return (predict_fn, params)
288
+
289
+
290
+ def plot_images(ds,
291
+ nr_rows: int,
292
+ nr_cols: int,
293
+ title: str,
294
+ inference_fn: Callable | None = None):
295
+ """Plots a grid of images with their predictions.
296
+
297
+ Params:
298
+ ds: a tensorflow dataset from where to pick the images and labels.
299
+ nr_rows, nr_cols: the size of the grid to plot
300
+ title: the title of the plot
301
+ inference_fn: if None then print the existing label, else use this function
302
+ on the batch of images to produce a batch of inference results, which
303
+ get printed.
304
+ inference_batch_size: the size of the batch of images passed to
305
+ `inference_fn`.
306
+ """
307
+ count = nr_rows * nr_cols
308
+ fig = plt.figure(figsize=(8., 4.), num=title)
309
+ # Get the first batch
310
+ (images, labels), = list(tfds.as_numpy(ds.take(1)))
311
+ if inference_fn:
312
+ inferred_labels = inference_fn(images)
313
+ for i, image in enumerate(images[:count]):
314
+ digit = fig.add_subplot(nr_rows, nr_cols, i + 1)
315
+ if inference_fn:
316
+ digit_title = f"infer: {np.argmax(inferred_labels[i])}\n"
317
+ else:
318
+ digit_title = ""
319
+ digit_title += f"label: {np.argmax(labels[i])}"
320
+ digit.set_title(digit_title)
321
+ plt.imshow(
322
+ (np.reshape(image, (28, 28)) * 255).astype(np.uint8),
323
+ interpolation="nearest")
324
+ plt.show()
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_lib.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Defines a helper function for creating a SavedModel from a jax2tf trained model.
15
+
16
+ This has been tested with TensorFlow Hub, TensorFlow JavaScript,
17
+ and TensorFlow Serving.
18
+
19
+ Note that the code in this file is provided only as an example. The functions
20
+ generated by `jax2tf.convert` are standard TensorFlow functions and you can
21
+ save them in a SavedModel using standard TensorFlow code. This decoupling
22
+ of jax2tf from SavedModel is important, because it allows the user to have full
23
+ control over what metadata is saved in the SavedModel. Please copy and
24
+ customize this function as needed.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from collections.abc import Sequence
30
+ from typing import Any, Callable, Optional, Union
31
+
32
+ from jax.experimental import jax2tf
33
+ import tensorflow as tf
34
+
35
+
36
+ def convert_and_save_model(
37
+ jax_fn: Callable[[Any, Any], Any],
38
+ params,
39
+ model_dir: str,
40
+ *,
41
+ input_signatures: Sequence[tf.TensorSpec],
42
+ polymorphic_shapes: str | None = None,
43
+ with_gradient: bool = False,
44
+ enable_xla: bool = True,
45
+ compile_model: bool = True,
46
+ saved_model_options: tf.saved_model.SaveOptions | None = None):
47
+ """Convert a JAX function and saves a SavedModel.
48
+
49
+ This is an example, we do not promise backwards compatibility for this code.
50
+ For serious uses, please copy and expand it as needed (see note at the top
51
+ of the module).
52
+
53
+ Use this function if you have a trained ML model that has both a prediction
54
+ function and trained parameters, which you want to save separately from the
55
+ function graph as variables (e.g., to avoid limits on the size of the
56
+ GraphDef, or to enable fine-tuning.) If you don't have such parameters,
57
+ you can still use this library function but probably don't need it
58
+ (see jax2tf/README.md for some simple examples).
59
+
60
+ In order to use this wrapper you must first convert your model to a function
61
+ with two arguments: the parameters and the input on which you want to do
62
+ inference. Both arguments may be np.ndarray or (nested)
63
+ tuples/lists/dictionaries thereof.
64
+
65
+ See the README.md for a discussion of how to prepare Flax and Haiku models.
66
+
67
+ Args:
68
+ jax_fn: a JAX function taking two arguments, the parameters and the inputs.
69
+ Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
70
+ params: the parameters, to be used as first argument for `jax_fn`. These
71
+ must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
72
+ saved as the variables of the SavedModel.
73
+ model_dir: the directory where the model should be saved.
74
+ input_signatures: the input signatures for the second argument of `jax_fn`
75
+ (the input). A signature must be a `tensorflow.TensorSpec` instance, or a
76
+ (nested) tuple/list/dictionary thereof with a structure matching the
77
+ second argument of `jax_fn`. The first input_signature will be saved as
78
+ the default serving signature. The additional signatures will be used
79
+ only to ensure that the `jax_fn` is traced and converted to TF for the
80
+ corresponding input shapes.
81
+ with_gradient: the value to use for the `with_gradient` parameter for
82
+ `jax2tf.convert`.
83
+ enable_xla: the value to use for the `enable_xla` parameter for
84
+ `jax2tf.convert`.
85
+ compile_model: use TensorFlow jit_compiler on the SavedModel. This
86
+ is needed if the SavedModel will be used for TensorFlow serving.
87
+ polymorphic_shapes: if given then it will be used as the
88
+ `polymorphic_shapes` argument to jax2tf.convert for the second parameter of
89
+ `jax_fn`. In this case, a single `input_signatures` is supported, and
90
+ should have `None` in the polymorphic dimensions.
91
+ saved_model_options: options to pass to savedmodel.save.
92
+ """
93
+ if not input_signatures:
94
+ raise ValueError("At least one input_signature must be given")
95
+ if polymorphic_shapes is not None:
96
+ if len(input_signatures) > 1:
97
+ raise ValueError("For shape-polymorphic conversion a single "
98
+ "input_signature is supported.")
99
+ tf_fn = jax2tf.convert(
100
+ jax_fn,
101
+ with_gradient=with_gradient,
102
+ polymorphic_shapes=[None, polymorphic_shapes],
103
+ enable_xla=enable_xla)
104
+
105
+ # Create tf.Variables for the parameters. If you want more useful variable
106
+ # names, you can use `tree.map_structure_with_path` from the `dm-tree` package
107
+ param_vars = tf.nest.map_structure(
108
+ lambda param: tf.Variable(param, trainable=with_gradient),
109
+ params)
110
+ tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
111
+ autograph=False,
112
+ jit_compile=compile_model)
113
+
114
+ signatures = {}
115
+ # This signature is needed for TensorFlow Serving use.
116
+ signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
117
+ tf_graph.get_concrete_function(input_signatures[0])
118
+ for input_signature in input_signatures[1:]:
119
+ # If there are more signatures, trace and cache a TF function for each one
120
+ tf_graph.get_concrete_function(input_signature)
121
+ wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
122
+ if with_gradient:
123
+ if not saved_model_options:
124
+ saved_model_options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
125
+ else:
126
+ saved_model_options.experimental_custom_gradients = True
127
+ tf.saved_model.save(wrapper, model_dir, signatures=signatures,
128
+ options=saved_model_options)
129
+
130
+
131
+ class _ReusableSavedModelWrapper(tf.train.Checkpoint):
132
+ """Wraps a function and its parameters for saving to a SavedModel.
133
+
134
+ Implements the interface described at
135
+ https://www.tensorflow.org/hub/reusable_saved_models.
136
+ """
137
+
138
+ def __init__(self, tf_graph, param_vars):
139
+ """Args:
140
+
141
+ tf_graph: a tf.function taking one argument (the inputs), which can be
142
+ be tuples/lists/dictionaries of np.ndarray or tensors. The function
143
+ may have references to the tf.Variables in `param_vars`.
144
+ param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
145
+ to be saved as the variables of the SavedModel.
146
+ """
147
+ super().__init__()
148
+ # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
149
+ self.variables = tf.nest.flatten(param_vars)
150
+ self.trainable_variables = [v for v in self.variables if v.trainable]
151
+ # If you intend to prescribe regularization terms for users of the model,
152
+ # add them as @tf.functions with no inputs to this list. Else drop this.
153
+ self.regularization_losses = []
154
+ self.__call__ = tf_graph
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Demonstrates training models and saving the result as a SavedModel.
15
+
16
+ By default, uses a pure JAX implementation of MNIST. There are flags to choose
17
+ a Flax CNN version of MNIST, or to skip the training and just test a
18
+ previously saved SavedModel. It is possible to save a batch-polymorphic
19
+ version of the model, or a model prepared for specific batch sizes.
20
+
21
+ Try --help to see all flags.
22
+
23
+ This file is used both as an executable, and as a library in two other examples.
24
+ See discussion in README.md.
25
+ """
26
+
27
+ import logging
28
+ import os
29
+
30
+ from absl import app
31
+ from absl import flags
32
+
33
+ from jax.experimental.jax2tf.examples import mnist_lib
34
+ from jax.experimental.jax2tf.examples import saved_model_lib
35
+
36
+ import numpy as np
37
+ import tensorflow as tf
38
+ import tensorflow_datasets as tfds # type: ignore
39
+
40
+ _MODEL = flags.DEFINE_enum(
41
+ "model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
42
+ "Which model to use.")
43
+ _MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
44
+ ("The model should include the classifier layer, or just "
45
+ "the last layer of logits. Set this to False when you "
46
+ "want to reuse the classifier-less model in a larger "
47
+ "model. See keras_reuse_main.py and README.md."))
48
+ _MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
49
+ "Path under which to save the SavedModel.")
50
+ _MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
51
+ ("The version number for the SavedModel. Needed for "
52
+ "serving, larger versions will take precedence"),
53
+ lower_bound=1)
54
+ _SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
55
+ "For what batch size to prepare the serving signature. "
56
+ "Use -1 for converting and saving with batch polymorphism.")
57
+ flags.register_validator(
58
+ "serving_batch_size",
59
+ lambda serving_batch_size: serving_batch_size > 0
60
+ or serving_batch_size == -1,
61
+ message="--serving_batch_size must be either -1 or a positive integer.",
62
+ )
63
+
64
+ _NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
65
+ "For how many epochs to train.",
66
+ lower_bound=1)
67
+ _GENERATE_MODEL = flags.DEFINE_boolean(
68
+ "generate_model", True,
69
+ "Train and save a new model. Otherwise, use an existing SavedModel.")
70
+ _COMPILE_MODEL = flags.DEFINE_boolean(
71
+ "compile_model", True,
72
+ "Enable TensorFlow jit_compiler for the SavedModel. This is "
73
+ "necessary if you want to use the model for TensorFlow serving.")
74
+ _SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
75
+ "Show details of saved SavedModel.")
76
+ SHOW_IMAGES = flags.DEFINE_boolean(
77
+ "show_images", False,
78
+ "Plot some sample images with labels and inference results.")
79
+ _TEST_SAVEDMODEL = flags.DEFINE_boolean(
80
+ "test_savedmodel", True,
81
+ "Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
82
+
83
+
84
+ def train_and_save():
85
+ logging.info("Loading the MNIST TensorFlow dataset")
86
+ train_ds = mnist_lib.load_mnist(
87
+ tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
88
+ test_ds = mnist_lib.load_mnist(
89
+ tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
90
+
91
+ if SHOW_IMAGES.value:
92
+ mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
93
+
94
+ the_model_class = pick_model_class()
95
+ model_dir = savedmodel_dir(with_version=True)
96
+
97
+ if _GENERATE_MODEL.value:
98
+ model_descr = model_description()
99
+ logging.info("Generating model for %s", model_descr)
100
+ (predict_fn, predict_params) = the_model_class.train(
101
+ train_ds,
102
+ test_ds,
103
+ num_epochs=_NUM_EPOCHS.value,
104
+ with_classifier=_MODEL_CLASSIFIER_LAYER.value)
105
+
106
+ if _SERVING_BATCH_SIZE.value == -1:
107
+ # Batch-polymorphic SavedModel
108
+ input_signatures = [
109
+ tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
110
+ ]
111
+ polymorphic_shapes = "(batch, ...)"
112
+ else:
113
+ input_signatures = [
114
+ # The first one will be the serving signature
115
+ tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
116
+ tf.float32),
117
+ tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
118
+ tf.float32),
119
+ tf.TensorSpec((mnist_lib.test_batch_size,) + mnist_lib.input_shape,
120
+ tf.float32),
121
+ ]
122
+ polymorphic_shapes = None
123
+
124
+ logging.info("Saving model for %s", model_descr)
125
+ saved_model_lib.convert_and_save_model(
126
+ predict_fn,
127
+ predict_params,
128
+ model_dir,
129
+ with_gradient=True,
130
+ input_signatures=input_signatures,
131
+ polymorphic_shapes=polymorphic_shapes,
132
+ compile_model=_COMPILE_MODEL.value)
133
+
134
+ if _TEST_SAVEDMODEL.value:
135
+ tf_accelerator, tolerances = tf_accelerator_and_tolerances()
136
+ with tf.device(tf_accelerator):
137
+ logging.info("Testing savedmodel")
138
+ pure_restored_model = tf.saved_model.load(model_dir)
139
+
140
+ if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
141
+ mnist_lib.plot_images(
142
+ test_ds,
143
+ 1,
144
+ 5,
145
+ f"Inference results for {model_descr}",
146
+ inference_fn=pure_restored_model)
147
+
148
+ test_input = np.ones(
149
+ (mnist_lib.test_batch_size,) + mnist_lib.input_shape,
150
+ dtype=np.float32)
151
+ np.testing.assert_allclose(
152
+ pure_restored_model(tf.convert_to_tensor(test_input)),
153
+ predict_fn(predict_params, test_input), **tolerances)
154
+
155
+ if _SHOW_MODEL.value:
156
+ def print_model(model_dir: str):
157
+ cmd = f"saved_model_cli show --all --dir {model_dir}"
158
+ print(cmd)
159
+ os.system(cmd)
160
+
161
+ print_model(model_dir)
162
+
163
+
164
+ def pick_model_class():
165
+ """Picks one of PureJaxMNIST or FlaxMNIST."""
166
+ if _MODEL.value == "mnist_pure_jax":
167
+ return mnist_lib.PureJaxMNIST
168
+ elif _MODEL.value == "mnist_flax":
169
+ return mnist_lib.FlaxMNIST
170
+ else:
171
+ raise ValueError(f"Unrecognized model: {_MODEL.value}")
172
+
173
+
174
+ def model_description() -> str:
175
+ """A short description of the picked model."""
176
+ res = pick_model_class().name
177
+ if not _MODEL_CLASSIFIER_LAYER.value:
178
+ res += " (features_only)"
179
+ return res
180
+
181
+
182
+ def savedmodel_dir(with_version: bool = True) -> str:
183
+ """The directory where we save the SavedModel."""
184
+ model_dir = os.path.join(
185
+ _MODEL_PATH.value,
186
+ _MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
187
+ )
188
+ if with_version:
189
+ model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
190
+ return model_dir
191
+
192
+
193
+ def tf_accelerator_and_tolerances():
194
+ """Picks the TF accelerator to use and the tolerances for numerical checks."""
195
+ tf_accelerator = (tf.config.list_logical_devices("TPU") +
196
+ tf.config.list_logical_devices("GPU") +
197
+ tf.config.list_logical_devices("CPU"))[0]
198
+ logging.info("Using tf_accelerator = %s", tf_accelerator)
199
+ if tf_accelerator.device_type == "TPU":
200
+ tolerances = dict(atol=1e-6, rtol=1e-6)
201
+ elif tf_accelerator.device_type == "GPU":
202
+ tolerances = dict(atol=1e-6, rtol=1e-4)
203
+ elif tf_accelerator.device_type == "CPU":
204
+ tolerances = dict(atol=1e-5, rtol=1e-5)
205
+ logging.info("Using tolerances %s", tolerances)
206
+ return tf_accelerator, tolerances
207
+
208
+
209
+ if __name__ == "__main__":
210
+ app.run(lambda _: train_and_save())
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/saved_model_main_test.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for mnist_lib, saved_model_lib, saved_model_main."""
15
+
16
+ import os
17
+ from absl import flags
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+
21
+ from jax._src import config
22
+ from jax._src import test_util as jtu
23
+
24
+ from jax.experimental.jax2tf.examples import saved_model_main
25
+ from jax.experimental.jax2tf.tests import tf_test_util
26
+
27
+ config.parse_flags_with_absl()
28
+ FLAGS = flags.FLAGS
29
+
30
+
31
+ class SavedModelMainTest(tf_test_util.JaxToTfTestCase):
32
+
33
+ def setUp(self):
34
+ super().setUp()
35
+ FLAGS.model_path = os.path.join(absltest.get_default_test_tmpdir(),
36
+ "saved_models")
37
+ FLAGS.num_epochs = 1
38
+ FLAGS.test_savedmodel = True
39
+ FLAGS.mock_data = True
40
+
41
+ @parameterized.named_parameters(
42
+ dict(
43
+ testcase_name=f"_{model}_batch={serving_batch_size}",
44
+ model=model,
45
+ serving_batch_size=serving_batch_size)
46
+ for model in ["mnist_pure_jax", "mnist_flax"]
47
+ for serving_batch_size in [1, -1])
48
+ def test_train_and_save_full(self,
49
+ model="mnist_flax",
50
+ serving_batch_size=-1):
51
+ if (serving_batch_size == -1 and
52
+ config.jax2tf_default_native_serialization.value and
53
+ not config.dynamic_shapes.value):
54
+ self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
55
+ FLAGS.model = model
56
+ FLAGS.model_classifier_layer = True
57
+ FLAGS.serving_batch_size = serving_batch_size
58
+ saved_model_main.train_and_save()
59
+
60
+ @parameterized.named_parameters(
61
+ dict(testcase_name=f"_{model}", model=model)
62
+ for model in ["mnist_pure_jax", "mnist_flax"])
63
+ def test_train_and_save_features(self, model="mnist_flax"):
64
+ FLAGS.model = model
65
+ FLAGS.model_classifier_layer = False
66
+ saved_model_main.train_and_save()
67
+
68
+
69
+ if __name__ == "__main__":
70
+ absltest.main(testLoader=jtu.JaxTestLoader())
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/examples/serving/model_server_request.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Demonstrates using jax2tf with TensorFlow model server.
15
+
16
+ See README.md for instructions.
17
+ """
18
+ import grpc # type: ignore
19
+ import json
20
+ import logging
21
+ import requests
22
+
23
+ from absl import app
24
+ from absl import flags
25
+
26
+ from jax.experimental.jax2tf.examples import mnist_lib
27
+
28
+ import numpy as np
29
+ import tensorflow as tf
30
+ import tensorflow_datasets as tfds # type: ignore[import-not-found]
31
+ from tensorflow_serving.apis import predict_pb2 # type: ignore[import-not-found]
32
+ from tensorflow_serving.apis import prediction_service_pb2_grpc
33
+
34
+
35
+ _USE_GRPC = flags.DEFINE_boolean(
36
+ "use_grpc", True,
37
+ "Use the gRPC API (default), or the HTTP REST API.")
38
+
39
+ _MODEL_SPEC_NAME = flags.DEFINE_string(
40
+ "model_spec_name", "",
41
+ "The name you used to export your model to model server (e.g., mnist_flax).")
42
+
43
+ _PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
44
+ "prediction_service_addr",
45
+ "localhost:8500",
46
+ "Stubby endpoint for the prediction service. If you serve your model "
47
+ "locally using TensorFlow model server, then you can use \"localhost:8500\""
48
+ "for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
49
+
50
+ _SERVING_BATCH_SIZE = flags.DEFINE_integer(
51
+ "serving_batch_size",
52
+ 1,
53
+ "Batch size for the serving request. Must match the "
54
+ "batch size at which the model was saved. Must divide "
55
+ "--count_images",
56
+ lower_bound=1,
57
+ )
58
+ _COUNT_IMAGES = flags.DEFINE_integer(
59
+ "count_images", 16, "How many images to test.", lower_bound=1
60
+ )
61
+
62
+
63
+ def serving_call_mnist(images):
64
+ """Send an RPC or REST request to the model server.
65
+
66
+ Args:
67
+ images: A numpy.ndarray of shape [B, 28, 28, 1] with the batch of images to
68
+ perform inference on.
69
+
70
+ Returns:
71
+ A numpy.ndarray of shape [B, 10] with the one-hot inference response.
72
+ """
73
+ if _USE_GRPC.value:
74
+ channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
75
+ stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
76
+
77
+ request = predict_pb2.PredictRequest()
78
+ request.model_spec.name = _MODEL_SPEC_NAME.value
79
+ request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
80
+ # You can see the name of the input ("inputs") in the SavedModel dump.
81
+ request.inputs["inputs"].CopyFrom(
82
+ tf.make_tensor_proto(images, dtype=images.dtype, shape=images.shape))
83
+ response = stub.Predict(request)
84
+ # We could also use response.outputs["output_0"], where "output_0" is the
85
+ # name of the output (which you can see in the SavedModel dump.)
86
+ # Alternatively, we just get the first output.
87
+ outputs, = response.outputs.values()
88
+ return tf.make_ndarray(outputs)
89
+ else:
90
+ # Use the HTTP REST api
91
+ images_json = json.dumps(images.tolist())
92
+ # You can see the name of the input ("inputs") in the SavedModel dump.
93
+ data = f'{{"inputs": {images_json}}}'
94
+ predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
95
+ response = requests.post(predict_url, data=data)
96
+ if response.status_code != 200:
97
+ msg = (f"Received error response {response.status_code} from model "
98
+ f"server: {response.text}")
99
+ raise ValueError(msg)
100
+ outputs = response.json()["outputs"]
101
+ return np.array(outputs)
102
+
103
+
104
+ def main(_):
105
+ if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
106
+ raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
107
+ "multiple of "
108
+ f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
109
+ test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
110
+ batch_size=_SERVING_BATCH_SIZE.value)
111
+ images_and_labels = tfds.as_numpy(test_ds.take(
112
+ _COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
113
+
114
+ accurate_count = 0
115
+ for batch_idx, (images, labels) in enumerate(images_and_labels):
116
+ predictions_one_hot = serving_call_mnist(images)
117
+ predictions_digit = np.argmax(predictions_one_hot, axis=1)
118
+ labels_digit = np.argmax(labels, axis=1)
119
+ accurate_count += np.sum(labels_digit == predictions_digit)
120
+ running_accuracy = (
121
+ 100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
122
+ logging.info(
123
+ " predicted digits = %s labels %s. Running accuracy %.3f%%",
124
+ predictions_digit, labels_digit, running_accuracy)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ app.run(main)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/impl_no_xla.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Workarounds for jax2tf transforms when XLA is not linked in."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import builtins
19
+ from collections.abc import Sequence
20
+ import dataclasses
21
+ from functools import partial, wraps
22
+ import math
23
+ import string
24
+ from typing import Any, Callable, Optional
25
+
26
+ from jax._src import core
27
+ from jax import lax
28
+ from jax._src.lax import slicing as lax_slicing
29
+ from jax._src import dtypes
30
+ from jax._src import util
31
+
32
+ from jax.experimental.jax2tf import jax2tf
33
+
34
+ import numpy as np
35
+ import tensorflow as tf
36
+
37
+
38
+ # Implementation rules for primitives when XLA is not linked in. These
39
+ # implementations are workarounds, making use of TF ops that do work when XLA is
40
+ # not linked in. They are only used when the argument `enable_xla=False` when
41
+ # calling jax2tf.convert().
42
+ tf_impl_no_xla: dict[core.Primitive, Callable[..., Any]] = {}
43
+
44
+
45
+ TfVal = Any
46
+ DType = Any
47
+ PrecisionType = Any
48
+
49
+
50
+ def _error(primitive_name: str, suffix_msg: str = "") -> Exception:
51
+ msg = f"Call to {primitive_name} cannot be converted with enable_xla=False."
52
+ if suffix_msg:
53
+ msg += (f" {suffix_msg} - See source code for the precise conditions under "
54
+ "which it can be converted without XLA.")
55
+ return NotImplementedError(msg)
56
+
57
+ _conv_error = lambda msg: _error("conv_general_dilated", msg)
58
+ _reduce_error = lambda msg: _error("reduce_window", msg)
59
+ _scatter_error = lambda msg: _error("scatter_(update/add/multiply/min/max)", msg
60
+ )
61
+
62
+ def _unimplemented(name):
63
+
64
+ def op(*arg, **kwargs):
65
+ raise _error(name)
66
+
67
+ return op
68
+
69
+
70
+ # TODO(marcvanzee): Remove this function and use `tf.math.invert_permutation`
71
+ # once it is implemented by TFjs:
72
+ # https://github.com/tensorflow/tfjs/issues/6395.
73
+ def _invert_permutation(perm):
74
+ return tuple(perm.index(i) for i in range(len(perm)))
75
+
76
+
77
+ def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> tuple[TfVal, core.Shape]:
78
+ """Computes transposition of x and its shape.
79
+
80
+ x_shape matches x.shape in the known dimensions, and it has dimension
81
+ polynomials elsewhere, while x.shape has None.
82
+ """
83
+ return tf.transpose(x, perm=permutation), tuple(x_shape[i] for i in permutation)
84
+
85
+
86
+ def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
87
+ rhs, rhs_shape: core.Shape, dimension_numbers):
88
+ """Transposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
89
+
90
+ The shapes passed in and returned may contain polynomials, and thus may
91
+ be different than lhs.shape and rhs.shape.
92
+ """
93
+ # TODO(marcvanzee): Add tests for this ops for shape polymorphism.
94
+ lhs_perm, rhs_perm, _ = dimension_numbers
95
+
96
+ # TODO(marcvanzee): Consider merging transposes if we want to optimize.
97
+ # For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
98
+ lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW"
99
+ if len(lhs_perm) == 3:
100
+ # For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution
101
+ # logic can be applied downstream.
102
+ lhs = lhs[:, :, :, np.newaxis]
103
+ lhs_shape = tuple(lhs_shape) + (1,)
104
+ # However, the TF ops only support "NHWC" on CPU, so we transpose again.
105
+ lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
106
+
107
+ # For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
108
+ rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, rhs_perm) # rhs --> "OIHW"
109
+ # Handle conv1d case.
110
+ if len(rhs_perm) == 3:
111
+ rhs = rhs[:, :, :, np.newaxis]
112
+ rhs_shape = tuple(rhs_shape) + (1,)
113
+ # For the tf ops, rhs is expected to be "OIHW".
114
+ rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
115
+ jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
116
+ jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
117
+ return lhs, lhs_shape, rhs, rhs_shape
118
+
119
+
120
+ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
121
+ for pad_str in ["VALID", "SAME"]:
122
+ pads = lax.padtype_to_pads(in_shape, window_shape, window_strides, pad_str)
123
+ if list(pads) == list(padding):
124
+ return pad_str
125
+ return "EXPLICIT"
126
+
127
+
128
+ def _pad_spatial_dims(x, x_shape, padding):
129
+ """Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
130
+ padding = tuple(padding)
131
+ if len(padding) == len(x_shape) - 2:
132
+ # If necessary, add empty padding for batch and feature dimensions.
133
+ no_pad = ((0, 0),)
134
+ padding = no_pad + padding + no_pad
135
+ x = tf.pad(x, padding)
136
+ assert len(x.shape) == len(padding)
137
+ x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
138
+ jax2tf._assert_matching_abstract_shape(x, x_shape)
139
+ return x, x_shape
140
+
141
+ def _check_pad_spatial_dims(x, x_shape, padding):
142
+ """Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
143
+ padding = tuple(padding)
144
+ if len(padding) == len(x_shape) - 2:
145
+ # If necessary, add empty padding for batch and feature dimensions.
146
+ no_pad = ((0, 0),)
147
+ padding = no_pad + padding + no_pad
148
+ assert len(x.shape) == len(padding)
149
+ x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
150
+ return x, x_shape, padding
151
+
152
+ def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding):
153
+ """Finds the padding type for a transpose convolution."""
154
+ # This is simply checking agreement with lax._conv_transpose_padding.
155
+ is_valid = True
156
+ is_same = True
157
+ if not len(kernel_sdims) == len(lhs_dilation) == len(padding):
158
+ raise ValueError(f'Found different lengths for '
159
+ f'kernel_sdims ({kernel_sdims}), '
160
+ f'lhs_dilation ({lhs_dilation}), '
161
+ f'and padding ({padding}).')
162
+ for k, s, (begin, end) in zip(kernel_sdims, lhs_dilation, padding):
163
+ # Check for VALID padding.
164
+ pad_len_valid = k + s - 2 + builtins.max(k - s, 0)
165
+ pad_a = k - 1
166
+ pad_b = pad_len_valid - pad_a
167
+ if begin != pad_a or end != pad_b:
168
+ is_valid = False
169
+
170
+ # Check for SAME padding.
171
+ pad_len_same = k + s - 2
172
+ if s > k - 1:
173
+ pad_a = k - 1
174
+ else:
175
+ pad_a = int(np.ceil(pad_len_same / 2))
176
+ pad_b = pad_len_same - pad_a
177
+ if begin != pad_a or end != pad_b:
178
+ is_same = False
179
+
180
+ if is_valid:
181
+ return 'VALID'
182
+ elif is_same:
183
+ return 'SAME'
184
+ raise ValueError('Transpose convolution padding mode must be '
185
+ '`SAME` or `VALID`.')
186
+
187
+ def _validate_spatial_dimensions(lhs: TfVal, lhs_shape: core.Shape,
188
+ rhs: TfVal, rhs_shape: core.Shape):
189
+ """Check spatial dimension support."""
190
+ jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
191
+ jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
192
+
193
+ nr_spatial_dimensions = len(lhs_shape) - 2
194
+ # Currently we only support 1D+2D convolutions because it keeps the code
195
+ # relatively simple and covers most cases.
196
+ if nr_spatial_dimensions > 2:
197
+ raise _conv_error(
198
+ "We only support 1D or 2D convolutions, but found "
199
+ f"{nr_spatial_dimensions}.")
200
+
201
+
202
+ def _normalize_padding_and_dilations(
203
+ padding, lhs_dilation, rhs_dilation, is_conv1d):
204
+ if is_conv1d:
205
+ lhs_dilation = list(lhs_dilation) + [1]
206
+ rhs_dilation = list(rhs_dilation) + [1]
207
+ # Empty padding in the dummy dimension.
208
+ # Note that when kernel_size=stride=1, padding of (0, 0) is both 'VALID' and
209
+ # 'SAME'. So the inferred padding type will still register according to the
210
+ # first dimension padding.
211
+ padding = list(padding) + [(0, 0)]
212
+ return padding, lhs_dilation, rhs_dilation
213
+
214
+
215
+ def _normalize_window_strides(window_strides):
216
+ """Ensure window_strides has length 4."""
217
+ # Some TF ops require len(window_strides) == 4 while others do not. We simply
218
+ # ensure it always has len(4).
219
+ if len(window_strides) == 1:
220
+ # This is the Conv1D case. We add a dummy dimension to allow using 2D ops,
221
+ # and use stride=1 on the dummy dimension.
222
+ window_strides = list(window_strides) + [1]
223
+ if len(window_strides) == 2:
224
+ window_strides = [1] + list(window_strides) + [1]
225
+ return window_strides
226
+
227
+
228
+ def _validate_conv_features(
229
+ is_transpose, is_atrous, is_depthwise, feature_group_count,
230
+ batch_group_count, preferred_element_type, lhs_dtype):
231
+ if feature_group_count > 1 and not is_depthwise:
232
+ raise _conv_error("Grouped convolutions are unsupported")
233
+ if (is_depthwise and is_atrous) and not is_transpose:
234
+ # We allow dilated depthwise convolutions.
235
+ pass
236
+ elif [is_depthwise, is_atrous, is_transpose].count(True) > 1:
237
+ raise _conv_error(
238
+ f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) "
239
+ f"and transposed convolutions ({is_transpose})")
240
+
241
+ # We can implement batch grouping when there is a need for it.
242
+ if batch_group_count != 1:
243
+ raise _conv_error("Unimplemented support for batch_group_count != 1 "
244
+ f"(found {batch_group_count})")
245
+
246
+ if (preferred_element_type is not None and
247
+ preferred_element_type != lhs_dtype):
248
+ raise _conv_error("Unimplemented support for preferred_element_type")
249
+
250
+
251
+ def _conv_general_dilated(
252
+ lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
253
+ dimension_numbers: lax.ConvDimensionNumbers, feature_group_count: int,
254
+ batch_group_count: int,
255
+ precision: tuple[PrecisionType, PrecisionType] | None,
256
+ preferred_element_type: DType | None,
257
+ _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
258
+ """Implementation of lax.conv_general_dilated_p using XlaConv."""
259
+ # In presence of shape polymorphism, lhs.shape and rhs.shape may contain
260
+ # None. The actual dimension polynomial shapes are in _in_avals.
261
+ del precision # Unused arguments.
262
+ lhs_shape, rhs_shape = _in_avals[0].shape, _in_avals[1].shape
263
+ out_shape = _out_aval.shape
264
+ _validate_spatial_dimensions(lhs, lhs_shape, rhs, rhs_shape)
265
+ is_conv1d = len(lhs_shape) - 2 == 1
266
+
267
+ tf_window_strides = _normalize_window_strides(window_strides)
268
+ padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations(
269
+ padding, lhs_dilation, rhs_dilation, is_conv1d)
270
+
271
+ lhs, lhs_shape, rhs, rhs_shape = _transpose_for_tf_conv(lhs, lhs_shape,
272
+ rhs, rhs_shape,
273
+ dimension_numbers)
274
+ in_channels = lhs_shape[-1]
275
+ *rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
276
+
277
+ is_transpose = any(d != 1 for d in lhs_dilation)
278
+ is_atrous = any(d != 1 for d in rhs_dilation)
279
+ is_depthwise = in_channels == feature_group_count and feature_group_count > 1
280
+ _validate_conv_features(is_transpose, is_atrous, is_depthwise,
281
+ feature_group_count, batch_group_count,
282
+ preferred_element_type, lhs.dtype.as_numpy_dtype)
283
+
284
+ rhs_dilated_shape = [
285
+ (k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation)
286
+ ]
287
+ output_perm = dimension_numbers[2]
288
+
289
+ if is_transpose:
290
+ padding_type = _conv_transpose_pads_to_padtype(
291
+ rhs_spatial_shapes, lhs_dilation, padding)
292
+ else:
293
+ padding_type = pads_to_padtype(
294
+ lhs_shape[1:3], rhs_dilated_shape, window_strides, padding)
295
+ # We only manually pad if we aren't using a transposed convolutions.
296
+ if padding_type == "EXPLICIT":
297
+ lhs, lhs_shape, padding = _check_pad_spatial_dims(lhs, lhs_shape, padding)
298
+ padding_type = padding
299
+
300
+ if padding_type != "SAME" and any(l < r for l, r in zip(lhs_shape[1:3], rhs_dilated_shape)):
301
+ # If the input shape is smaller than the filter shape in a spatial dimension,
302
+ # lax returns only zeros while tf.conv2d returns an error.
303
+ # We thus return zeros to make sure the behavior is consistent.
304
+ return tf.broadcast_to(tf.constant(0, dtype=tf.float32),
305
+ jax2tf._eval_shape(out_shape))
306
+
307
+ if is_depthwise:
308
+ # Reshape filter from
309
+ # [filter_height, filter_width, 1, in_channels * channel_multiplier] to
310
+ # [filter_height, filter_width, in_channels, channel_multiplier].
311
+ new_rhs_shape = tuple(rhs_spatial_shapes) + (in_channels,
312
+ rhs_out_channel // in_channels)
313
+ output = tf.nn.depthwise_conv2d(
314
+ input=lhs,
315
+ filter=tf.reshape(rhs, jax2tf._eval_shape(new_rhs_shape)),
316
+ strides=tf_window_strides,
317
+ padding=padding_type,
318
+ dilations=rhs_dilation)
319
+
320
+ elif is_transpose:
321
+ # tf.nn.conv2d_transpose requires a transposed filter.
322
+ rhs_t = tf.reverse(rhs, [0, 1])
323
+ rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
324
+
325
+ # We should transpose `out_shape` to "NHWC", which is what TF expects.
326
+ # First transpose to "NCHW".
327
+ if is_conv1d:
328
+ tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,)
329
+ else:
330
+ tf_out_shape = tuple(out_shape[i] for i in output_perm)
331
+ # Then transpose "NCHW" to "NHWC".
332
+ tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1))
333
+ output = tf.nn.conv2d_transpose(
334
+ input=lhs,
335
+ filters=rhs_t,
336
+ output_shape=jax2tf._eval_shape(tf_out_shape),
337
+ strides=lhs_dilation,
338
+ padding=padding_type)
339
+
340
+ else:
341
+ output = tf.nn.conv2d(
342
+ input=lhs,
343
+ filters=rhs,
344
+ strides=tf_window_strides,
345
+ padding=padding_type,
346
+ dilations=rhs_dilation)
347
+
348
+ # TF outputs in format "NHWC", so convert to "NCHW", which is lax's default
349
+ # format.
350
+ output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW"
351
+ if is_conv1d:
352
+ output = output[:, :, :, 0]
353
+ # To determine the right permutation, we compute the inverse permutation of
354
+ # `output_perm`, so that when `output_perm` is applied to `output`, we obtain
355
+ # the outpt in NCHW format.
356
+ inverse_perm = _invert_permutation(output_perm)
357
+ output = tf.transpose(output, inverse_perm) # "NCHW" -> desired output shape.
358
+ return output
359
+
360
+
361
+ tf_impl_no_xla[lax.conv_general_dilated_p] = _conv_general_dilated
362
+
363
+
364
+ def _dot_general(lhs, rhs, *, dimension_numbers,
365
+ precision: tuple[PrecisionType, PrecisionType] | None,
366
+ preferred_element_type: DType | None,
367
+ _in_avals: Sequence[core.ShapedArray],
368
+ _out_aval: core.ShapedArray):
369
+ """Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
370
+ # Unused arguments.
371
+ del precision
372
+ del preferred_element_type
373
+
374
+ lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype(
375
+ lhs, _in_avals[0], rhs, _in_avals[1], _out_aval)
376
+
377
+ (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
378
+ lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
379
+
380
+ # This condition ensures that:
381
+ # 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
382
+ # not strictly necessary, but we would have to reshape the array if that
383
+ # were not the case;
384
+ # 2) lhs and rhs have the same number of dimensions +/- 1
385
+ # 3) the number of non-batch dimensions in both tensors is either 1 or 2
386
+ # 4) the contracting dimensions are consistent with those of a classic
387
+ # matrix/matrix, vector/matrix or matrix/vector multiplication.
388
+ if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) and
389
+ lhs_ndim - rhs_ndim in [-1, 0, 1] and
390
+ 1 <= lhs_ndim - len(lhs_batch) <= 2 and
391
+ 1 <= rhs_ndim - len(rhs_batch) <= 2 and
392
+ lhs_contracting == (len(lhs.shape) - 1,) and
393
+ rhs_contracting == (len(lhs_batch),)):
394
+ # All the inputs to tf.linalg.matmul must have 2 inner dimensions,
395
+ # after their batch dimensions, so we need to expand the dimensions
396
+ # appropriately. We can get to this branch with three combinations of
397
+ # inner shapes:
398
+ # - lhs.inner_shape == [a, b], rhs.inner_shape == [b, c]
399
+ # - in this case, the resulting inner shape is [a, c];
400
+ # - lhs.inner_shape == [b] , rhs.inner_shape == [b, c]
401
+ # - in this case, we need to expand lhs to [1, b], and the resulting
402
+ # shape is [c]. We need to squeeze the result of tf.linalg.matmul
403
+ # as it will have shape [1, c];
404
+ # - lhs.shape == [batch] + [a, b], rhs.shape == [batch] + [b]
405
+ # - in this case, we need to expand rhs to [b, 1], and the resulting
406
+ # shape is [a]. We need to squeeze the result of tf.linalg.matmul
407
+ # as it will have shape [a, 1];
408
+ # - lhs.shape == [batch] + [b] , rhs.shape == [batch] + [b]
409
+ # - in this case, we need to expand lhs to [1, b] and rhs to [b, 1],
410
+ # and the resulting shape is (). We need to squeeze the result of
411
+ # tf.linalg.matmul as it will have shape [1, 1].
412
+ squeeze_idxs = []
413
+ if lhs_ndim - len(lhs_batch) == 1:
414
+ lhs = tf.expand_dims(lhs, lhs_ndim - 1)
415
+ squeeze_idxs.append(len(lhs.shape) - 2)
416
+ if rhs_ndim - len(rhs_batch) == 1:
417
+ rhs = tf.expand_dims(rhs, rhs_ndim)
418
+ squeeze_idxs.append(len(rhs.shape) - 1)
419
+ result = tf.linalg.matmul(lhs, rhs)
420
+ if len(squeeze_idxs) != 0:
421
+ assert all(result.shape[i] == 1 for i in squeeze_idxs)
422
+ result = tf.squeeze(result, squeeze_idxs)
423
+ return convert_result(result)
424
+
425
+ new_id = iter(string.ascii_letters)
426
+ lhs_axis_ids = [next(new_id) for _ in lhs.shape]
427
+ rhs_axis_ids = [next(new_id) for _ in rhs.shape]
428
+ lhs_out_axis_ids = lhs_axis_ids[:]
429
+ rhs_out_axis_ids = rhs_axis_ids[:]
430
+
431
+ for lhs_axis, rhs_axis in zip(lhs_contracting, rhs_contracting):
432
+ shared_id = next(new_id)
433
+ lhs_axis_ids[lhs_axis] = shared_id
434
+ rhs_axis_ids[rhs_axis] = shared_id
435
+ lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
436
+ rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
437
+
438
+ batch_ids = []
439
+ for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
440
+ shared_id = next(new_id)
441
+ lhs_axis_ids[lhs_axis] = shared_id
442
+ rhs_axis_ids[rhs_axis] = shared_id
443
+ lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
444
+ rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
445
+ batch_ids.append(shared_id)
446
+
447
+ not_none = lambda x: x is not None
448
+ out_axis_ids = list(
449
+ filter(not_none, batch_ids + lhs_out_axis_ids + rhs_out_axis_ids))
450
+ assert lhs.dtype == rhs.dtype
451
+ spec = "{},{}->{}".format("".join(lhs_axis_ids), "".join(rhs_axis_ids),
452
+ "".join(out_axis_ids))
453
+ return convert_result(tf.linalg.einsum(spec, lhs, rhs))
454
+
455
+
456
+ tf_impl_no_xla[lax.dot_general_p] = _dot_general
457
+
458
+
459
+ def _interior_padding(operand, padding_value, padding_config, operand_shape):
460
+ # Used only when enable_xla=False
461
+ # Applies only the interior padding from the padding_config.
462
+ # We do this somewhat inefficiently, as a scatter.
463
+ # For each dimension we compute the indices_by_dim as [0, f, 2f, 3f, ...] where
464
+ # f is the dilation factor for the dimension, i.e., 1 + interior_padding.
465
+ # Then we compute the cartesian production of the indices (using broadcast
466
+ # and concat).
467
+
468
+ # We could make this code more complex and do all the padding at once, but
469
+ # we prefer to keep it simple.
470
+ indices_by_dim = []
471
+ indices_shape = operand_shape + (1,)
472
+ output_shape = [] # considering only interior padding
473
+ for d, (dsz, (_, _, i)) in enumerate(zip(operand_shape, padding_config)):
474
+ dilation_factor = i + 1
475
+ output_shape.append(dsz * dilation_factor - i)
476
+ indices = tf.range(dsz) * dilation_factor
477
+ expansion = [None] * (1 + len(operand_shape))
478
+ expansion[d] = slice(None, None, None)
479
+ indices_by_dim.append(tf.broadcast_to(indices[expansion], indices_shape))
480
+
481
+ indices_cartesian = tf.concat(indices_by_dim, axis=len(operand_shape))
482
+ scattered = tf.scatter_nd(indices_cartesian, operand, output_shape)
483
+ # What elements from the output array we use from
484
+ mask = tf.scatter_nd(indices_cartesian, tf.ones_like(operand, dtype=np.bool_),
485
+ output_shape)
486
+ return tf.where(mask, scattered, padding_value)
487
+
488
+
489
+ def _pad(operand, padding_value, *, padding_config,
490
+ _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
491
+ # Do only the interior padding first. This is rarely needed.
492
+ if any(i != 0 for _, _, i in padding_config):
493
+ operand = _interior_padding(operand, padding_value, padding_config,
494
+ jax2tf._eval_shape(_in_avals[0].shape))
495
+
496
+ # Now do the non-negative edge padding. This is the common case, use tf.pad.
497
+ non_negative_padding = [((lo if lo >= 0 else 0), (hi if hi >= 0 else 0))
498
+ for lo, hi, _ in padding_config]
499
+ operand = tf.pad(
500
+ operand,
501
+ non_negative_padding,
502
+ mode="CONSTANT",
503
+ constant_values=padding_value)
504
+ # Now the negative edge padding (this is also rare)
505
+ if any(lo < 0 or hi < 0 for lo, hi, _ in padding_config):
506
+ output_shape = jax2tf._eval_shape(_out_aval.shape)
507
+ begins = [(-lo if lo < 0 else 0) for lo, _, _ in padding_config]
508
+ operand = tf.slice(operand, begins, output_shape)
509
+
510
+ return operand
511
+
512
+
513
+ tf_impl_no_xla[lax.pad_p] = _pad
514
+
515
+
516
+ def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
517
+ index_dtype: DType, _in_avals: Sequence[core.ShapedArray],
518
+ _out_aval: core.ShapedArray):
519
+ # The following is known to diverge from JAX behavior for NaN.
520
+ axis, = axes
521
+ output_type = tf.int32
522
+ if dtypes.iinfo(index_dtype).bits > 32:
523
+ output_type = tf.int64
524
+ # TODO(phawkins): handle axes larger than 2^31.
525
+ fn = tf.math.argmin if is_min else tf.math.argmax
526
+ result = fn(operand, axis=axis, output_type=output_type)
527
+ return tf.cast(result, jax2tf._to_tf_dtype(index_dtype))
528
+
529
+
530
+ tf_impl_no_xla[lax.argmin_p] = partial(_argminmax, True)
531
+ tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False)
532
+
533
+
534
+ def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
535
+ window_dimensions, window_strides,
536
+ base_dilation, window_dilation):
537
+ if computation_name not in ["min", "max", "add"]:
538
+ raise _reduce_error("Reduction function should be either min, max, or add.")
539
+ if computation_name in ["min", "max"] and dtype in [
540
+ tf.bool, tf.uint32, tf.uint64, tf.complex64, tf.complex128
541
+ ]:
542
+ raise _reduce_error("Min/max pool does not support operands of type "
543
+ f"{dtype}")
544
+ if computation_name == "min" and dtype in [tf.uint8, tf.uint16]:
545
+ # TODO(marcvanzee): We currently implement min pooling by negating the
546
+ # input, but this doesn't work for uint. We could work around it using
547
+ # tf.math.reduce_min.
548
+ raise _reduce_error(f"Min pool does not support operands of type {dtype}")
549
+ if computation_name == "add" and dtype not in [
550
+ tf.bfloat16,
551
+ tf.float16,
552
+ tf.float32,
553
+ tf.float64,
554
+ tf.int16,
555
+ tf.int32,
556
+ ]:
557
+ raise _reduce_error("Add pooling does not support operands of type "
558
+ f"{dtype}")
559
+
560
+ if (len(operand_shape) != len(window_dimensions) != len(window_strides) !=
561
+ len(window_dilation)):
562
+ raise _reduce_error("Input shapes, window dimensions, window stride "
563
+ "dimensions, and window dilation dimensions should "
564
+ "match.")
565
+
566
+ has_only_spatial_dims = True
567
+ if len(operand_shape) > 4:
568
+ raise _reduce_error("Only 1D or 2D input are supported.")
569
+ if len(operand_shape) > 2:
570
+ # operand_shape = (batch, spatial_dims, ..., channel).
571
+ has_only_spatial_dims = False
572
+
573
+ for name, value in [("window_dimensions", window_dimensions),
574
+ ("window_strides", window_strides),
575
+ ("window_dilation", window_dilation)]:
576
+ if value[0] != value[-1] != 1:
577
+ raise _reduce_error("Only 1D or 2D input are supported, expected "
578
+ f"{name}=(1, spatial_dims, ..., 1), but got "
579
+ f"{value}")
580
+
581
+ if list(base_dilation) != [1] * len(operand_shape):
582
+ # TODO(marcvanzee): Add support for base dilations. We can do this using
583
+ # a scatter on operand.
584
+ raise _reduce_error("Unimplemented support for base dilation.")
585
+
586
+ return has_only_spatial_dims
587
+
588
+
589
+ def _padding_reduce_window(operand, operand_shape, computation_name,
590
+ window_dimensions, window_strides, padding):
591
+ padding_type = pads_to_padtype(operand_shape, window_dimensions,
592
+ window_strides, padding)
593
+
594
+ # https://github.com/google/jax/issues/11874.
595
+ needs_manual_padding = (
596
+ padding_type == "SAME" and computation_name == "add" and
597
+ window_dimensions != [1] * len(operand_shape))
598
+
599
+ if needs_manual_padding or padding_type == "EXPLICIT":
600
+ operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding)
601
+ padding_type = "VALID"
602
+
603
+ return operand, operand_shape, padding_type
604
+
605
+
606
+ def _reshape_reduce_window(operand, operand_shape, window_dimensions,
607
+ window_strides, window_dilation, *,
608
+ has_only_spatial_dims):
609
+ # Reshape inputs so they are accepted by tf.nn.pool, which expects batch and
610
+ # channel dimensions for operand but not for any of the other inputs.
611
+ if has_only_spatial_dims: # len(operand_shape) <= 2
612
+ # Call eval_shape on a shape that may contain polynomials, otherwise TF does
613
+ # not know what to do with polynomials in the shape.
614
+ operand_shape = jax2tf._eval_shape(operand_shape)
615
+ # Add batch and channel dimensions to operand.
616
+ operand = tf.reshape(operand, (1,) + operand_shape + (1,))
617
+ else:
618
+ # This branch assumes operand.shape = (batch, spatial_dims, ..., channel),
619
+ # and dimensions, strides, dilation are all (1, spatial_values, ..., 1).
620
+ # Input validation for this is done in _validate_reduce_window_inputs.
621
+ window_dimensions = window_dimensions[1:-1]
622
+ window_strides = window_strides[1:-1]
623
+ window_dilation = window_dilation[1:-1]
624
+
625
+ return operand, window_dimensions, window_strides, window_dilation
626
+
627
+
628
+ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
629
+ base_dilation, window_dilation, computation_name,
630
+ _in_avals: Sequence[core.ShapedArray],
631
+ _out_aval: core.ShapedArray):
632
+ dtype = operand.dtype
633
+ # In presence of shape polymorphism, operand.shape may contain None. The
634
+ # actual dimension polynomial shapes are in _in_avals.
635
+ operand_shape = _in_avals[0].shape
636
+
637
+ # TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to
638
+ # Gather, to simplify function calls.
639
+ has_only_spatial_dims = _validate_reduce_window_inputs(
640
+ operand_shape, computation_name, dtype, window_dimensions, window_strides,
641
+ base_dilation, window_dilation)
642
+
643
+ operand, operand_shape, padding_type = _padding_reduce_window(
644
+ operand, operand_shape, computation_name, window_dimensions,
645
+ window_strides, padding)
646
+
647
+ operand, window_dimensions, window_strides, dilations = _reshape_reduce_window(
648
+ operand,
649
+ operand_shape,
650
+ window_dimensions,
651
+ window_strides,
652
+ window_dilation,
653
+ has_only_spatial_dims=has_only_spatial_dims)
654
+
655
+ def tf_pool(inputs, pooling_type):
656
+ if any(not core.is_constant_shape(s) for s in
657
+ (window_dimensions, window_strides, dilations)):
658
+ raise NotImplementedError(
659
+ f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} "
660
+ f" {window_strides=} {dilations=}")
661
+ # tf.nn.pool() currently does not suport tf.int32 and so we cast back and
662
+ # forth in order to be able to convert.
663
+ if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add":
664
+ original_dtype = inputs.dtype
665
+ inputs = tf.cast(inputs, dtype=tf.float32)
666
+ else:
667
+ original_dtype = None
668
+ result = tf.nn.pool(
669
+ inputs,
670
+ window_shape=window_dimensions,
671
+ pooling_type=pooling_type,
672
+ padding=padding_type,
673
+ strides=window_strides,
674
+ dilations=dilations)
675
+ if original_dtype:
676
+ result = tf.cast(result, dtype=original_dtype)
677
+
678
+ if has_only_spatial_dims:
679
+ # If the input only had spatial dimensions we need to contract the batch
680
+ # and channel dimensions before returning the output.
681
+ result = tf.squeeze(result, [0, -1])
682
+
683
+ jax2tf._assert_matching_abstract_shape(result, _out_aval.shape)
684
+ return result
685
+
686
+ negate = lambda x: tf.multiply(x, tf.constant(-1, dtype))
687
+ if computation_name == "max":
688
+ return tf_pool(operand, "MAX")
689
+ elif computation_name == "min":
690
+ return negate(tf_pool(negate(operand), "MAX"))
691
+ elif computation_name == "add":
692
+ # TODO(marcvanzee): This may give very large deviations on TPU when using
693
+ # floats as inputs. Alternatively, we could implement this using a
694
+ # convolution with an all-1's kernel.
695
+ return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions))
696
+
697
+
698
+ def _reduce_window(*args, jaxpr, consts, window_dimensions,
699
+ window_strides, padding, base_dilation, window_dilation,
700
+ _in_avals: Sequence[core.ShapedArray],
701
+ _out_aval: tuple[core.ShapedArray, ...]
702
+ ) -> tuple[TfVal, ...]:
703
+ assert len(consts) == 0, "Reduction computation cannot have constants"
704
+ operands, init_values = util.split_list(args, [len(args) // 2])
705
+
706
+ if len(operands) != 1 or len(init_values) != 1:
707
+ raise _reduce_error("jax2tf does not support variadic reduce_window")
708
+
709
+ operand, init_value = operands[0], init_values[0]
710
+ # Infer operation type from jaxpr.
711
+ if (len(jaxpr.eqns) != 1 or
712
+ len(jaxpr.eqns[0].invars) != 2 or
713
+ len(jaxpr.eqns[0].outvars) != 1 or
714
+ jaxpr.eqns[0].primitive.name not in ["min", "max", "add"]):
715
+ raise _reduce_error("Reduction function should be either min, max, or add.")
716
+
717
+ computation_name = jaxpr.eqns[0].primitive.name
718
+ result = _reduce_monoid(operand,
719
+ window_dimensions=window_dimensions,
720
+ window_strides=window_strides,
721
+ padding=padding,
722
+ base_dilation=base_dilation,
723
+ window_dilation=window_dilation,
724
+ computation_name=computation_name,
725
+ _in_avals=(_in_avals[0],), # Don't pass init_value.
726
+ _out_aval=_out_aval[0]) # Returns single value.
727
+
728
+ reduce_fn = {
729
+ "min": tf.minimum,
730
+ "max": tf.maximum,
731
+ "add": tf.add,
732
+ }[computation_name]
733
+ result = reduce_fn(result, init_value)
734
+
735
+ # The output is expected to be wrapped in a tuple, and since we don't use
736
+ # variadic reductions, this tuple always contains a single element.
737
+ return (result,)
738
+
739
+
740
+ tf_impl_no_xla[lax.reduce_window_min_p] = (
741
+ partial(_reduce_monoid, computation_name="min"))
742
+ tf_impl_no_xla[lax.reduce_window_max_p] = (
743
+ partial(_reduce_monoid, computation_name="max"))
744
+ tf_impl_no_xla[lax.reduce_window_sum_p] = (
745
+ partial(_reduce_monoid, computation_name="add"))
746
+
747
+ tf_impl_no_xla[lax.reduce_window_p] = _reduce_window
748
+
749
+ tf_impl_no_xla[lax.reduce_p] = _unimplemented("reduce")
750
+
751
+ tf_impl_no_xla[lax.select_and_scatter_add_p] = _unimplemented(
752
+ "select_and_scatter_add")
753
+
754
+ tf_impl_no_xla[lax.rng_bit_generator_p] = _unimplemented("rng_bit_generator")
755
+
756
+
757
+ def _clip(max_indices: Sequence[TfVal], start_indices: Sequence[TfVal],
758
+ slice_sizes: Sequence[TfVal]):
759
+ """Simulates XLA clipping behavior with TF ops.
760
+
761
+ Various TF ops have different clipping behavior than XLA:
762
+ * If `start_indices` is out-of-bounds, then TF fails but XLA clips the indices
763
+ to
764
+ [0, max_len].
765
+ * If `start_indices + slice_size` is out-of-bounds, then TF fails, but XLA
766
+ adjust
767
+ `start_indices` so that a full slice is returned.
768
+ This function clips the start indices correctly.
769
+ """
770
+ # We cast both arguments to `tf.clip_by_value` to int32. Otherwise, this
771
+ # function may return uint32 which is not always compatible with TF ops, so
772
+ # this may result in type errors.
773
+ max_start = tf.cast(tf.subtract(max_indices, slice_sizes), dtype=tf.int32)
774
+ return tf.clip_by_value(tf.cast(start_indices, dtype=tf.int32), 0, max_start)
775
+
776
+
777
+ @dataclasses.dataclass
778
+ class GatherArgs:
779
+ operand: TfVal
780
+ start_indices: TfVal
781
+ dnums: lax.GatherDimensionNumbers
782
+ slice_sizes: TfVal
783
+ op_shape: core.Shape
784
+ start_indices_shape: core.Shape
785
+ out_aval: core.ShapedArray
786
+
787
+ def __post_init__(self):
788
+ assert len(self.op_shape) == len(self.slice_sizes)
789
+
790
+ def __repr__(self):
791
+ return (f"operand shape={self.op_shape}, "
792
+ f"start_indices={self.start_indices}, "
793
+ f"dimension_numbes={self.dnums}, "
794
+ f"slice_sizes={self.slice_sizes}")
795
+ @property
796
+ def batch_dims(self):
797
+ return tuple(x for x in range(len(self.out_aval.shape))
798
+ if x not in self.dnums.offset_dims)
799
+
800
+ def gather_precondition(precondition_fn: Callable[[GatherArgs], None]):
801
+ """Decorator for specifying a precondition function.
802
+
803
+ This decorator should be put on a function with argument `arg` of type
804
+ `GatherArgs`. It will first call `precondition_fn` with `arg` (which may throw
805
+ an exception), and then call the function it is decorating with `arg` as well.
806
+ """
807
+
808
+ def decorator(gather_fn: Callable[[GatherArgs], Any]):
809
+
810
+ @wraps(gather_fn)
811
+ def wrapper(args: GatherArgs):
812
+ # Call `precondition_fn`; we assume it may throw an exception.
813
+ precondition_fn(args)
814
+ return gather_fn(args)
815
+
816
+ return wrapper
817
+
818
+ return decorator
819
+
820
+
821
+ def _pre_gather_for_scalar_indexing(args: GatherArgs):
822
+ """Returns True if this call to gather represents scalar indexing into arrays.
823
+
824
+ E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
825
+ """
826
+ # TODO(marcvanzee): Add more assumptions here, because this is currently too
827
+ # permissive.
828
+ if len(args.start_indices_shape) != 1:
829
+ raise ValueError("start_indices shape should be 1")
830
+
831
+
832
+ @gather_precondition(_pre_gather_for_scalar_indexing)
833
+ def _gather_for_scalar_indexing(args: GatherArgs):
834
+ """Implements 'scalar indexing into arrays' cases of lax.gather using tf.slice.
835
+
836
+ E.g., op[2], op[:, :5, :], jnp.take(op, 0, axis=0).
837
+ """
838
+ indices = tf.expand_dims(args.dnums.start_index_map, 1)
839
+ # lax.gather uses an "index map" which maps `start_indices` to the right axes
840
+ # in `operand`. Since tf.strided_slice uses a single array for specifying the
841
+ # start indices, we use a scatter to map the start indices to the right axes.
842
+ op_shape = jax2tf._eval_shape(args.op_shape)
843
+ slice_sizes_tf = jax2tf._eval_shape(args.slice_sizes)
844
+ # TODO(marcvanzee): Consider transposing `operand`, which is probably more
845
+ # optimization friendly.
846
+ begin = tf.scatter_nd(indices, args.start_indices, [len(op_shape)])
847
+ begin = _clip(op_shape, begin, slice_sizes_tf)
848
+ end = slice_sizes_tf + begin
849
+
850
+ # `collapsed_slice_dims` is a tuple of dimensions to collapse, e.g. (0, 2).
851
+ # `tf.strided_slice` expects a binary mask to specify the shrink axes, i.e.,
852
+ # if we want to shrink axis 0 and 2, this corresponds to binary mask 101,
853
+ # which is 5 in decimals. The following line converts the lax representation
854
+ # to the one used by `tf.strided_slice`.
855
+ shrink_mask = sum(2**x for x in args.dnums.collapsed_slice_dims)
856
+ res = tf.strided_slice(args.operand, begin, end, shrink_axis_mask=shrink_mask)
857
+ # Shape inference doesn't work for tf.strided_slice.
858
+ res = jax2tf._ensure_tf_shape_if_dynamic(
859
+ res, jax2tf._aval_to_tf_shape(args.out_aval)
860
+ )
861
+ return res
862
+
863
+
864
+ def _pre_gather_for_multidim_indexing(args: GatherArgs):
865
+ """Returns True if this call to gather represents multi-dimensional indexing.
866
+
867
+ E.g., jnp.take(op, [[0], [1]], axis=0).
868
+ Note we currently only support multi-dimensional indexing if the last
869
+ dimension is 1.
870
+ """
871
+ # Handle only the case when tf.gather argument batch_dims=0.
872
+ # Find axis to match the tf.gather semantics
873
+ # Let I = len(start_indices_shape)
874
+ # let O = len(op_shape)
875
+ # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:]
876
+ # collapsed_slice_dims == (axis,)
877
+ # start_index_map == (axis,)
878
+ # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1)
879
+ # We added a trailing dimension of size 1
880
+ op_shape = args.op_shape
881
+ start_index_map = args.dnums.start_index_map
882
+ collapsed_slice_dims = args.dnums.collapsed_slice_dims
883
+ offset_dims = args.dnums.offset_dims
884
+ if not (len(op_shape) >= 1 and len(start_index_map) == 1 and
885
+ len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0]
886
+ == start_index_map[0] and len(offset_dims) == len(op_shape) - 1):
887
+ raise ValueError("unsupported dimension numbers")
888
+ # We added a trailing dimension of size 1
889
+ if not core.definitely_equal(args.start_indices_shape[-1], 1):
890
+ raise ValueError("start_indices shape[-1] should be 1")
891
+ # Guess the axis
892
+ axis = collapsed_slice_dims[0]
893
+ index_dims = len(args.start_indices_shape) - 1
894
+ expected_offset_dims = tuple(
895
+ list(range(axis)) +
896
+ list(range(axis + index_dims,
897
+ len(op_shape) + index_dims - 1)))
898
+ if offset_dims != expected_offset_dims:
899
+ raise ValueError("unsupported offset_dims")
900
+ expected_slice_sizes = op_shape[:axis] + (1,) + op_shape[axis + 1:] # type: ignore
901
+ if not core.definitely_equal_shape(args.slice_sizes, expected_slice_sizes):
902
+ raise ValueError("unsupported slice_sizes")
903
+
904
+
905
+ @gather_precondition(_pre_gather_for_multidim_indexing)
906
+ def _gather_for_multidim_indexing(args: GatherArgs):
907
+ """Implements 'multi-dimensional indexing into arrays' cases of lax.gather using tf.gather.
908
+
909
+ E.g., jnp.take(op, [[0], [1]], axis=0).
910
+ """
911
+ # Guess the axis.
912
+ axis = args.dnums.collapsed_slice_dims[0]
913
+ squeezed_indices = tf.squeeze(args.start_indices, -1)
914
+ op_shape = jax2tf._eval_shape(args.op_shape)
915
+ start_indices = _clip((op_shape[axis],), squeezed_indices, (1,))
916
+ return tf.gather(args.operand, start_indices, axis=axis, batch_dims=0)
917
+
918
+
919
+ def _pre_gather_with_batch_dim(args: GatherArgs):
920
+ """Returns True if this call to gather has non-empty batch dimensions.
921
+
922
+ This is for instance triggered when doing jax.vmap(lax.dynamic_slice).
923
+ """
924
+ # We assume exactly one batch (and one or more non-batch dimensions).
925
+ if len(args.batch_dims) != 1:
926
+ raise ValueError(f"batch_dims is {len(args.batch_dims)} but should be 1")
927
+
928
+ # `start_index_map` maps indices in `start_indices` to indices in `operand`.
929
+ # For simplicity, we currently only consider the case where this mapping is
930
+ # the identity function, i.e., [2, 3] in `start_indices` maps to
931
+ # `operand[2, 3]`.
932
+ if args.dnums.start_index_map != tuple(range(args.start_indices_shape[-1])):
933
+ raise ValueError("unsupported start_index_map")
934
+
935
+ # The batch dims in `start_indices` and `operand` should agree.
936
+ if not core.definitely_equal(args.op_shape[0], args.start_indices_shape[0]):
937
+ raise ValueError("Batch dimensions in operand and start_indices don't "
938
+ "agree")
939
+
940
+
941
+ def _pre_gather_with_batch_dims(args: GatherArgs):
942
+ """Returns True if this call to gather has non-empty 2D batch dimensions.
943
+
944
+ This is for instance triggered when doing
945
+ jax.vmap(jax.vmap(lax.dynamic_slice)).
946
+ """
947
+ if len(args.dnums.collapsed_slice_dims) != 0:
948
+ # NOTE: this can be relaxed in _gather_with_batch_dims but we might
949
+ # also need to re-work the output reshaping
950
+ raise ValueError("only len(collapsed_slice_dims) == 0 is supported")
951
+
952
+ # NOTE: This supports higher dimensions than listed (the highest dimension
953
+ # in the tests is 3D so it is limited to that, but the implementation is
954
+ # designed to handle higher dimensions (N-Dimensional)).
955
+ if len(args.batch_dims) not in [1, 2, 3]:
956
+ raise ValueError(
957
+ f"Size of batch_dims is {len(args.batch_dims)} but should be up to 3"
958
+ )
959
+
960
+ @gather_precondition(_pre_gather_with_batch_dim)
961
+ def _gather_with_batch_dim(args: GatherArgs):
962
+ """Implements call to gather with non-empty batch dimensions.
963
+
964
+ E.g., when doing `jax.vmap(lax.dynamic_slice).
965
+ """
966
+ op_shape = jax2tf._eval_shape(args.op_shape)
967
+ start_indices = _clip(op_shape, args.start_indices, args.slice_sizes)
968
+ result = tf.map_fn(
969
+ lambda idxs: tf.slice(args.operand, begin=idxs, size=args.slice_sizes),
970
+ start_indices,
971
+ fn_output_signature=jax2tf._to_tf_dtype(args.operand.dtype)
972
+ )
973
+ result = tf.reshape(result, jax2tf._eval_shape(args.out_aval.shape))
974
+ return result
975
+
976
+
977
+ def _gather_generate_indices(shape: tuple[int, ...]):
978
+ """
979
+ Returns the indices of the according to `shape`:
980
+ each element in the output is the index of an element of an array
981
+ of the provided shape. The result's shape is (math.prod(shape), len(shape))
982
+
983
+ For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1)
984
+ """
985
+ return tf.reshape(
986
+ tf.stack(
987
+ tf.meshgrid(
988
+ *[tf.range(start=0, limit=x) for x in shape], indexing="ij"
989
+ ),
990
+ axis=-1,
991
+ ),
992
+ (-1, len(shape)),
993
+ )
994
+
995
+
996
+ @gather_precondition(_pre_gather_with_batch_dims)
997
+ def _gather_with_batch_dims(args: GatherArgs):
998
+ """Implements call to gather with non-empty 2D batch dimensions."""
999
+ op_shape = jax2tf._eval_shape(args.op_shape)
1000
+ output_shape = jax2tf._eval_shape(args.out_aval.shape)
1001
+ # Used to map the start_indices w.r.t start_index_map
1002
+ indices = tf.expand_dims(args.dnums.start_index_map, 1)
1003
+
1004
+ # batch_indices is shaped (N,d) where N is the number of slices and d is
1005
+ # the number of batch_dims; batch_indices_size equals to N
1006
+ batch_indices = _gather_generate_indices(
1007
+ tuple(output_shape[i] for i in args.batch_dims)
1008
+ )
1009
+ batch_indices_size = jax2tf._eval_shape(batch_indices.shape)[0]
1010
+ # offset_indices is shaped (K,d) where K is the number of elements in each
1011
+ # slice and d is the number of offset_dims; offset_indices_size equals to K
1012
+ offset_indices = _gather_generate_indices(
1013
+ tuple(output_shape[i] for i in args.dnums.offset_dims)
1014
+ )
1015
+ offset_indices_size = jax2tf._eval_shape(offset_indices.shape)[0]
1016
+
1017
+ # After we compute the result we need to reshape the axes with respect to
1018
+ # the output batch_dims and offset_dims.
1019
+ dim_mask = args.batch_dims + args.dnums.offset_dims
1020
+ mask_output_shape = tuple(output_shape[x] for x in dim_mask)
1021
+
1022
+ def get_scatter_indices(indices, batch_indices_size, size_of_index_map):
1023
+ """Generate the start indices of each slice, which index into the operand."""
1024
+ # Tile indices batch_indices_size times
1025
+ tiled_indices = tf.tile(
1026
+ tf.expand_dims(indices, 0), [batch_indices_size, 1, 1]
1027
+ )
1028
+ # The above tiles need to index the proper element of batch_indices
1029
+ # To do this generate a repeated sequence of numbers
1030
+ temp_batch_indices = tf.repeat(
1031
+ tf.range(start=0, limit=batch_indices_size), size_of_index_map
1032
+ )
1033
+ # Reshape the above sequence so it follows the same shape of tiled_indices
1034
+ temp_batch_indices = tf.reshape(
1035
+ temp_batch_indices, (batch_indices_size, size_of_index_map, 1)
1036
+ )
1037
+ # Now we concatenate to create indices offset by the temp_batch_indices
1038
+ return tf.concat([temp_batch_indices, tiled_indices], axis=-1)
1039
+
1040
+ slice_start_indices = tf.gather_nd(args.start_indices, batch_indices)
1041
+ # TODO: In the case where start_index_map is the identity we can skip this.
1042
+ scatter_indices = get_scatter_indices(
1043
+ indices, batch_indices_size, len(args.dnums.start_index_map)
1044
+ )
1045
+ # We map the scatter_indices w.r.t start_index_map
1046
+ indices_in_operand = tf.scatter_nd(
1047
+ scatter_indices, slice_start_indices, [batch_indices_size, len(op_shape)]
1048
+ )
1049
+
1050
+ # We clip the indices as OOB cases are possible when offsetting past
1051
+ # the operand boundaries
1052
+ clipped_start_indices = _clip(op_shape, indices_in_operand, args.slice_sizes)
1053
+ # Here we need to broadcast clipped_start_indices and add each of the offsets
1054
+ # which will generate a large index tensor of shape (T,d) where T is the
1055
+ # number of slices times the size of each slice (i.e total number of items
1056
+ # across all sices); d is rank(operand)
1057
+ slice_element_indices = tf.add(
1058
+ tf.repeat(clipped_start_indices, offset_indices_size, axis=0),
1059
+ tf.tile(offset_indices, (batch_indices_size, 1)),
1060
+ )
1061
+ results = tf.gather_nd(args.operand, slice_element_indices)
1062
+
1063
+ # Here results comes shaped as (N,1). Because collapsed_slice_dims is 0,
1064
+ # offset_dims is effectviely slice_sizes.
1065
+ # We reshape to mask_output_shape because if we directly reshape to the
1066
+ # output shape and our batch_dims are non-contiguous we will produce the
1067
+ # wrong shape. Reshaping to mask_output_shape gives (...,*slice_sizes),
1068
+ # which we then transpose to permute the axes in the proper way.
1069
+ # Note that if the batch_dims are contiguous this won't change the output.
1070
+ temp = tf.reshape(results, shape=mask_output_shape)
1071
+ return tf.transpose(temp, perm=tf.math.invert_permutation(dim_mask))
1072
+
1073
+ def _gather(operand, start_indices, *, dimension_numbers,
1074
+ slice_sizes: core.Shape, indices_are_sorted, unique_indices, mode,
1075
+ fill_value, _in_avals: Sequence[core.ShapedArray],
1076
+ _out_aval: core.ShapedArray):
1077
+ """Tensorflow implementation of gather."""
1078
+ if mode == lax.GatherScatterMode.FILL_OR_DROP:
1079
+ gather_fill_fn = jax2tf._convert_jax_impl(lax_slicing._gather_fill,
1080
+ multiple_results=False)
1081
+ return gather_fill_fn(
1082
+ operand, start_indices, dimension_numbers=dimension_numbers,
1083
+ slice_sizes=slice_sizes, unique_indices=unique_indices,
1084
+ indices_are_sorted=indices_are_sorted, fill_value=fill_value,
1085
+ output_shape=_out_aval.shape, _in_avals=_in_avals, _out_aval=_out_aval)
1086
+
1087
+ # TODO(marcvanzee): Check if we need more tests in shape_poly for gather with
1088
+ # enable_xla=False.
1089
+ gather_args = GatherArgs(
1090
+ operand=operand,
1091
+ start_indices=start_indices,
1092
+ dnums=dimension_numbers,
1093
+ slice_sizes=slice_sizes,
1094
+ op_shape=_in_avals[0].shape,
1095
+ start_indices_shape=_in_avals[1].shape,
1096
+ out_aval=_out_aval)
1097
+
1098
+ errors = []
1099
+
1100
+ for gather_fn in [
1101
+ _gather_for_scalar_indexing,
1102
+ _gather_for_multidim_indexing,
1103
+ _gather_with_batch_dim,
1104
+ _gather_with_batch_dims,
1105
+ ]:
1106
+ try:
1107
+ return gather_fn(gather_args)
1108
+ except ValueError as e:
1109
+ errors.append(f"{gather_fn}: {e!r}")
1110
+
1111
+ error_msg = (f"Unsupported arguments for gather: {gather_args}, errors:\n" +
1112
+ "\n".join(errors))
1113
+
1114
+ raise _error("gather", error_msg)
1115
+
1116
+
1117
+ tf_impl_no_xla[lax.gather_p] = _gather
1118
+
1119
+
1120
+ def _dynamic_slice(operand, *start_indices, slice_sizes: core.Shape,
1121
+ _in_avals: Sequence[core.ShapedArray],
1122
+ _out_aval: core.ShapedArray):
1123
+ start_indices = tf.stack(start_indices)
1124
+ slice_sizes_tf = jax2tf._eval_shape(slice_sizes)
1125
+
1126
+ operand_shape = jax2tf._eval_shape(_in_avals[0].shape)
1127
+ start_indices = _clip(operand_shape, start_indices, slice_sizes_tf)
1128
+ return tf.slice(operand, start_indices, size=slice_sizes_tf)
1129
+
1130
+
1131
+ tf_impl_no_xla[lax.dynamic_slice_p] = _dynamic_slice
1132
+
1133
+
1134
+ def _dynamic_update_slice(operand, update, *start_indices,
1135
+ _in_avals: Sequence[core.ShapedArray],
1136
+ _out_aval: core.ShapedArray):
1137
+ start_indices = tf.stack(start_indices)
1138
+
1139
+ op_shape = jax2tf._eval_shape(_in_avals[0].shape)
1140
+ op_size = tf.size(operand)
1141
+ update_shape_tf = jax2tf._eval_shape(_in_avals[1].shape)
1142
+
1143
+ start_indices = _clip(op_shape, start_indices, update_shape_tf)
1144
+ end_indices = tf.add(start_indices, update_shape_tf)
1145
+
1146
+ # Get the cells to update in `operand` as an array of ids.
1147
+ id_tensor = tf.reshape(tf.range(op_size), op_shape)
1148
+ scattered_indices = tf.strided_slice(id_tensor, start_indices, end_indices)
1149
+
1150
+ # Create an array containing updates at scattered_indices and zeros otherwise.
1151
+ flat_indices = tf.expand_dims(tf.nest.flatten(scattered_indices), -1)
1152
+ flat_update = tf.nest.flatten(update)
1153
+ update = tf.scatter_nd(flat_indices, flat_update, (op_size,))
1154
+ update = tf.reshape(update, op_shape)
1155
+
1156
+ # Create a bool mask that is True only where `operand` should be updated.
1157
+ update_mask = tf.ones_like(flat_update, dtype=tf.bool)
1158
+ update_mask = tf.scatter_nd(flat_indices, update_mask, (op_size,))
1159
+ update_mask = tf.reshape(update_mask, op_shape)
1160
+
1161
+ # Use the mask to only update `operand` with `update`.
1162
+ return tf.where(update_mask, update, operand)
1163
+
1164
+
1165
+ tf_impl_no_xla[lax.dynamic_update_slice_p] = _dynamic_update_slice
1166
+
1167
+
1168
+ def shift_axes_forward(operand,
1169
+ axes: tuple[int, ...],
1170
+ inverse: bool = False,
1171
+ forward: bool = True):
1172
+ """Shifts the tuple of axes to the front of an array"""
1173
+ other_axes = tuple(i for i in range(len(operand.shape)) if i not in axes)
1174
+ fwd_order = axes + other_axes if forward else other_axes + axes
1175
+ order = fwd_order if not inverse else _invert_permutation(fwd_order)
1176
+ return tf.transpose(operand, order)
1177
+
1178
+ def convert_scatter_jax_to_tf(update_op, unsorted_segment_op=None):
1179
+
1180
+ def _sparse_scatter(operand, scatter_indices, updates, unique_indices, mode,
1181
+ _in_avals: Sequence[core.ShapedArray],
1182
+ _out_aval: core.ShapedArray):
1183
+ """Implementation of scatter specialised to indexing from the front axes.
1184
+
1185
+ This covers unique indices and non-unique indices of single depth.
1186
+ Note on unique indices: `tf.tensor_scatter_nd_update` interprets indices
1187
+ thusly: every axis except the final one encodes a batch dimension, the final
1188
+ axis encoding the actual indices to scatter in to. It enforces, at least
1189
+ one, batch dimension so we add an empty dimension to indices and updates if
1190
+ lacking.
1191
+
1192
+ Note on non-unique indices: There is no tf op for non-single depth indexing,
1193
+ but if indexing is single depth, this can be viewed as a segment op.
1194
+ """
1195
+ # Infer unique indices from lack of batch dimension
1196
+ unique_indices = unique_indices or (len(scatter_indices.shape) == 1)
1197
+ if unique_indices:
1198
+ suboperand = tf.gather_nd(operand, scatter_indices)
1199
+ updated_suboperand = update_op(suboperand, updates)
1200
+ # add a batch dim if none exist
1201
+ if len(scatter_indices.shape) == 1:
1202
+ scatter_indices = scatter_indices[None]
1203
+ updated_suboperand = updated_suboperand[None]
1204
+ y = tf.tensor_scatter_nd_update(operand, scatter_indices, updated_suboperand)
1205
+ else:
1206
+ if (scatter_indices.shape[-1] == 1) and unsorted_segment_op:
1207
+ # If only indexing into the first dimension, it's a segment op
1208
+ operand_update = unsorted_segment_op(updates,
1209
+ tf.squeeze(scatter_indices, -1),
1210
+ operand.shape[0])
1211
+ y = update_op(operand, operand_update)
1212
+ else:
1213
+ raise _scatter_error(
1214
+ "Scatter only supports non-unique "
1215
+ "indices with indexing into only one dimension for (add, mul, min, "
1216
+ "max)")
1217
+ return y
1218
+
1219
+ def sparse_scatter(operand, scatter_indices, updates, update_jaxpr,
1220
+ update_consts, dimension_numbers, indices_are_sorted: bool,
1221
+ unique_indices: bool, mode,
1222
+ _in_avals: Sequence[core.ShapedArray],
1223
+ _out_aval: core.ShapedArray):
1224
+ """
1225
+ Wrapper around the scatter function.
1226
+ The underlying tf ops `tf.tensor_scatter_nd_update` and
1227
+ `tf.math.unsorted_segment_*` index from the front dimensions.
1228
+ `tf.math.unsorted_segment_*` indexes to a depth 1 from the front.
1229
+ `tf.tensor_scatter_nd_update` indexes from the front dimensions onwards,
1230
+ with no ability to skip a dimension. This function shifts the axes to be
1231
+ indexed to the front then calls a front-specific implementation, then
1232
+ inverse-shifts the output.
1233
+
1234
+ scatter_dims_to_operand_dims: dimensions which the scatter indexes in to.
1235
+ We shift these to the front to match tf syntax. All other dims are batch
1236
+ update_window_dims: dimensions which are not batch dimensions. We shift
1237
+ these to the back as the remaining dimensions are batch dimensions.
1238
+ """
1239
+ del update_jaxpr, update_consts, indices_are_sorted # Unused arguments
1240
+
1241
+ update_window_dims = dimension_numbers.update_window_dims
1242
+ inserted_window_dims = dimension_numbers.inserted_window_dims
1243
+ scatter_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims
1244
+
1245
+ dtype = operand.dtype # assume updates has same dtype as operand
1246
+ if dtype in [tf.bool, tf.complex64]:
1247
+ raise _scatter_error(f"Scatter does not support operands of type {dtype}")
1248
+
1249
+ if inserted_window_dims != scatter_to_operand_dims:
1250
+ raise _scatter_error("Complex scatters are not supported")
1251
+
1252
+ if (mode != lax.GatherScatterMode.FILL_OR_DROP and
1253
+ mode != lax.GatherScatterMode.PROMISE_IN_BOUNDS):
1254
+ # The OOB behavior for tf.scatter is as follows:
1255
+ # - When running in eager or graph mode, it throws an error.
1256
+ # TODO(marcvanzee): Fix this case by removing the OOB indices.
1257
+ # - When running in compile mode, the OOB indices are dropped, which is
1258
+ # the same behavior as FILL_OR_DROP and PROMISE_IN_BOUNDS.
1259
+ # To ensure correctness, we disallow CLIP mode for now.
1260
+ raise _scatter_error("Only scatter modes `FILL_OR_DROP` and "
1261
+ "`PROMISE_IN_BOUNDS` are supported.")
1262
+
1263
+ # Shift axes to the front to match tf syntax, inverse afterwards
1264
+ fwd = partial(shift_axes_forward, axes=scatter_to_operand_dims)
1265
+ inv = partial(fwd, inverse=True)
1266
+
1267
+ # Shift update value axes to the back, so batch are at the front
1268
+ updates_shifted = shift_axes_forward(
1269
+ updates, axes=update_window_dims, forward=False)
1270
+
1271
+ return inv(
1272
+ _sparse_scatter(
1273
+ fwd(operand), scatter_indices, updates_shifted, unique_indices,
1274
+ mode, _in_avals, _out_aval))
1275
+ return sparse_scatter
1276
+
1277
+
1278
+ tf_impl_no_xla[lax.scatter_p] = convert_scatter_jax_to_tf(
1279
+ lambda x, y: y) # just replace with the update
1280
+ tf_impl_no_xla[lax.scatter_add_p] = convert_scatter_jax_to_tf(tf.add, tf.math.unsorted_segment_sum)
1281
+ tf_impl_no_xla[lax.scatter_mul_p] = convert_scatter_jax_to_tf(tf.multiply, tf.math.unsorted_segment_prod)
1282
+ tf_impl_no_xla[lax.scatter_min_p] = convert_scatter_jax_to_tf(tf.minimum, tf.math.unsorted_segment_min)
1283
+ tf_impl_no_xla[lax.scatter_max_p] = convert_scatter_jax_to_tf(tf.maximum, tf.math.unsorted_segment_max)
1284
+
1285
+ tf_impl_no_xla[lax.sort_p] = _unimplemented("sort")
1286
+
1287
+ tf_impl_no_xla[lax.reduce_precision_p] = _unimplemented("reduce_precision")
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/jax2tf.py ADDED
The diff for this file is too large to render. See raw diff
 
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/back_compat_tf_test.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for backwards compatibility of custom calls involving TensorFlow.
15
+
16
+ See the back_compat_test_util module docstring for how to setup and update
17
+ these tests.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import base64
23
+ from collections.abc import Sequence
24
+ import io
25
+ import os
26
+ import tarfile
27
+ from typing import Callable, Optional
28
+
29
+ from absl.testing import absltest
30
+ import jax
31
+ from jax._src import test_util as jtu
32
+ from jax._src.internal_test_util import export_back_compat_test_util as bctu
33
+ from jax._src.lib import xla_extension
34
+ from jax.experimental import jax2tf
35
+ from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
36
+ import jax.numpy as jnp
37
+ import tensorflow as tf
38
+
39
+
40
+ jax.config.parse_flags_with_absl()
41
+
42
+
43
+ def serialize_directory(directory_path):
44
+ """Seriliaze the directory as a string."""
45
+ tar_buffer = io.BytesIO()
46
+ with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
47
+ tar.add(directory_path, arcname=os.path.basename(directory_path))
48
+
49
+ # Convert the binary data to a base64-encoded string
50
+ serialized_string = base64.b64encode(tar_buffer.getvalue())
51
+ return serialized_string
52
+
53
+
54
+ def deserialize_directory(serialized_string, output_directory):
55
+ """Deserialize the string to the directory."""
56
+ # Convert the base64-encoded string back to binary data
57
+ tar_data = base64.b64decode(serialized_string)
58
+
59
+ # Extract the tar archive to the output directory
60
+ with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
61
+ tar.extractall(output_directory)
62
+
63
+
64
+ class CompatTensoflowTest(bctu.CompatTestBase):
65
+ """Compatibility tests that use TF.
66
+
67
+ Uses tf.Graph to serialize and run the functions; expects that `func`
68
+ contains a `jax2tf.call_tf` and uses `jax2tf.convert` to generate a
69
+ `tf.Graph` containing a XlaCallModule with the actual MLIR module.
70
+ """
71
+
72
+ def run_current(self, func: Callable, data: bctu.CompatTestData):
73
+ # Here we use tf.saved_model and provide string serialize/deserialize methods
74
+ # for the whole directory.
75
+ @tf.function(autograph=False, jit_compile=True)
76
+ def tf_func(the_input): # Use recognizable names for input and result
77
+ res = jax2tf.convert(func, native_serialization=True)(the_input)
78
+ return tf.identity(res, name="the_result")
79
+
80
+ self.tf_func = tf_func
81
+ return tf_func(*data.inputs)
82
+
83
+ def serialize(
84
+ self,
85
+ func: Callable,
86
+ data: bctu.CompatTestData,
87
+ polymorphic_shapes: Sequence[str] | None = None,
88
+ allow_unstable_custom_call_targets: Sequence[str] = (),
89
+ ):
90
+ # We serialize as a tf.Graph
91
+ assert len(data.inputs) == 1 # We only support a single input now
92
+ tf_graph = self.tf_func.get_concrete_function(*data.inputs).graph
93
+ for op in tf_graph.get_operations():
94
+ if op.type == "XlaCallModule":
95
+ serialized_module = op.get_attr("module")
96
+ module_str = xla_extension.mlir.deserialize_portable_artifact(
97
+ serialized_module
98
+ )
99
+ module_version = op.get_attr("version")
100
+ break
101
+ else:
102
+ raise ValueError("Cannot find an XlaCallModule")
103
+ tf_graph_def = tf_graph.as_graph_def()
104
+ # module_str is just for human readability, add both the MLIR module
105
+ # and the tf.Graph
106
+ module_str = (
107
+ "# First the MLIR module:\n"
108
+ + module_str
109
+ + "\n# Then the tf.Graph:\n"
110
+ + str(tf_graph_def)
111
+ )
112
+ # serialized = tf_graph_def.SerializeToString()
113
+ module = tf.Module()
114
+ module.call = self.tf_func.get_concrete_function(*data.inputs)
115
+ root_dir = self.create_tempdir()
116
+ saved_model_dir = os.path.join(root_dir, "saved_model")
117
+ os.mkdir(saved_model_dir)
118
+ tf.saved_model.save(
119
+ module,
120
+ saved_model_dir,
121
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
122
+ )
123
+ serialized = serialize_directory(saved_model_dir)
124
+ nr_devices = 1
125
+ return serialized, module_str, module_version, nr_devices
126
+
127
+ def run_serialized(
128
+ self,
129
+ data: bctu.CompatTestData,
130
+ polymorphic_shapes: Sequence[str] | None = None,
131
+ ):
132
+ root_dir = self.create_tempdir()
133
+ deserialize_directory(data.mlir_module_serialized, root_dir)
134
+ saved_model_dir = os.path.join(root_dir, "saved_model")
135
+ loaded_model = tf.saved_model.load(saved_model_dir)
136
+ return (loaded_model.call(*data.inputs).numpy(),)
137
+
138
+ def test_tf_call_tf_function(self):
139
+ # A custom call tf.call_tf_function is generated when we lower call_tf
140
+ # with the call_tf_graph=True option.
141
+ def func(x):
142
+ def func_tf(x):
143
+ return tf.math.sin(x)
144
+
145
+ return jnp.cos(
146
+ jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
147
+ )
148
+
149
+ data = self.load_testdata(tf_call_tf_function.data_2023_07_29)
150
+ self.run_one_test(func, data)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ absltest.main(testLoader=jtu.JaxTestLoader())
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/jax2tf/tests/call_tf_test.py ADDED
@@ -0,0 +1,1821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for call_tf."""
15
+
16
+ import contextlib
17
+ from functools import partial
18
+ import os
19
+ from typing import Callable
20
+ import unittest
21
+
22
+ from absl import logging
23
+ from absl.testing import absltest
24
+ from absl.testing import parameterized
25
+ import jax
26
+ from jax import dlpack
27
+ from jax import dtypes
28
+ from jax import lax
29
+ from jax import numpy as jnp
30
+ from jax._src import config
31
+ from jax._src import test_util as jtu
32
+ from jax._src.lib.mlir import ir
33
+ from jax._src.lib.mlir.dialects import hlo
34
+ from jax.experimental import export
35
+ from jax.experimental import jax2tf
36
+ from jax.experimental.jax2tf.tests import tf_test_util
37
+ import numpy as np
38
+
39
+ try:
40
+ import tensorflow as tf
41
+ except ImportError:
42
+ tf = None
43
+
44
+ jax.config.parse_flags_with_absl()
45
+
46
+
47
+ def _maybe_jit(with_jit: bool, func: Callable) -> Callable:
48
+ if with_jit:
49
+ return jax.jit(func)
50
+ else:
51
+ return func
52
+
53
+ def _maybe_tf_jit(with_jit: bool, func: Callable) -> Callable:
54
+ if with_jit:
55
+ return tf.function(func, autograph=False, jit_compile=True)
56
+ else:
57
+ return func
58
+
59
+ def _named_test(**kwargs):
60
+ return dict(kwargs,
61
+ testcase_name = "_".join([f"{k}={kwargs[k]}" for k in sorted(kwargs.keys())]))
62
+
63
+ _parameterized_jit = parameterized.named_parameters(
64
+ _named_test(with_jit=with_jit)
65
+ for with_jit in [True, False])
66
+
67
+ _call_tf_non_compilable_error = "Error compiling TensorFlow function"
68
+ _call_tf_dynamic_shape_error = "call_tf cannot call functions whose output has dynamic shape"
69
+
70
+ class CallTfTest(tf_test_util.JaxToTfTestCase):
71
+
72
+ @classmethod
73
+ def setUpClass(cls):
74
+ # One TF device of each device_type
75
+ cls.tf_devices = []
76
+ for tf_device in tf.config.list_logical_devices():
77
+ if tf_device.device_type == "TPU_SYSTEM":
78
+ continue # A virtual device
79
+ if all(tf_device.device_type != d.device_type for d in cls.tf_devices):
80
+ cls.tf_devices.append(tf_device)
81
+
82
+ super().setUpClass()
83
+
84
+ def setUp(self):
85
+ if tf is None:
86
+ raise unittest.SkipTest("Test requires tensorflow")
87
+ # TODO(b/171320191): this line works around a missing context initialization
88
+ # bug in TensorFlow.
89
+ _ = tf.add(1, 1)
90
+ super().setUp()
91
+
92
+ @_parameterized_jit
93
+ def test_eval_scalar_arg(self, with_jit=True):
94
+ def f_tf(x):
95
+ return tf.math.sin(x)
96
+ x = 3.
97
+ res = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
98
+ self.assertAllClose(jnp.sin(x), res)
99
+
100
+ @_parameterized_jit
101
+ def test_eval_scalar_res(self, with_jit=True):
102
+ x = 3.
103
+ res = _maybe_jit(with_jit, jax2tf.call_tf(lambda x: 4.))(x)
104
+ self.assertAllClose(4., res, check_dtypes=False)
105
+
106
+ @_parameterized_jit
107
+ def test_eval_numpy_arg(self, with_jit=True):
108
+ x = np.ones((2, 3), dtype=np.float32)
109
+ res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
110
+ self.assertAllClose(jnp.sin(x), res)
111
+
112
+ @_parameterized_jit
113
+ def test_eval_numpy_res(self, with_jit=False):
114
+ x = np.ones((2, 3))
115
+ res = _maybe_jit(with_jit, jax2tf.call_tf(lambda _: x))(x)
116
+ self.assertAllClose(x, res)
117
+
118
+ @_parameterized_jit
119
+ def test_eval_devicearray_arg(self, with_jit=False):
120
+ x = jnp.ones((2, 3), dtype=np.float32)
121
+ res = _maybe_jit(with_jit, jax2tf.call_tf(tf.math.sin))(x)
122
+ self.assertAllClose(jnp.sin(x), res)
123
+
124
+ x = jnp.array(3.0, dtype=jnp.bfloat16)
125
+ res = jax2tf.call_tf(lambda x: x)(x)
126
+ self.assertAllClose(x, res)
127
+ # bfloat16 scalar will create a copy.
128
+ with self.assertRaises(AssertionError):
129
+ self.assertTrue(np.shares_memory(x, res))
130
+
131
+ @_parameterized_jit
132
+ def test_eval_pytree(self, with_jit=True):
133
+
134
+ def fun_tf(x: dict, y: tuple) -> tuple:
135
+ return (x["first"] * x["second"], y[0] + y[1])
136
+
137
+ x = dict(first=np.float32(3.), second=np.float32(4.))
138
+ y = (np.float64(5.), np.float64(6.))
139
+ fun_jax = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))
140
+ res = fun_jax(x, y)
141
+ self.assertAllClose((np.float32(12.), np.float64(11.)), res)
142
+
143
+ def test_result_tuple(self):
144
+ x1 = np.ones(3, dtype=np.int32)
145
+ x2 = np.ones(5, dtype=np.float32)
146
+ def fun_tf():
147
+ return tf.tuple([x1, x2])
148
+
149
+ fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
150
+ res = fun_jax()
151
+ self.assertAllClose(res, (x1, x2))
152
+
153
+ def test_error_non_compilable_strings(self):
154
+ # Check that in op-by-op we call a function in eager mode.
155
+ def f_tf_non_compilable(x):
156
+ return tf.strings.length(tf.strings.format("Hello {}!", [x]))
157
+
158
+ f_jax = jax2tf.call_tf(f_tf_non_compilable)
159
+ x = np.float32(0.7)
160
+ self.assertAllClose(f_tf_non_compilable(x).numpy(), f_jax(x))
161
+ with self.assertRaisesRegex(ValueError,
162
+ _call_tf_non_compilable_error):
163
+ jax.jit(f_jax)(x)
164
+
165
+ with self.assertRaisesRegex(ValueError,
166
+ _call_tf_non_compilable_error):
167
+ lax.cond(True, lambda x: f_jax(x), lambda x: f_jax(x), x)
168
+
169
+ def test_error_non_compilable_dynamic_shape(self):
170
+ # Check that in op-by-op we call a function in eager mode.
171
+ def f_tf_non_compilable(x):
172
+ return tf.cond(x[0], lambda: x[1:], lambda: x)
173
+
174
+ f_jax = jax2tf.call_tf(f_tf_non_compilable)
175
+ x = np.array([True, False], dtype=np.bool_)
176
+ self.assertAllClose(f_tf_non_compilable(x), f_jax(x)) # Works in eager mode
177
+
178
+ with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
179
+ jax.jit(f_jax)(x)
180
+
181
+ def test_error_bad_result_tensorarray(self):
182
+ # Call a function that returns a tf.TensorArray. This should be detected
183
+ # early on. If we don't the function is actually compilable but returns
184
+ # a tuple instead of a single result.
185
+ def fun_tf():
186
+ ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
187
+ ta = ta.unstack([0, 1, 2, 3, 4])
188
+ return ta
189
+
190
+ with self.assertRaisesRegex(ValueError,
191
+ "The called TF function returns a result that is not convertible to JAX"):
192
+ fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
193
+ fun_jax()
194
+
195
+ def test_error_bad_result_string(self):
196
+ def fun_tf():
197
+ return tf.constant("foo")
198
+
199
+ # Now under jit, should fail because the function is not compilable
200
+ with self.assertRaisesRegex(ValueError,
201
+ "The called TF function returns a result that is not convertible to JAX"):
202
+ fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
203
+ fun_jax()
204
+
205
+ @_parameterized_jit
206
+ def test_control_flow(self, with_jit=True):
207
+
208
+ def times_5_tf(x):
209
+ # Multiply x * 5 using a loop
210
+ c = lambda i, acc: tf.less(i, 5)
211
+ b = lambda i, acc: (tf.add(i, 1), tf.add(acc, x))
212
+ _, acc = tf.while_loop(c, b, [tf.constant(0), tf.constant(0.)])
213
+ return acc
214
+
215
+ def fun_jax(x):
216
+ # Calls times_5_tf 3 times in a loop
217
+ def body(_, acc):
218
+ return jax2tf.call_tf(times_5_tf)(acc)
219
+
220
+ return lax.fori_loop(0, 3, body, x)
221
+
222
+ x = np.float32(3.)
223
+ res = _maybe_jit(with_jit, fun_jax)(x)
224
+ self.assertAllClose(np.float32(x * 5 * 5 * 5), res)
225
+
226
+ @parameterized.named_parameters(
227
+ dict(
228
+ testcase_name=f"_{dtype.__name__}{'_jit' if with_jit else ''}",
229
+ dtype=dtype,
230
+ with_jit=with_jit)
231
+ for dtype in set(jtu.dtypes.all) - {np.bool_}
232
+ for with_jit in [True, False])
233
+ def test_dtypes(self, dtype=np.int32, with_jit=True):
234
+
235
+ def fun_tf(x):
236
+ # AddV2 supports more types
237
+ return tf.raw_ops.AddV2(x=x, y=tf.constant(3, dtype=dtype))
238
+
239
+ def fun_jax(x):
240
+ return jax2tf.call_tf(fun_tf)(x) + x
241
+
242
+ x = np.ones((3,), dtype=dtype)
243
+ res = _maybe_jit(with_jit, fun_jax)(x)
244
+ self.assertAllClose(dtype(2 * x + 3), res)
245
+
246
+ @_parameterized_jit
247
+ def test_bool(self, with_jit=False):
248
+
249
+ def fun_tf(x, y):
250
+ return tf.math.logical_and(x, y)
251
+
252
+ x = np.array([True, False, True, False], dtype=np.bool_)
253
+ y = np.array([True, True, False, False], dtype=np.bool_)
254
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x, y)
255
+ self.assertAllClose(
256
+ np.array([True, False, False, False], dtype=np.bool_), res)
257
+
258
+ @_parameterized_jit
259
+ def test_x64_input(self, with_jit=True):
260
+ def f_tf(x):
261
+ return tf.math.sin(x)
262
+
263
+ x = 5. # TF interprets this as f64
264
+ res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
265
+ res_jax = jnp.sin(x)
266
+ self.assertAllClose(res_call_tf, res_jax)
267
+
268
+ @_parameterized_jit
269
+ def test_x64_output(self, with_jit=True):
270
+ def f_tf(x):
271
+ return (tf.constant(3., tf.float64), x)
272
+
273
+ x = np.float32(5.)
274
+ res_call_tf = _maybe_jit(with_jit, jax2tf.call_tf(f_tf))(x)
275
+ res_jax = (3., x)
276
+ self.assertAllClose(res_call_tf, res_jax)
277
+
278
+ res_call_tf_jit = jax.jit(jax2tf.call_tf(f_tf))(x)
279
+ self.assertAllClose(res_call_tf_jit, res_jax)
280
+
281
+ @_parameterized_jit
282
+ def test_with_var_read(self, with_jit=True):
283
+ # The variable is placed on the default TF device.
284
+ outer_var_array = np.array([3., 4.], dtype=np.float32)
285
+ outer_var = tf.Variable(outer_var_array)
286
+
287
+ def fun_tf(x):
288
+ return x * outer_var + 1.
289
+
290
+ x = np.array([2., 5.,], dtype=np.float32)
291
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
292
+ self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
293
+
294
+ @_parameterized_jit
295
+ def test_with_var_read_x64(self, with_jit=True):
296
+ outer_var_array = np.array([3., 4.], dtype=np.float64)
297
+ outer_var = tf.Variable(outer_var_array)
298
+
299
+ def fun_tf(x):
300
+ return x * tf.cast(outer_var, x.dtype) + 1.
301
+
302
+ x = np.array([2., 5.,], dtype=np.float32)
303
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
304
+ self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
305
+
306
+ def test_with_var_different_shape(self):
307
+ # See https://github.com/google/jax/issues/6050
308
+ v = tf.Variable((4., 2.), dtype=tf.float32)
309
+
310
+ def tf_func(x):
311
+ return v + x
312
+ x = np.float32(123.)
313
+ tf_out = tf_func(x)
314
+
315
+ jax_func = jax.jit(jax2tf.call_tf(tf_func))
316
+ jax_out = jax_func(x)
317
+
318
+ self.assertAllClose(tf_out, jax_out, check_dtypes=False)
319
+
320
+ @_parameterized_jit
321
+ def test_with_var_write_error(self, with_jit=True):
322
+ if with_jit:
323
+ raise unittest.SkipTest("variable writes not yet working")
324
+ outer_var = tf.Variable(3., dtype=np.float32)
325
+
326
+ def fun_tf(x):
327
+ outer_var.assign(tf.constant(4.))
328
+ return x * outer_var + 1.
329
+
330
+ x = np.float32(2.)
331
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
332
+ self.assertAllClose(x * 4. + 1, res, check_dtypes=False)
333
+
334
+ @_parameterized_jit
335
+ def test_with_tensor_capture(self, with_jit=True):
336
+ outer_tensor = tf.constant(3., dtype=np.float32)
337
+
338
+ def fun_tf(x):
339
+ return x * outer_tensor + 1.
340
+
341
+ x = np.float32(2.)
342
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
343
+ self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
344
+
345
+ @_parameterized_jit
346
+ def test_with_tensor_capture_x64(self, with_jit=True):
347
+ outer_tensor = tf.constant(3., dtype=np.float64)
348
+
349
+ def fun_tf(x):
350
+ return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1.
351
+
352
+ x = np.float32(2.)
353
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
354
+ self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
355
+
356
+ @_parameterized_jit
357
+ def test_with_value_capture(self, with_jit=True):
358
+ outer_val = np.array(3., dtype=np.float32)
359
+
360
+ def fun_tf(x):
361
+ return x * outer_val + 1.
362
+
363
+ x = np.float32(2.)
364
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
365
+ self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
366
+
367
+ @_parameterized_jit
368
+ def test_with_multiple_capture(self, with_jit=True):
369
+ if jtu.test_device_matches(["gpu"]):
370
+ raise unittest.SkipTest("Test fails on GPU")
371
+ v2 = tf.Variable(2., dtype=np.float32)
372
+ v3 = tf.Variable(3., dtype=np.float32)
373
+ t4 = tf.constant(4., dtype=np.float32)
374
+ t5 = tf.constant(5., dtype=np.float32)
375
+
376
+ def fun_tf(x):
377
+ return (x * v3 + t4 + v2) * v3 + t5
378
+
379
+ x = np.float32(2.)
380
+ res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
381
+ self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False)
382
+
383
+ @_parameterized_jit
384
+ def test_grad(self, with_jit=False):
385
+ x = np.float32(3.)
386
+ res = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(tf.math.sin)))(x)
387
+ self.assertAllClose(np.cos(x), res)
388
+
389
+ @_parameterized_jit
390
+ def test_grad_pytree(self, with_jit=False):
391
+
392
+ def fun_tf(x: dict, y: tuple) -> tuple:
393
+ return x["first"] * x["second"] + 3. * y[0] + 4. * y[1]
394
+
395
+ x = dict(first=np.float32(3.), second=np.float32(4.))
396
+ y = (np.float32(5.), np.float32(6.))
397
+ grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(fun_tf)))(x, y)
398
+ self.assertAllClose(
399
+ dict(first=np.float32(4.), second=np.float32(3.)), grad_x)
400
+
401
+ def test_grad_nested(self):
402
+ # We embed the call_tf function in a larger function whose gradient we take
403
+ # It is relevant here that the cotangents flowing through the call_tf
404
+ # function are not scalars.
405
+
406
+ b = np.array([[11., 12., 13.], [21., 22., 23.]], dtype=np.float32) # [2, 3]
407
+ c = np.array([[31., 32.], [41., 42.], [51., 52.], [61., 62.]], dtype=np.float32) # [4, 2]
408
+ x_dict = dict(b=b, c=c) # b:[2, 3], c=[4, 2]
409
+ # res: dict(r:[4, 3], s:[4, 2])
410
+ def f_tf(x_dict):
411
+ return dict(r=tf.matmul(x_dict["c"], x_dict["b"]), s=7. * x_dict["c"])
412
+
413
+ @jax.jit # To recognize it in jaxpr
414
+ def f_jax(x_dict):
415
+ return dict(r=jnp.matmul(x_dict["c"], x_dict["b"]), s=7. * x_dict["c"])
416
+
417
+ def loss(functional, x_dict):
418
+ prediction = functional(x_dict) # r:[4, 3], s:[4, 2]
419
+ weights = np.array([1., 2., 3., 4.], dtype=np.float32) # [4]
420
+ weighted_pred = jnp.matmul(weights, prediction["r"]) # [3]
421
+ return jnp.sum(weighted_pred) + 4. * jnp.sum(prediction["s"])
422
+
423
+ g_fun_with_tf = jax.grad(partial(loss, jax2tf.call_tf(f_tf)))
424
+ g_fun_with_jax = jax.grad(partial(loss, f_jax))
425
+
426
+ g_tf = g_fun_with_tf(x_dict)
427
+ g_jax = g_fun_with_jax(x_dict)
428
+ self.assertAllClose(g_jax, g_tf)
429
+
430
+ def test_grad_int_argument(self):
431
+ # Similar to https://github.com/google/jax/issues/6975
432
+ # state is a pytree that contains an integer and a boolean.
433
+ # The function returns an integer and a boolean.
434
+ def f(param, state, x):
435
+ return param * x, state
436
+
437
+ param = np.array([0.7, 0.9], dtype=np.float32)
438
+ state = dict(array=np.float32(1.), counter=7, truth=True)
439
+ x = np.float32(3.)
440
+
441
+ # tf.function is important, without it the bug does not appear
442
+ f_call_tf = jax2tf.call_tf(f)
443
+ g_call_tf = jax.grad(lambda *args: jnp.sum(f_call_tf(*args)[0]))(param, state, x)
444
+ g = jax.grad(lambda *args: jnp.sum(f(*args)[0]))(param, state, x)
445
+ self.assertAllClose(g_call_tf, g)
446
+
447
+ def test_grad_int_argument_unused(self):
448
+ batch_size = 5
449
+ inputs = np.ones((batch_size, 3), dtype=np.float32)
450
+ rng = np.array([1, 2], dtype=np.uint32)
451
+ params = np.float32(.5)
452
+
453
+ # rng is integer, unused
454
+ def jax_model(params, rng, inputs):
455
+ return jnp.ones([batch_size, 2], dtype=jnp.float32)
456
+
457
+ tf_model = jax2tf.convert(jax_model, with_gradient=True)
458
+
459
+ def _loss_fn(inference_fn, params, rng, inputs):
460
+ prediction = inference_fn(params, rng, inputs)
461
+ return jnp.mean(prediction)
462
+
463
+ jax_loss_fn = partial(_loss_fn, jax_model)
464
+ jax_grad = jax.grad(jax_loss_fn)(params, rng, inputs)
465
+
466
+ paramsv = tf.Variable(params)
467
+ with tf.GradientTape() as tape:
468
+ tf_prediction = tf_model(paramsv, rng, inputs)
469
+ tf_loss = tf.reduce_mean(tf_prediction)
470
+
471
+ tf_grad = tape.gradient(tf_loss, paramsv)
472
+ self.assertAllClose(jax_grad, tf_grad.numpy())
473
+
474
+ call_tf_loss_fn = partial(_loss_fn, jax2tf.call_tf(tf_model))
475
+ call_tf_grad = jax.grad(call_tf_loss_fn)(params, rng, inputs)
476
+ self.assertAllClose(jax_grad, call_tf_grad)
477
+
478
+ def test_grad_with_float0_result(self):
479
+ # Gradient over integer-argument functions, with float0 result
480
+ def f_jax(x, y): # x is an int, y is a float; res is a (int, float)
481
+ return (2 * x, 2 * x + y * y)
482
+ def f_tf(x, y):
483
+ # TF needs explicit casts
484
+ return (2 * x, tf.cast(2 * x, dtype=y.dtype) + y * y)
485
+
486
+ def wrapper(functional, x, y): # x: i32
487
+ return jnp.sum(2. * functional(3 * x, 4. * y)[1])
488
+
489
+ grad_g = jax.grad(partial(wrapper, f_jax),
490
+ allow_int=True, argnums=(0, 1))
491
+ grad_g_call_tf = jax.grad(partial(wrapper, jax2tf.call_tf(f_tf)),
492
+ allow_int=True, argnums=(0, 1))
493
+
494
+ x = np.int32(2)
495
+ y = np.float32(3.)
496
+ g_jax = grad_g(x, y)
497
+ g_call_tf = grad_g_call_tf(x, y)
498
+ self.assertEqual(g_jax[0].dtype, dtypes.float0)
499
+ self.assertEqual(g_call_tf[0].dtype, dtypes.float0)
500
+ self.assertAllClose(g_jax[1], g_call_tf[1])
501
+
502
+ @_parameterized_jit
503
+ def test_grad_custom(self, with_jit=False):
504
+
505
+ @tf.custom_gradient
506
+ def func_square_tf(x):
507
+ # Like x ** 2, but with custom grad 3. * x
508
+ def grad(dy, variables=None):
509
+ # dy, = dys
510
+ return 3. * x * dy,
511
+
512
+ return x * x, grad
513
+
514
+ x = np.float32(4.)
515
+ grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(func_square_tf)))(x)
516
+ self.assertAllClose(np.float32(3.) * x, grad_x)
517
+
518
+ @parameterized.named_parameters(
519
+ dict(
520
+ testcase_name=f"_{degree=}{'_jit' if with_jit else ''}",
521
+ degree=degree,
522
+ with_jit=with_jit)
523
+ for degree in [1, 2, 3, 4]
524
+ for with_jit in [True, False])
525
+ def test_higher_order_grad(self, degree=2, with_jit=False):
526
+
527
+ def fun_tf(x):
528
+ return 2. * x * x * x
529
+
530
+ def fun_jax(x):
531
+ return 3. * _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
532
+
533
+ def fun_jax_pure(x):
534
+ return 3. * fun_tf(x)
535
+
536
+ grad_jax = fun_jax
537
+ grad_jax_pure = fun_jax_pure
538
+ for _ in range(degree):
539
+ grad_jax = jax.grad(grad_jax)
540
+ grad_jax_pure = jax.grad(grad_jax_pure)
541
+
542
+ res_jax = grad_jax(np.float32(5.))
543
+ logging.info("Grad of %s degree is %s", degree, res_jax)
544
+ self.assertAllClose(res_jax, grad_jax_pure(np.float32(5.)))
545
+
546
+ def test_pmap(self):
547
+ logging.info("Running test_pmap on %s devices", jax.local_device_count())
548
+
549
+ def plus_2_tf(x):
550
+ return tf.math.add(2., x)
551
+
552
+ def fun_jax(x):
553
+ return np.float32(3.) * jax2tf.call_tf(plus_2_tf)(x)
554
+
555
+ x = np.arange(jax.local_device_count(), dtype=np.float32)
556
+ res = jax.pmap(fun_jax)(x)
557
+ self.assertAllClose(np.float32(3. * (x + 2)), res)
558
+
559
+ def test_function_compile_time_constant_inputs(self):
560
+ # Call a function for which shape inference does not give an output
561
+ # shape.
562
+ x = np.array([1, 2, 3], dtype=np.int32)
563
+ def fun_tf(x): # x:i32[3]
564
+ # Indexing with a dynamic slice makes the TF shape inference return
565
+ # a partially known shape.
566
+ end_idx = x[1]
567
+ res = x[0:end_idx]
568
+ return res
569
+
570
+ # Call in eager mode. Should work!
571
+ res1 = jax2tf.call_tf(fun_tf)(x)
572
+ self.assertAllClose(x[0:x[1]], res1)
573
+
574
+ # Now under jit, should fail because the function is not compilable
575
+ with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
576
+ fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
577
+ fun_jax(x)
578
+
579
+ def test_experimental_get_compiler_ir_design_doc(self):
580
+ # Not a test of call_tf, but more of how experimental_get_compiler_ir works.
581
+ # Examples are from the design doc.
582
+
583
+ # Constant slice. This is the common case.
584
+ x = np.zeros((10,), dtype=np.int32)
585
+
586
+ def fun_tf(x):
587
+ begin = 0
588
+ return x[begin:5]
589
+
590
+ hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
591
+ self.assertIn("(arg0.1: s32[10]) -> s32[5]", hlo)
592
+
593
+ # Non-constant slice, but compile-time constant depending only on values.
594
+ x = np.zeros((10,), dtype=np.int32)
595
+
596
+ # Non-constant slice, but compile-time constant depending only on shapes.
597
+ x = np.zeros((10,), dtype=np.int32)
598
+
599
+ def fun_tf(x):
600
+ begin = tf.shape(x)[0] - 2 # begin is a compile-time constant, even if x is not
601
+ return x[begin:]
602
+
603
+ hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
604
+ self.assertIn("(arg0.1: s32[10]) -> s32[2]", hlo)
605
+
606
+ # Capture a variable
607
+ outer_var = tf.Variable(np.array([3.], dtype=np.float32))
608
+ x = np.array([2., 3., 4.], dtype=np.float32)
609
+
610
+ def fun_tf(x):
611
+ return x * tf.broadcast_to(outer_var, x.shape) + 1.
612
+
613
+ hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
614
+ self.assertIn("(arg0.1: f32[3], arg1.2: f32[1]) -> f32[3]", hlo)
615
+
616
+ # Capture a constant
617
+ outer_ct = np.array([3.], dtype=np.float32)
618
+ x = np.array([2., 3., 4.], dtype=np.float32)
619
+
620
+ def fun_tf(x):
621
+ return x * tf.broadcast_to(outer_ct, x.shape) + 1.
622
+
623
+ hlo = tf.function(fun_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir(x)()
624
+ self.assertIn("(arg0.1: f32[3]) -> f32[3]", hlo)
625
+
626
+ # Call get_compiler_ir in a function context
627
+ x = np.array([2., 3., 4.], dtype=np.float32)
628
+
629
+ def fun_tf_outer(x):
630
+ x_const = tf.constant(0, shape=x.shape, dtype=x.dtype)
631
+ _ = tf.function(tf.math.sin, jit_compile=True, autograph=False).experimental_get_compiler_ir(x_const)()
632
+
633
+ # TODO(b/193754660)
634
+ # with self.assertRaisesRegex(
635
+ # TypeError, "An op outside of the function building code is being passed"):
636
+ # tf.function(fun_tf_outer)(x)
637
+ #
638
+ # with self.assertRaisesRegex(
639
+ # TypeError, "An op outside of the function building code is being passed"):
640
+ # tf.function(fun_tf_outer, jit_compile=True)(x)
641
+
642
+ # Call get_concrete_function in a graph context
643
+ def fun_tf_outer_2(x):
644
+ _ = tf.function(tf.math.sin, jit_compile=True).get_concrete_function(tf.TensorSpec(x.shape, x.dtype))
645
+ return x
646
+
647
+ # Outside of a function context, this works.
648
+ _ = tf.function(fun_tf_outer_2)(x)
649
+ _ = tf.function(fun_tf_outer_2, jit_compile=True)(x)
650
+
651
+ def test_repro_193754660(self):
652
+ # Try to reproduce b/193754660. I can't.
653
+ # We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))).
654
+ # The get_compiler_ir will indeed fail for f_tf. Then we try to use
655
+ # shape inference for f_tf.
656
+ # I thought to use a f_tf that uses an op without shape inference, e.g.,
657
+ # tfxla.gather. If we wash it through a saved_model I expect that shape
658
+ # inference would not work on it. Instead, shape inference works!!!
659
+ x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)
660
+ def f_jax(x):
661
+ return x[1]
662
+ f_tf = jax2tf.convert(f_jax)
663
+ f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
664
+ f_jax2 = jax2tf.call_tf(f_tf_rt)
665
+ f_tf2 = jax2tf.convert(f_jax2)
666
+ res = tf.function(f_tf2, autograph=False)(x)
667
+ self.assertAllClose(res.numpy(), f_jax(x))
668
+
669
+ def test_effectful(self):
670
+ x = np.ones((3,), dtype=np.float32)
671
+ lower_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=True)).lower(x)
672
+ self.assertNotEmpty(lower_effect._lowering.compile_args["unordered_effects"])
673
+
674
+ lower_no_effect = jax.jit(jax2tf.call_tf(tf.math.sin, has_side_effects=False)).lower(x)
675
+ self.assertEmpty(lower_no_effect._lowering.compile_args["unordered_effects"])
676
+
677
+ def test_module_documentation(self):
678
+ def cos_tf(x):
679
+ return tf.math.cos(x)
680
+
681
+ # Compute cos with TF and sin with JAX
682
+ def cos_tf_sin_jax(x):
683
+ return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x))
684
+
685
+ # Calls `cos_tf` in TF eager mode
686
+ x = np.float32(1.)
687
+ cos_tf_sin_jax(x)
688
+
689
+ # Compiles `cos_tf` using TF and embeds the XLA computation into the JAX
690
+ # XLA computation (containing `sin`). The XLA compiler may even be able to
691
+ # fuse through JAX-TF computations.
692
+ jax.jit(cos_tf_sin_jax)(x)
693
+
694
+ # Uses TF gradient for `cos_tf` and JAX gradient for `sin`
695
+ jax.grad(cos_tf_sin_jax)(x)
696
+
697
+ logging.info(jax.make_jaxpr(cos_tf_sin_jax)(x))
698
+ logging.info(jax.xla_computation(cos_tf_sin_jax)(x).as_hlo_text())
699
+
700
+ def test_tf_gather(self):
701
+ """tf_gather gradient output is tf.IndexSlices."""
702
+ operand = jnp.array(np.random.uniform(size=(100, 128)))
703
+ indices = jnp.array(np.random.randint(low=0, high=100, size=(4000,)))
704
+
705
+ @tf.function(jit_compile=True, autograph=False)
706
+ def fun_tf(operand, indices):
707
+ return tf.experimental.numpy.std(tf.gather(operand, indices))
708
+
709
+ fun_jax = jax2tf.call_tf(fun_tf)
710
+ grad_fun_jax = jax.grad(fun_jax)
711
+ grad_res = grad_fun_jax(operand, indices)
712
+ self.assertEqual(grad_res.shape, (100, 128))
713
+
714
+ def test_output_shape_dtype_none(self):
715
+ x = jnp.zeros((10), dtype=jnp.float32)
716
+
717
+ @tf.function(jit_compile=True, autograph=False)
718
+ def fun_tf(x): # pylint: disable=unused-argument
719
+ return
720
+
721
+ fun_jax_1 = jax2tf.call_tf(fun_tf, output_shape_dtype=None)
722
+ fun_jax_2 = jax2tf.call_tf(fun_tf)
723
+ self.assertIsNone(fun_jax_1(x))
724
+ self.assertIsNone(fun_jax_2(x))
725
+ fun_jax_3 = jax2tf.call_tf(
726
+ fun_tf, output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32)
727
+ )
728
+ with self.assertRaisesRegex(
729
+ ValueError,
730
+ "The pytree of the TensorFlow function results does not match the"
731
+ " pytree of the declared output_shape_dtype",
732
+ ):
733
+ _ = fun_jax_3(x)
734
+
735
+ def test_output_shape_dtype_not_none(self):
736
+ x = jnp.zeros((10), dtype=jnp.float32)
737
+
738
+ @tf.function(jit_compile=True, autograph=False)
739
+ def fun_tf(x):
740
+ return x
741
+
742
+ fun_jax_1 = jax2tf.call_tf(
743
+ fun_tf, output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32)
744
+ )
745
+ fun_jax_2 = jax2tf.call_tf(fun_tf)
746
+ self.assertAllClose(fun_jax_1(x), fun_jax_2(x))
747
+
748
+ fun_jax_3 = jax2tf.call_tf(fun_tf, output_shape_dtype=None)
749
+ with self.assertRaisesRegex(
750
+ ValueError,
751
+ "The pytree of the TensorFlow function results does not match the"
752
+ " pytree of the declared output_shape_dtype",
753
+ ):
754
+ _ = fun_jax_3(x)
755
+
756
+ def test_multi_platform(self):
757
+ def tf_fun(x):
758
+ return tf.math.sin(x)
759
+
760
+ def f_jax(x):
761
+ return jnp.cos(jax2tf.call_tf(tf_fun)(jnp.cos(x)))
762
+ x = np.arange(12, dtype=np.float32).reshape((3, 4))
763
+
764
+ # Find platforms that are available for both JAX and TF
765
+ # Pick one device from each available platform
766
+ jax_platforms = []
767
+ for backend in ["cpu", "gpu", "tpu"]:
768
+ try:
769
+ devices = jax.devices(backend)
770
+ except RuntimeError:
771
+ devices = []
772
+ if devices:
773
+ jax_platforms.append(devices[0].platform)
774
+
775
+ jax_and_tf_platforms = (
776
+ set(jax_platforms) & {d.device_type.lower()
777
+ for d in self.__class__.tf_devices})
778
+
779
+ lowering_platforms = ("tpu", "cpu", "cuda")
780
+
781
+ exp = export.export(f_jax,
782
+ lowering_platforms=lowering_platforms)(x)
783
+ for jax_platform in jax_and_tf_platforms:
784
+ with self.subTest(jax_platform):
785
+ jax_device = jax.devices(jax_platform)[0]
786
+ x_device = jax.device_put(x, jax_device)
787
+ logging.info("Running harness natively on %s", jax_device)
788
+ native_res = f_jax(x_device)
789
+ logging.info("Running exported harness on %s", jax_device)
790
+ exported_res = export.call(exp)(x_device)
791
+ self.assertAllClose(native_res, exported_res)
792
+
793
+ def test_multi_platform_call_tf_graph(self):
794
+ def tf_fun(x):
795
+ return tf.math.sin(x)
796
+
797
+ def f_jax(x):
798
+ return jnp.cos(jax2tf.call_tf(tf_fun,
799
+ call_tf_graph=True,
800
+ ordered=True)(jnp.cos(x)))
801
+ x = np.arange(12, dtype=np.float32).reshape((3, 4))
802
+ # When we use call_tf_graph we can serialize for multiple platforms
803
+ lowering_platforms = ("tpu", "cpu", "cuda")
804
+ # We must use jax2tf.convert to run a call_tf(call_tf_graph)
805
+ # TODO(necula): if we remove the tf.function and we have multiple platforms
806
+ # then we attempt to lower call_tf multiple times and only the first
807
+ # lowering will have the proper side effects for the function_list.
808
+ f_tf = tf.function(jax2tf.convert(
809
+ f_jax,
810
+ native_serialization=True,
811
+ native_serialization_platforms=lowering_platforms))
812
+ for tf_device in self.__class__.tf_devices:
813
+ with self.subTest(tf_device.device_type):
814
+ logging.info(
815
+ f"Running on tf_device = {tf_device} of device_type = {tf_device.device_type}")
816
+ with tf.device(tf_device):
817
+ res = f_tf(x)
818
+ self.assertAllClose(res, f_jax(x))
819
+
820
+ @parameterized.named_parameters(
821
+ {"testcase_name": f"_type={type_.__name__}", "type_": type_}
822
+ for type_ in dlpack.SUPPORTED_DTYPES
823
+ )
824
+ def test_avoid_copy_between_gpu_and_cpu(self, type_):
825
+ try:
826
+ gpu_devices = jax.devices("gpu")
827
+ except RuntimeError:
828
+ gpu_devices = []
829
+ if not gpu_devices:
830
+ raise unittest.SkipTest("Test requires a GPU device.")
831
+
832
+ def tf_fun(x):
833
+ if type_ == jnp.bool_:
834
+ return tf.math.logical_or(x, True)
835
+ else:
836
+ return x + 1
837
+
838
+ jax_array_on_gpu = jnp.zeros([1], type_, device=gpu_devices[0])
839
+
840
+ # Since the input array is already on a GPU device, we expect that no memory
841
+ # copy occurs between GPU and CPU. Thus, we expect no errors raised by the
842
+ # transfer guard.
843
+ # There are two exceptions:
844
+ # First, when dtype is "int32". This is because almost all TensorFlow
845
+ # kernels for GPU devices keep int32 tensors in host memory.
846
+ # (https://github.com/tensorflow/tensorflow/blob/4eb3e36d1b0cd511e1677e740bd093f42365cf9f/tensorflow/python/eager/pywrap_tensor.cc#L352-L354)
847
+ # Hence, for "int32", we do expect a "host-to-device" copy.
848
+ # Second, when using PJRT C API runtime. This is because it currently skips dlpack
849
+ # to workaround "PJRT C API does not support GetDefaultLayout" runtime error.
850
+ # https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289
851
+ @contextlib.contextmanager
852
+ def _transfer_guard(guard_level):
853
+ with contextlib.ExitStack() as stack:
854
+ stack.enter_context(jax.transfer_guard_device_to_device(guard_level))
855
+ stack.enter_context(jax.transfer_guard_device_to_host(guard_level))
856
+ if type_ != jnp.int32:
857
+ stack.enter_context(jax.transfer_guard_host_to_device(guard_level))
858
+ yield
859
+
860
+ with _transfer_guard("disallow_explicit"):
861
+ jax2tf.call_tf(tf_fun)(jax_array_on_gpu)
862
+
863
+
864
+ class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
865
+ "Reloading output of jax2tf into JAX with call_tf"
866
+ def setUp(self):
867
+ if tf is None:
868
+ raise unittest.SkipTest("Test requires tensorflow")
869
+ # TODO(b/171320191): this line works around a missing context initialization
870
+ # bug in TensorFlow.
871
+ _ = tf.add(1, 1)
872
+ super().setUp()
873
+
874
+ def test_simple(self):
875
+ f_jax = jnp.sin
876
+ f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
877
+ x = np.float32(0.7)
878
+ self.assertAllClose(f_jax(x), f_jax_rt(x))
879
+
880
+ def test_pytree(self):
881
+ def f_jax(x): # x: dict(a=f32, b=f32)
882
+ return dict(a=x["a"]+1., b=x)
883
+ x = dict(a=0.7, b=0.8)
884
+ f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
885
+ self.assertAllClose(f_jax(x), f_jax_rt(x))
886
+
887
+ def test_custom_grad(self):
888
+ @jax.custom_vjp
889
+ def f(x):
890
+ return x * x
891
+
892
+ # f_fwd: a -> (b, residual)
893
+ def f_fwd(x):
894
+ return f(x), np.float32(3.) * x
895
+ # f_bwd: (residual, CT b) -> [CT a]
896
+ def f_bwd(residual, ct_b):
897
+ return residual * ct_b,
898
+
899
+ f.defvjp(f_fwd, f_bwd)
900
+
901
+ f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True))
902
+ x = np.float32(0.7)
903
+ self.assertAllClose(f(x), f_rt(x))
904
+ self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
905
+
906
+ def test_shape_poly(self):
907
+ f_jax = jnp.sin
908
+ f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax,
909
+ polymorphic_shapes=["(b, ...)"]))
910
+ x = np.array([0.7, 0.8], dtype=np.float32)
911
+ self.assertAllClose(f_jax(x), f_jax_rt(x))
912
+
913
+ def test_saved_model_simple(self):
914
+ x = np.array([0.7, 0.8], dtype=np.float32)
915
+ def f_jax(x):
916
+ return jnp.sin(x)
917
+
918
+ f_tf = jax2tf.convert(f_jax)
919
+ restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
920
+ restored_jax = jax2tf.call_tf(restored_tf)
921
+ self.assertAllClose(f_jax(x), restored_jax(x))
922
+
923
+ def test_saved_model_variables(self):
924
+ param = np.array([1., 2.], dtype=np.float32)
925
+ x = np.array([0.7, 0.8], dtype=np.float32)
926
+ def f_jax(param, x):
927
+ return jnp.sin(x) + jnp.cos(param)
928
+
929
+ param_v = tf.Variable(param)
930
+ f_tf = jax2tf.convert(f_jax)
931
+ _, restored_model = tf_test_util.SaveAndLoadFunction(
932
+ lambda x: f_tf(param_v, x),
933
+ input_args=[x],
934
+ variables=[param_v])
935
+ restored_jax = jax2tf.call_tf(restored_model.f)
936
+ self.assertAllClose(f_jax(param, x), restored_jax(x))
937
+ self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
938
+
939
+ def test_saved_model_shape_poly(self):
940
+ tracing_count = 0
941
+ x = np.array([0.7, 0.8], dtype=np.float32)
942
+ def f_jax(x):
943
+ nonlocal tracing_count
944
+ tracing_count += 1
945
+ return jnp.sin(x)
946
+
947
+ f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
948
+ res_jax = f_jax(x)
949
+ self.assertEqual(1, tracing_count)
950
+ # Will trace twice, it seems. Once to get the result signature, and once again
951
+ # for the actual saving.
952
+ restored_f, _ = tf_test_util.SaveAndLoadFunction(
953
+ f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
954
+ self.assertGreaterEqual(tracing_count, 2)
955
+ tracing_count = 0
956
+ f_jax_rt = jax2tf.call_tf(restored_f)
957
+ self.assertAllClose(res_jax, f_jax_rt(x))
958
+ # Ensure that restored_f works at other batch size as well
959
+ y = np.concatenate([x, x])
960
+ self.assertEqual(0, tracing_count)
961
+ res_jax_y = f_jax(y)
962
+ self.assertEqual(1, tracing_count)
963
+ # No more tracing for f_jax_rt
964
+ self.assertAllClose(res_jax_y, f_jax_rt(y))
965
+ self.assertEqual(1, tracing_count)
966
+
967
+ def test_custom_grad_saved_model(self):
968
+
969
+ @jax.custom_vjp
970
+ def f(x):
971
+ return x * x
972
+
973
+ # f_fwd: a -> (b, residual)
974
+ def f_fwd(x):
975
+ return f(x), np.float32(3.) * x
976
+ # f_bwd: (residual, CT b) -> [CT a]
977
+ def f_bwd(residual, ct_b):
978
+ return residual * ct_b,
979
+
980
+ f.defvjp(f_fwd, f_bwd)
981
+ def g(x):
982
+ return jnp.sum(f(x))
983
+
984
+ g_tf, _ = tf_test_util.SaveAndLoadFunction(
985
+ jax2tf.convert(g, with_gradient=True),
986
+ input_signature=[tf.TensorSpec(shape=(1,), dtype=tf.float32)],
987
+ )
988
+ g_rt = jax2tf.call_tf(g_tf)
989
+ x = np.array([0.7], dtype=np.float32)
990
+ self.assertAllClose(g(x), g_rt(x))
991
+ self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))
992
+
993
+ def test_without_gradient_saved_model(self):
994
+ # Explicitly with_gradient=False
995
+ f_jax = jnp.sum
996
+
997
+ x = np.array([0.7, 0.8], dtype=np.float32)
998
+ f_tf, _ = tf_test_util.SaveAndLoadFunction(
999
+ jax2tf.convert(f_jax, with_gradient=False),
1000
+ input_args=[x])
1001
+ f_rt = jax2tf.call_tf(f_tf)
1002
+
1003
+ self.assertAllClose(f_jax(x), f_rt(x))
1004
+ with self.assertRaisesRegex(Exception,
1005
+ "Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
1006
+ jax.grad(f_rt)(x)
1007
+
1008
+ def test_saved_model_no_gradients(self):
1009
+ # Save without gradients
1010
+ f_jax = jnp.sum
1011
+
1012
+ x = np.array([0.7, 0.8], dtype=np.float32)
1013
+ f_tf, _ = tf_test_util.SaveAndLoadFunction(
1014
+ jax2tf.convert(f_jax, with_gradient=True), input_args=[x],
1015
+ save_gradients=False)
1016
+ f_rt = jax2tf.call_tf(f_tf)
1017
+
1018
+ self.assertAllClose(f_jax(x), f_rt(x))
1019
+ # TODO: clean this up b/191117111: it should fail with a clear error
1020
+ # The following results in a confusing error:
1021
+ # TypeError: tf.Graph captured an external symbolic tensor.
1022
+ with self.assertRaises(TypeError):
1023
+ _ = jax.grad(f_rt)(x)
1024
+
1025
+ def test_call_tf_under_function_context(self):
1026
+ def fun_jax(x, y):
1027
+ z = jax2tf.call_tf(tf.math.sin)(x) + jnp.cos(y)
1028
+ return z
1029
+
1030
+ x = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
1031
+ y = np.array([-0.5, 0.0, 0.5], dtype=np.float32)
1032
+
1033
+ converted_fun = tf.function(
1034
+ jax2tf.convert(fun_jax, native_serialization=True)
1035
+ )
1036
+ expected = np.sin(x) + np.cos(y)
1037
+ res = tf.function(converted_fun, jit_compile=True, autograph=False)(x, y)
1038
+ self.assertAllClose(expected, res.numpy(), atol=1e-5, rtol=1e-5)
1039
+
1040
+ @parameterized.named_parameters(
1041
+ dict(
1042
+ testcase_name=f"_{dtype.__name__}",
1043
+ dtype=dtype,
1044
+ )
1045
+ for dtype in set(jtu.dtypes.all_floating)
1046
+ )
1047
+ def test_all_floating_input_gradient(self, dtype):
1048
+ def tf_f(x):
1049
+ res = tf.math.sin(x)
1050
+ return tf.reduce_sum(res)
1051
+
1052
+ jax_f = jax2tf.call_tf(tf_f)
1053
+ tf_f_rt = jax2tf.convert(jax_f)
1054
+ x = jnp.array([5.0, 6.0, 7.0]).astype(dtype)
1055
+
1056
+ def assert_all_close_support_bfloat16(baseline, candidate):
1057
+ def conversion(x):
1058
+ # convert scalar to array and bfloat16 to float32
1059
+ # to support self.assertAllClose numpy array comparison.
1060
+ if x.shape == tf.TensorShape([]):
1061
+ x = tf.convert_to_tensor([x])
1062
+ if dtype == jnp.float16:
1063
+ x = tf.cast(x, tf.float32)
1064
+ return x
1065
+
1066
+ baseline = jax.tree_util.tree_map(conversion, baseline)
1067
+ candidate = jax.tree_util.tree_map(conversion, candidate)
1068
+ tol = (
1069
+ 1e-2
1070
+ if jtu.test_device_matches(["tpu"]) and dtype == np.float16
1071
+ else None
1072
+ )
1073
+ self.assertAllClose(baseline, candidate, atol=tol, rtol=tol)
1074
+
1075
+ # Eager mode
1076
+ assert_all_close_support_bfloat16(tf_f(x), tf_f_rt(x))
1077
+
1078
+ # Compiled function mode
1079
+ assert_all_close_support_bfloat16(
1080
+ tf.function(tf_f)(x), tf.function(tf_f_rt)(x)
1081
+ )
1082
+
1083
+ # Compiled function mode with jit_compiled=True
1084
+ assert_all_close_support_bfloat16(
1085
+ tf.function(tf_f, jit_compile=True)(x),
1086
+ tf.function(tf_f_rt, jit_compile=True)(x),
1087
+ )
1088
+
1089
+ # RoundTrip test for the gradient
1090
+ grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f))
1091
+ grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))
1092
+
1093
+ # Eager mode
1094
+ assert_all_close_support_bfloat16(grad_fun_jax(x), grad_fun_jax_rt(x))
1095
+
1096
+ # Jit mode
1097
+ assert_all_close_support_bfloat16(
1098
+ jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)
1099
+ )
1100
+
1101
+ @parameterized.named_parameters(
1102
+ dict(
1103
+ testcase_name=f"_{dtype.__name__}",
1104
+ dtype=dtype,
1105
+ )
1106
+ for dtype in set(jtu.dtypes.complex)
1107
+ )
1108
+ def test_complex_input_gradient(self, dtype):
1109
+ def tf_f(x):
1110
+ res = tf.math.sin(x)
1111
+ return tf.reduce_sum(res)
1112
+
1113
+ x = jnp.array([(5.0 + 4.0j), (6.0 + 3.0j), (7.0 + 8.0j)]).astype(dtype)
1114
+
1115
+ jax_f = jax2tf.call_tf(tf_f)
1116
+ tf_f_rt = jax2tf.convert(jax_f)
1117
+
1118
+ # Eager mode
1119
+ self.assertAllClose(tf_f(x), tf_f_rt(x))
1120
+
1121
+ # tf.function context
1122
+ self.assertAllClose(tf.function(tf_f)(x), tf.function(tf_f_rt)(x))
1123
+
1124
+ # tf.function context with jit_compiled=True
1125
+ self.assertAllClose(
1126
+ tf.function(tf_f, jit_compile=True)(x),
1127
+ tf.function(tf_f_rt, jit_compile=True)(x),
1128
+ )
1129
+
1130
+ # RoundTrip test for the gradient
1131
+ grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f), holomorphic=True)
1132
+ grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))
1133
+
1134
+ # Eager mode
1135
+ self.assertAllClose(grad_fun_jax(x), grad_fun_jax_rt(x))
1136
+
1137
+ # Jit mode
1138
+ self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x))
1139
+
1140
+
1141
+ class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
1142
+ "Reloading output of call_tf into TF with jax2tf."
1143
+
1144
+ def setUp(self):
1145
+ if tf is None:
1146
+ raise unittest.SkipTest("Test requires tensorflow")
1147
+ # TODO(b/171320191): this line works around a missing context initialization
1148
+ # bug in TensorFlow.
1149
+ _ = tf.add(1, 1)
1150
+ super().setUp()
1151
+
1152
+ def test_alternate(self):
1153
+ # Alternate sin/cos with sin in TF and cos in JAX
1154
+ f_tf_inner = tf.math.sin
1155
+ def f_jax(x_jax):
1156
+ y_jax = jnp.cos(x_jax)
1157
+ z_jax = jax2tf.call_tf(f_tf_inner)(y_jax)
1158
+ return jnp.cos(z_jax)
1159
+ def f_tf_outer(x_tf):
1160
+ y_tf = tf.math.sin(x_tf)
1161
+ z_tf = jax2tf.convert(f_jax)(y_tf)
1162
+ return tf.math.sin(z_tf)
1163
+
1164
+ x = np.float32(0.7)
1165
+
1166
+ self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))),
1167
+ f_tf_outer(x).numpy())
1168
+ xv = tf.Variable(x)
1169
+ with tf.GradientTape() as tape:
1170
+ res = f_tf_outer(xv)
1171
+ g_tf = tape.gradient(res, xv)
1172
+ _, gf = tf_test_util.ComputeTfValueAndGrad(f_tf_outer, (x,))
1173
+ # Eager
1174
+ expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x)))))
1175
+ self.assertAllClose(expected_res, f_tf_outer(x).numpy())
1176
+
1177
+ # Gradient
1178
+ expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) *
1179
+ np.sin(np.sin(np.cos(np.sin(x)))) *
1180
+ np.cos(np.cos(np.sin(x))) *
1181
+ np.sin(np.sin(x)) *
1182
+ np.cos(x))
1183
+ self.assertAllClose(expected_grad, g_tf.numpy())
1184
+
1185
+ # Graph
1186
+ self.assertAllClose(expected_res,
1187
+ tf.function(f_tf_outer, autograph=False)(x).numpy())
1188
+
1189
+ # Compiled
1190
+ self.assertAllClose(expected_res,
1191
+ tf.function(f_tf_outer, autograph=False,
1192
+ jit_compile=True)(x).numpy())
1193
+
1194
+ def test_saved_model(self):
1195
+ x = np.array([.7, .8], dtype=np.float32)
1196
+ def fun_tf(x):
1197
+ return tf.math.sin(x)
1198
+ def fun_jax(x):
1199
+ return jax2tf.call_tf(fun_tf)(x)
1200
+
1201
+ # Now convert and save to SavedModel
1202
+ fun_tf_rt = jax2tf.convert(fun_jax)
1203
+ res = fun_tf_rt(x)
1204
+ self.assertAllClose(np.sin(x), res.numpy())
1205
+
1206
+ res = tf.function(fun_tf_rt, autograph=False)(x)
1207
+ self.assertAllClose(np.sin(x), res.numpy())
1208
+
1209
+ res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x)
1210
+ self.assertAllClose(np.sin(x), res.numpy())
1211
+
1212
+ reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
1213
+ fun_tf_rt, input_args=[x])
1214
+ res = reloaded_f(x)
1215
+ self.assertAllClose(np.sin(x), res.numpy())
1216
+
1217
+ def test_saved_model_polymorphic_input_static_output(self):
1218
+ x = np.array([.7, .8], dtype=np.float32)
1219
+ def fun_tf(x):
1220
+ return tf.math.reduce_sum(tf.math.sin(x))
1221
+ def fun_jax(x):
1222
+ return jax2tf.call_tf(fun_tf)(x)
1223
+
1224
+ # Now convert and save to SavedModel
1225
+ fun_tf_rt = jax2tf.convert(fun_jax)
1226
+ res = fun_tf_rt(x)
1227
+ self.assertAllClose(fun_tf(x), res.numpy())
1228
+
1229
+ res = tf.function(fun_tf_rt, autograph=False)(x)
1230
+ self.assertAllClose(fun_tf(x), res.numpy())
1231
+
1232
+ res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x)
1233
+ self.assertAllClose(fun_tf(x), res.numpy())
1234
+
1235
+ reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
1236
+ fun_tf_rt, input_args=[x])
1237
+ res = reloaded_f(x)
1238
+ self.assertAllClose(fun_tf(x), res.numpy())
1239
+
1240
+ def test_function_dynamic_shape(self):
1241
+ # Call a function for which shape inference does not give an output
1242
+ # shape.
1243
+ x = np.array([-1, 0, 1], dtype=np.int32)
1244
+ def fun_tf(x): # x:i32[3]
1245
+ # The shape depends on the value of x
1246
+ return tf.cond(x[0] >= 0, lambda: x, lambda: x[1:])
1247
+
1248
+ # Call in eager mode. Should work!
1249
+ res1 = jax2tf.call_tf(fun_tf)(x)
1250
+ expected = x[1:]
1251
+ self.assertAllClose(expected, res1, check_dtypes=False)
1252
+
1253
+ # Now under jit, should fail because the function is not compilable
1254
+ with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
1255
+ fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
1256
+ fun_jax(x)
1257
+
1258
+ # TODO(necula): this should work in op-by-op mode, but it fails because
1259
+ # jax2tf.convert does abstract evaluation.
1260
+ with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
1261
+ fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
1262
+ fun_tf_rt(x)
1263
+
1264
+ @_parameterized_jit
1265
+ def test_shape_poly_static_output_shape(self, with_jit=True):
1266
+ if jax.config.jax2tf_default_native_serialization:
1267
+ raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
1268
+ x = np.array([0.7, 0.8], dtype=np.float32)
1269
+
1270
+ def fun_tf(x):
1271
+ return tf.math.reduce_sum(tf.math.sin(x))
1272
+
1273
+ fun_jax = jax2tf.call_tf(fun_tf)
1274
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1275
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1276
+ self.assertAllClose(fun_tf(x), fun_tf_rt(x))
1277
+
1278
+ @_parameterized_jit
1279
+ def test_shape_poly(self, with_jit=False):
1280
+ if jax.config.jax2tf_default_native_serialization:
1281
+ raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
1282
+ x = np.array([7, 8, 9, 10], dtype=np.float32)
1283
+ def fun_jax(x):
1284
+ y = jax2tf.call_tf(tf.math.sin,
1285
+ output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype))(x)
1286
+ z = jnp.cos(y)
1287
+ w = jax2tf.call_tf(lambda z: tf.concat([z, z], axis=0),
1288
+ output_shape_dtype=jax.ShapeDtypeStruct((2 * z.shape[0],), z.dtype))(z)
1289
+ assert w.shape[0] == 2 * x.shape[0]
1290
+ return w
1291
+
1292
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1293
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1294
+ res_tf = fun_tf_rt(x)
1295
+ self.assertAllClose(fun_jax(x), res_tf)
1296
+
1297
+ @_parameterized_jit
1298
+ def test_shape_poly_pytree_result(self, with_jit=True):
1299
+ if jax.config.jax2tf_default_native_serialization:
1300
+ raise unittest.SkipTest("TODO(b/268386622): call_tf with shape polymorphism and native serialization.")
1301
+ x = np.array([7, 8, 9, 10], dtype=np.float32)
1302
+ def fun_jax(x):
1303
+ # Returns a tuple
1304
+ y = jax2tf.call_tf(lambda x: (x, tf.concat([x, x], axis=0)),
1305
+ output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
1306
+ jax.ShapeDtypeStruct((2 * x.shape[0],), x.dtype)))(x)
1307
+ assert y[0].shape[0] == x.shape[0]
1308
+ assert y[1].shape[0] == 2 * x.shape[0]
1309
+ return y
1310
+
1311
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1312
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1313
+ res_tf = fun_tf_rt(x)
1314
+ self.assertAllClose(fun_jax(x), res_tf)
1315
+
1316
+ @_parameterized_jit
1317
+ def test_shape_poly_error_no_output_shape_dtype(self, with_jit=True):
1318
+ x = np.array([7, 8, 9, 10], dtype=np.float32)
1319
+ def fun_jax(x):
1320
+ return jax2tf.call_tf(tf.math.sin)(x)
1321
+
1322
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1323
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1324
+ with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error):
1325
+ fun_tf_rt(x)
1326
+
1327
+ @_parameterized_jit
1328
+ def test_shape_poly_error_mismatch_output_shape_dtype_tree(self, with_jit=False):
1329
+ x = np.array([7, 8, 9, 10], dtype=np.float32)
1330
+ def fun_jax(x):
1331
+ return jax2tf.call_tf(tf.math.sin,
1332
+ output_shape_dtype=(jax.ShapeDtypeStruct(x.shape, x.dtype),
1333
+ jax.ShapeDtypeStruct(x.shape, x.dtype)))(x)
1334
+
1335
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1336
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1337
+
1338
+ with self.assertRaisesRegex(
1339
+ ValueError,
1340
+ "The pytree of the TensorFlow function results does not match the pytree of the declared output_shape_dtype"):
1341
+ fun_tf_rt(x)
1342
+
1343
+ @parameterized.named_parameters(
1344
+ _named_test(with_jit=with_jit, kind=kind)
1345
+ for with_jit in [True, False]
1346
+ for kind in ["bad_rank", "bad_dim", "bad_dtype", "bad_dtype_x64"])
1347
+ def test_shape_poly_error_mismatch_output_shape_dtype(self, with_jit=False, kind="bad_rank"):
1348
+ x = np.array([7, 8, 9, 10], dtype=np.float32)
1349
+
1350
+ if kind == "bad_rank":
1351
+ def fun_jax(x):
1352
+ return jax2tf.call_tf(lambda x: x,
1353
+ # Wrong shape rank
1354
+ output_shape_dtype=jax.ShapeDtypeStruct((), x.dtype))(x)
1355
+ elif kind == "bad_dim":
1356
+ def fun_jax(x):
1357
+ bad_shape = (5 + x.shape[0],)
1358
+ y = jax2tf.call_tf(lambda x: x,
1359
+ # Wrong dimension
1360
+ output_shape_dtype=jax.ShapeDtypeStruct(bad_shape, x.dtype))(x)
1361
+ # JAX will believe that the following is Ok, leading to downstream error in TF
1362
+ return y + jnp.ones(bad_shape, dtype=x.dtype)
1363
+ elif kind == "bad_dtype":
1364
+ def fun_jax(x):
1365
+ return jax2tf.call_tf(lambda x: x,
1366
+ output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.int32))(x)
1367
+ elif kind == "bad_dtype_x64":
1368
+ def fun_jax(x):
1369
+ return jax2tf.call_tf(lambda x: x * np.float64(3.),
1370
+ output_shape_dtype=jax.ShapeDtypeStruct(x.shape, np.float64))(x)
1371
+ else:
1372
+ assert False
1373
+ expect_ex = ValueError
1374
+ expect_error = r"The shapes or dtypes returned by the TensorFlow function do not match the declared output_shape_dtype"
1375
+
1376
+ # Call without shape polymorphism
1377
+ fun_tf_rt = _maybe_tf_jit(with_jit, jax2tf.convert(fun_jax))
1378
+ with self.assertRaisesRegex(expect_ex, expect_error):
1379
+ fun_tf_rt(x)
1380
+
1381
+ # Now with shape polymorphism
1382
+ if kind == "bad_dim" and with_jit:
1383
+ # TODO: in jit more the error pops up later, at AddV2
1384
+ expect_error = "Dimensions must be equal, but are 4 and 9 for .* AddV2"
1385
+ if kind == "bad_dim" and jax.config.jax2tf_default_native_serialization:
1386
+ # TODO(b/268386622): call_tf with shape polymorphism and native serialization.
1387
+ expect_error = "Error compiling TensorFlow function"
1388
+ fun_tf_rt = _maybe_tf_jit(with_jit,
1389
+ jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]))
1390
+ with self.assertRaisesRegex(expect_ex, expect_error):
1391
+ fun_tf_rt(x)
1392
+
1393
+ def test_inner_native_serialization(self):
1394
+ # Two nested jax2tf, the inner one being with native serialization
1395
+ x = np.ones((3,), dtype=np.float32)
1396
+ def f_inner_jax(x):
1397
+ return jnp.sin(x)
1398
+ def f_outer_jax(x):
1399
+ f_inner_tf = jax2tf.convert(f_inner_jax, native_serialization=True)
1400
+ return jnp.cos(jax2tf.call_tf(f_inner_tf)(x))
1401
+
1402
+ f_outer_tf = tf.function(
1403
+ jax2tf.convert(f_outer_jax, native_serialization=False),
1404
+ autograph=False)
1405
+ f_outer_graph = str(f_outer_tf.get_concrete_function(tf.convert_to_tensor(x)).graph.as_graph_def())
1406
+ # Quick way to check that there is an XlaCallModule op, and a Cos op, but no Sin op
1407
+ self.assertIn('op: "Cos"', f_outer_graph)
1408
+ self.assertIn('op: "XlaCallModule"', f_outer_graph)
1409
+ self.assertNotIn('op: "Sin"', f_outer_graph)
1410
+
1411
+ @parameterized.named_parameters(
1412
+ _named_test(f2_function=f2_function, f2_saved_model=f2_saved_model,
1413
+ f4_function=f4_function, f4_saved_model=f4_saved_model)
1414
+ for f2_function in [True, False]
1415
+ for f2_saved_model in [True, False]
1416
+ for f4_function in [True, False]
1417
+ for f4_saved_model in [True, False])
1418
+ def test_several_round_trips(self,
1419
+ f2_function=False, f2_saved_model=False,
1420
+ f4_function=False, f4_saved_model=False):
1421
+ if (f2_saved_model and
1422
+ f4_saved_model and
1423
+ not jax.config.jax2tf_default_native_serialization):
1424
+ # TODO: Getting error Found invalid capture Tensor("jax2tf_vjp/jax2tf_arg_0:0", shape=(), dtype=float32) when saving custom gradients
1425
+ # when saving f4, but only with non-native serialization.
1426
+ raise unittest.SkipTest("TODO: error invalid capture when saving custom gradients")
1427
+ x = np.array(.7, dtype=np.float32)
1428
+ # f(n)(x) = 2. * x^n
1429
+ def f(n):
1430
+ def fn(x):
1431
+ acc = np.array(2., dtype=x.dtype)
1432
+ for i in range(n):
1433
+ acc *= x
1434
+ return acc
1435
+ return fn
1436
+
1437
+ f2_tf = lambda x: x * jax2tf.convert(f(1))(x)
1438
+ if f2_function:
1439
+ f2_tf = tf.function(f2_tf, autograph=False)
1440
+ if f2_saved_model:
1441
+ f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x])
1442
+
1443
+ self.assertAllClose(f(2)(x), f2_tf(x).numpy())
1444
+ _, (g_f2_ft,) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x])
1445
+ self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy())
1446
+
1447
+ f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x)
1448
+ self.assertAllClose(f(3)(x), f3_jax(x))
1449
+ self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x))
1450
+ self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x))
1451
+
1452
+ f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x)
1453
+ self.assertAllClose(f(4)(x), f4_tf(x).numpy())
1454
+ _, (g_f4_ft,) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
1455
+ self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
1456
+
1457
+ if f4_function:
1458
+ f4_tf = tf.function(f4_tf, autograph=False)
1459
+ if f4_saved_model:
1460
+ f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x])
1461
+ self.assertAllClose(f(4)(x), f4_tf(x).numpy())
1462
+ _, (g_f4_ft,) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
1463
+ self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
1464
+
1465
+ @classmethod
1466
+ def _walk_stablehlo_operations(cls, op, cb):
1467
+ """walk the stablehlo operation recursive with callback function."""
1468
+ cb(op)
1469
+ for region in op.operation.regions:
1470
+ for block in region:
1471
+ for op in block:
1472
+ cls._walk_stablehlo_operations(op, cb)
1473
+
1474
+ def test_call_tf_graph(self):
1475
+ const = tf.Variable(0.0, dtype=tf.float32)
1476
+
1477
+ @tf.function(jit_compile=True)
1478
+ def tf_func_1(x):
1479
+ return x * x + const
1480
+
1481
+ @tf.function
1482
+ def tf_func_2(x, y):
1483
+ return tf_func_1(x) + y
1484
+
1485
+ @tf.function
1486
+ def tf_func_3(x, y, z):
1487
+ return tf_func_2(x, y) + z, z
1488
+
1489
+ x = jnp.array(3.0, dtype=jnp.float32)
1490
+ y = jnp.array(3.0, dtype=jnp.float32)
1491
+ z = jnp.array(5.0, dtype=jnp.float32)
1492
+ f_jax = jax.jit(jax2tf.call_tf(tf_func_3, call_tf_graph=False))
1493
+ stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo")
1494
+ self.assertNotIn("stablehlo.custom_call", str(stablehlo_module))
1495
+
1496
+ f_jax = jax.jit(
1497
+ jax2tf.call_tf(
1498
+ tf_func_3,
1499
+ call_tf_graph=True,
1500
+ )
1501
+ )
1502
+ with self.assertRaisesRegex(
1503
+ ValueError,
1504
+ "call_tf_graph=True only support exporting by jax2tf.convert currently",
1505
+ ):
1506
+ stablehlo_module = f_jax.lower(x, y, z).compiler_ir("stablehlo")
1507
+ self.assertIn("stablehlo.custom_call", str(stablehlo_module))
1508
+
1509
+ called_index_list = []
1510
+
1511
+ def _extract_info(op):
1512
+ if op.operation.name != "stablehlo.custom_call":
1513
+ return
1514
+ tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
1515
+ called_index = ir.IntegerAttr(tf_backend_config["called_index"]).value
1516
+ called_index_list.append(called_index)
1517
+
1518
+ self._walk_stablehlo_operations(stablehlo_module, _extract_info)
1519
+ self.assertLen(called_index_list, 1)
1520
+
1521
+ @parameterized.named_parameters(
1522
+ dict(
1523
+ testcase_name="multiple_outputs",
1524
+ tf_f=lambda x: tf.py_function(np.sin, [x], tf.float32),
1525
+ output_shape_dtype=jax.ShapeDtypeStruct((10,), jnp.float32),
1526
+ ),
1527
+ dict(
1528
+ testcase_name="zero_outputs",
1529
+ tf_f=lambda x: print(tf.strings.length(tf.constant("hello, world"))),
1530
+ output_shape_dtype=None,
1531
+ ),
1532
+ )
1533
+ def test_call_tf_graph_non_compilable(self, tf_f, output_shape_dtype):
1534
+ inputs = jnp.ones([10], dtype=jnp.float32)
1535
+ called_index_list = []
1536
+ xla_call_module_list = []
1537
+
1538
+ def _extract_info(op):
1539
+ if op.operation.name != "stablehlo.custom_call":
1540
+ return
1541
+ tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
1542
+ called_index = ir.IntegerAttr(tf_backend_config["called_index"]).value
1543
+ called_index_list.append(called_index)
1544
+
1545
+ jax_f = jax2tf.call_tf(
1546
+ tf_f,
1547
+ call_tf_graph=True,
1548
+ output_shape_dtype=output_shape_dtype,
1549
+ )
1550
+
1551
+ # Eager mode
1552
+ self.assertAllClose(tf_f(inputs), jax_f(inputs))
1553
+
1554
+ # Jit mode
1555
+ stablehlo_module = None
1556
+ with self.assertRaisesRegex(
1557
+ ValueError,
1558
+ "call_tf_graph=True only support exporting by jax2tf.convert currently",
1559
+ ):
1560
+ stablehlo_module = jax.jit(jax_f).lower(inputs).compiler_ir("stablehlo")
1561
+ if stablehlo_module:
1562
+ self.assertIn(
1563
+ "stablehlo.custom_call @tf.call_tf_function",
1564
+ str(stablehlo_module),
1565
+ )
1566
+ self.assertIn("tf.backend_config", str(stablehlo_module))
1567
+ self._walk_stablehlo_operations(stablehlo_module, _extract_info)
1568
+ self.assertLen(called_index_list, 1)
1569
+
1570
+ # Test model exporting and reloading.
1571
+ # There is no runtime support yet so it can not run.
1572
+ tf_f_rt = jax2tf.convert(
1573
+ jax_f,
1574
+ native_serialization=True,
1575
+ with_gradient=False,
1576
+ )
1577
+ _, restored_model = tf_test_util.SaveAndLoadFunction(
1578
+ tf_f_rt, input_args=[inputs]
1579
+ )
1580
+ func_def = restored_model.f.concrete_functions[0]
1581
+
1582
+ for node_def in func_def.graph.as_graph_def().node:
1583
+ if node_def.op == "XlaCallModule":
1584
+ xla_call_module_list.append(node_def)
1585
+ # There is only one xla_call_module in the saved model.
1586
+ self.assertLen(xla_call_module_list, 1)
1587
+
1588
+ # Check the xla_call_module version and function_list attributes.
1589
+ xla_call_module = xla_call_module_list[0]
1590
+ self.assertGreaterEqual(xla_call_module.attr["version"].i, 5)
1591
+ self.assertIn("function_list", str(xla_call_module.attr))
1592
+ xla_call_module_list.clear()
1593
+ called_index_list.clear()
1594
+
1595
+ # If JAX calls same tensorflow function by `jax2tf.call_tf` twice,
1596
+ # it should return two different tf concrete functions.
1597
+ def jax_f_2(x):
1598
+ res1 = jax2tf.call_tf(
1599
+ tf_f,
1600
+ call_tf_graph=True,
1601
+ output_shape_dtype=output_shape_dtype,
1602
+ )(x)
1603
+ res2 = jax2tf.call_tf(
1604
+ tf_f,
1605
+ call_tf_graph=True,
1606
+ output_shape_dtype=output_shape_dtype,
1607
+ )(x)
1608
+ return res1, res2
1609
+ stablehlo_module = None
1610
+ with self.assertRaisesRegex(ValueError, "call_tf_graph=True only support exporting by jax2tf.convert currently"):
1611
+ stablehlo_module = jax.jit(jax_f_2).lower(inputs).compiler_ir("stablehlo")
1612
+ if stablehlo_module:
1613
+ self._walk_stablehlo_operations(stablehlo_module, _extract_info)
1614
+ xla_call_module_list.clear()
1615
+
1616
+ def test_b279454591(self):
1617
+ """Test case when tensorflow function returns `StatefulPartitionedCall` op."""
1618
+ inputs = jnp.ones([10], dtype=jnp.float32)
1619
+
1620
+ # With one or more outputs, it is okay.
1621
+ def tf_f(x):
1622
+ y = tf.math.sin(3.0)
1623
+ tf.print(y)
1624
+ return x
1625
+
1626
+ jax_f = jax2tf.call_tf(
1627
+ tf.function(tf_f),
1628
+ call_tf_graph=True,
1629
+ )
1630
+ tf_f_rt = jax2tf.convert(
1631
+ jax_f,
1632
+ native_serialization=True,
1633
+ with_gradient=False,
1634
+ )
1635
+ _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt, input_args=[inputs])
1636
+
1637
+ # With zero output, it return `StatefulPartitionedCall` op instead.
1638
+ def tf_f_2():
1639
+ y = tf.math.sin(3.0)
1640
+ tf.print(y)
1641
+ return
1642
+
1643
+ jax_f_2 = jax2tf.call_tf(tf.function(tf_f_2), call_tf_graph=True)
1644
+ tf_f_rt_2 = jax2tf.convert(
1645
+ jax_f_2,
1646
+ native_serialization=True,
1647
+ with_gradient=False,
1648
+ )
1649
+ _, _ = tf_test_util.SaveAndLoadFunction(tf_f_rt_2, input_args=[])
1650
+
1651
+ @jtu.parameterized_filterable(
1652
+ kwargs=[dict(version=version) for version in [9]]
1653
+ )
1654
+ def test_call_tf_graph_ordered(self, *, version: int):
1655
+ with config.jax_serialization_version(version):
1656
+ logging.info(
1657
+ "Using JAX serialization version %s",
1658
+ jax.config.jax_serialization_version)
1659
+
1660
+ @tf.function
1661
+ def tf_print(x):
1662
+ tf.print(x)
1663
+
1664
+ call_tf_print = jax2tf.call_tf(
1665
+ tf_print,
1666
+ call_tf_graph=True,
1667
+ ordered=True,
1668
+ )
1669
+
1670
+ x = jnp.array(1.0, dtype=jnp.float32)
1671
+
1672
+ def body(i, x):
1673
+ call_tf_print(x)
1674
+ return x + 1
1675
+
1676
+ @jax.jit
1677
+ def f_jax(x):
1678
+ return jax.lax.fori_loop(0, 4, body, x)
1679
+
1680
+ num_custom_calls = 0
1681
+
1682
+ def _check_mlir_ops(op):
1683
+ nonlocal num_custom_calls
1684
+
1685
+ if (
1686
+ op.operation.name == "stablehlo.custom_call"
1687
+ and ir.StringAttr(op.attributes["call_target_name"]).value
1688
+ == "tf.call_tf_function"
1689
+ ):
1690
+ num_custom_calls += 1
1691
+
1692
+ # The custom call op must have `has_token_input_output` attribute.
1693
+ tf_backend_config = ir.DictAttr(op.attributes["tf.backend_config"])
1694
+ self.assertTrue(
1695
+ ir.BoolAttr(tf_backend_config["has_token_input_output"]).value
1696
+ )
1697
+
1698
+ # Verify that the first argument/result of the custom call op is a token
1699
+ # type. This is a calling convention defined by `has_token_input_output`.
1700
+ self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
1701
+ self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
1702
+
1703
+ stablehlo_module = None
1704
+ with self.assertRaisesRegex(
1705
+ ValueError,
1706
+ "call_tf_graph=True only support exporting by jax2tf.convert currently",
1707
+ ):
1708
+ lower = f_jax.lower(x)
1709
+ self.assertNotEmpty(lower._lowering.compile_args["ordered_effects"])
1710
+ stablehlo_module = lower.compiler_ir("stablehlo")
1711
+ if stablehlo_module:
1712
+ self._walk_stablehlo_operations(stablehlo_module, _check_mlir_ops)
1713
+ self.assertEqual(num_custom_calls, 1)
1714
+
1715
+ f_tf = jax2tf.convert(
1716
+ f_jax,
1717
+ native_serialization=True,
1718
+ with_gradient=False,
1719
+ )
1720
+ _, restored_model = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
1721
+
1722
+ @jtu.parameterized_filterable(
1723
+ kwargs=[dict(poly=poly, version=version)
1724
+ for poly in [True, False]
1725
+ for version in [9]]
1726
+ )
1727
+ def test_call_tf_ordered_dead_inputs(self, *, poly: bool, version: int):
1728
+ with config.jax_serialization_version(version):
1729
+ logging.info(
1730
+ "Using JAX serialization version %s",
1731
+ jax.config.jax_serialization_version)
1732
+ def f_jax(x1, x_dead, x3):
1733
+ return (x1, jax2tf.call_tf(lambda x: tf.math.sin(x), ordered=True,
1734
+ call_tf_graph=True)(x3))
1735
+ if poly:
1736
+ polymorphic_shapes = ["b", None, None]
1737
+ else:
1738
+ polymorphic_shapes = None
1739
+ f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes)
1740
+ x1 = np.arange(3, dtype=np.float32)
1741
+ x_dead = np.arange(4, dtype=np.float32)
1742
+ x3 = np.arange(5, dtype=np.float32)
1743
+ self.assertAllClose(f_jax(x1, x_dead, x3),
1744
+ f_tf(x1, x_dead, x3))
1745
+
1746
+ @jtu.parameterized_filterable(
1747
+ kwargs=[dict(ordered=ordered, version=version)
1748
+ for ordered in [True, False]
1749
+ for version in [9]
1750
+ ]
1751
+ )
1752
+ def test_call_tf_graph_polymorphic(self, ordered: bool, version: int):
1753
+ with config.jax_serialization_version(version):
1754
+ logging.info(
1755
+ "Using JAX serialization version %s",
1756
+ jax.config.jax_serialization_version)
1757
+
1758
+ @tf.function(jit_compile=True, autograph=False)
1759
+ @partial(jax2tf.convert,
1760
+ with_gradient=False,
1761
+ native_serialization=True,
1762
+ polymorphic_shapes=["(b)"])
1763
+ @jax.jit
1764
+ def tf_f_2(x):
1765
+ tf_f = lambda x: print(tf.strings.length(tf.constant("hello, world")))
1766
+ jax2tf.call_tf(tf_f,
1767
+ call_tf_graph=True,
1768
+ ordered=ordered)(x)
1769
+ return x
1770
+
1771
+ x = np.arange(3, dtype=np.int32)
1772
+ _ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)
1773
+
1774
+ # TODO(b/293927250): call_tf_graph=True only accept concrete_function. The
1775
+ # workaround here is to set `module.call=concrete_fn.`.
1776
+ @unittest.skip(
1777
+ "The root cause here is because the XLACallModule.function_list attribute"
1778
+ " depends on JAX call_tf lowering. The 2nd time tf.SavedModel TF tracing"
1779
+ " will not trigger call_tf tracing since it was already cached. The"
1780
+ " solution is to create the `CallTFContext` to make TF tracing and JAX"
1781
+ " tracing work together correctly."
1782
+ )
1783
+ def test_call_tf_graph_save_and_load(self):
1784
+ def jax_func(x):
1785
+ def func_tf(x):
1786
+ return tf.math.sin(x)
1787
+
1788
+ return jnp.cos(
1789
+ jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
1790
+ )
1791
+ data_inputs = (np.array([0.5, 0.7], dtype=np.float32),)
1792
+
1793
+ def tf_func(the_input):
1794
+ res = jax2tf.convert(jax_func, native_serialization=True)(the_input)
1795
+ return tf.identity(res, name="the_result")
1796
+
1797
+ jit_tf_func = tf.function(
1798
+ tf_func,
1799
+ autograph=False,
1800
+ jit_compile=True,
1801
+ )
1802
+ # The next line is necessary to reproduce this issue. It trigger TF
1803
+ # ConcreteFunction tracing. Otherwise, you will fail with another error
1804
+ # `Found zero restored functions for caller function`.
1805
+ _ = jit_tf_func.get_concrete_function(*data_inputs)
1806
+ module = tf.Module()
1807
+ module.call = jit_tf_func # Switching to concrete_function works.
1808
+ root_dir = self.create_tempdir()
1809
+ saved_model_dir = os.path.join(root_dir, "saved_model")
1810
+ tf.saved_model.save(
1811
+ module,
1812
+ saved_model_dir,
1813
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
1814
+ )
1815
+ loaded_model = tf.saved_model.load(saved_model_dir)
1816
+ res = loaded_model.call(*data_inputs)
1817
+ self.assertAllClose(jax_func(*data_inputs), res)
1818
+
1819
+
1820
+ if __name__ == "__main__":
1821
+ absltest.main(testLoader=jtu.JaxTestLoader())