ZAIDX11 commited on
Commit
6519678
·
verified ·
1 Parent(s): 35ede7c

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_D_E_F_.py +13 -0
  2. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_M_A_P_.py +148 -0
  3. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_P_K_G_.py +133 -0
  4. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_P_O_S_.py +14 -0
  5. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_S_U_B_.py +13 -0
  6. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_V_A_R_.py +5 -0
  7. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G__l_a_t.py +235 -0
  8. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G__l_o_c.py +85 -0
  9. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/H_V_A_R_.py +13 -0
  10. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/J_S_T_F_.py +13 -0
  11. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/L_T_S_H_.py +58 -0
  12. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_A_T_H_.py +13 -0
  13. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_E_T_A_.py +352 -0
  14. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_V_A_R_.py +13 -0
  15. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otBase.py +1464 -0
  16. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otConverters.py +2068 -0
  17. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otData.py +0 -0
  18. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otTables.py +0 -0
  19. external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otTraverse.py +163 -0
  20. external/alphageometry/.venv-ag/Lib/site-packages/jax/_src/scipy/special.py +2574 -0
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_D_E_F_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_G_D_E_F_(BaseTTXConverter):
5
+ """Glyph Definition table
6
+
7
+ The ``GDEF`` table stores various glyph properties that are used
8
+ by OpenType Layout.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/gdef
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_M_A_P_.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc import sstruct
2
+ from fontTools.misc.textTools import tobytes, tostr, safeEval
3
+ from . import DefaultTable
4
+
5
+ GMAPFormat = """
6
+ > # big endian
7
+ tableVersionMajor: H
8
+ tableVersionMinor: H
9
+ flags: H
10
+ recordsCount: H
11
+ recordsOffset: H
12
+ fontNameLength: H
13
+ """
14
+ # psFontName is a byte string which follows the record above. This is zero padded
15
+ # to the beginning of the records array. The recordsOffsst is 32 bit aligned.
16
+
17
+ GMAPRecordFormat1 = """
18
+ > # big endian
19
+ UV: L
20
+ cid: H
21
+ gid: H
22
+ ggid: H
23
+ name: 32s
24
+ """
25
+
26
+
27
+ class GMAPRecord(object):
28
+ def __init__(self, uv=0, cid=0, gid=0, ggid=0, name=""):
29
+ self.UV = uv
30
+ self.cid = cid
31
+ self.gid = gid
32
+ self.ggid = ggid
33
+ self.name = name
34
+
35
+ def toXML(self, writer, ttFont):
36
+ writer.begintag("GMAPRecord")
37
+ writer.newline()
38
+ writer.simpletag("UV", value=self.UV)
39
+ writer.newline()
40
+ writer.simpletag("cid", value=self.cid)
41
+ writer.newline()
42
+ writer.simpletag("gid", value=self.gid)
43
+ writer.newline()
44
+ writer.simpletag("glyphletGid", value=self.gid)
45
+ writer.newline()
46
+ writer.simpletag("GlyphletName", value=self.name)
47
+ writer.newline()
48
+ writer.endtag("GMAPRecord")
49
+ writer.newline()
50
+
51
+ def fromXML(self, name, attrs, content, ttFont):
52
+ value = attrs["value"]
53
+ if name == "GlyphletName":
54
+ self.name = value
55
+ else:
56
+ setattr(self, name, safeEval(value))
57
+
58
+ def compile(self, ttFont):
59
+ if self.UV is None:
60
+ self.UV = 0
61
+ nameLen = len(self.name)
62
+ if nameLen < 32:
63
+ self.name = self.name + "\0" * (32 - nameLen)
64
+ data = sstruct.pack(GMAPRecordFormat1, self)
65
+ return data
66
+
67
+ def __repr__(self):
68
+ return (
69
+ "GMAPRecord[ UV: "
70
+ + str(self.UV)
71
+ + ", cid: "
72
+ + str(self.cid)
73
+ + ", gid: "
74
+ + str(self.gid)
75
+ + ", ggid: "
76
+ + str(self.ggid)
77
+ + ", Glyphlet Name: "
78
+ + str(self.name)
79
+ + " ]"
80
+ )
81
+
82
+
83
+ class table_G_M_A_P_(DefaultTable.DefaultTable):
84
+ """Glyphlets GMAP table
85
+
86
+ The ``GMAP`` table is used by Adobe's SING Glyphlets.
87
+
88
+ See also https://web.archive.org/web/20080627183635/http://www.adobe.com/devnet/opentype/gdk/topic.html
89
+ """
90
+
91
+ dependencies = []
92
+
93
+ def decompile(self, data, ttFont):
94
+ dummy, newData = sstruct.unpack2(GMAPFormat, data, self)
95
+ self.psFontName = tostr(newData[: self.fontNameLength])
96
+ assert (
97
+ self.recordsOffset % 4
98
+ ) == 0, "GMAP error: recordsOffset is not 32 bit aligned."
99
+ newData = data[self.recordsOffset :]
100
+ self.gmapRecords = []
101
+ for i in range(self.recordsCount):
102
+ gmapRecord, newData = sstruct.unpack2(
103
+ GMAPRecordFormat1, newData, GMAPRecord()
104
+ )
105
+ gmapRecord.name = gmapRecord.name.strip("\0")
106
+ self.gmapRecords.append(gmapRecord)
107
+
108
+ def compile(self, ttFont):
109
+ self.recordsCount = len(self.gmapRecords)
110
+ self.fontNameLength = len(self.psFontName)
111
+ self.recordsOffset = 4 * (((self.fontNameLength + 12) + 3) // 4)
112
+ data = sstruct.pack(GMAPFormat, self)
113
+ data = data + tobytes(self.psFontName)
114
+ data = data + b"\0" * (self.recordsOffset - len(data))
115
+ for record in self.gmapRecords:
116
+ data = data + record.compile(ttFont)
117
+ return data
118
+
119
+ def toXML(self, writer, ttFont):
120
+ writer.comment("Most of this table will be recalculated by the compiler")
121
+ writer.newline()
122
+ formatstring, names, fixes = sstruct.getformat(GMAPFormat)
123
+ for name in names:
124
+ value = getattr(self, name)
125
+ writer.simpletag(name, value=value)
126
+ writer.newline()
127
+ writer.simpletag("PSFontName", value=self.psFontName)
128
+ writer.newline()
129
+ for gmapRecord in self.gmapRecords:
130
+ gmapRecord.toXML(writer, ttFont)
131
+
132
+ def fromXML(self, name, attrs, content, ttFont):
133
+ if name == "GMAPRecord":
134
+ if not hasattr(self, "gmapRecords"):
135
+ self.gmapRecords = []
136
+ gmapRecord = GMAPRecord()
137
+ self.gmapRecords.append(gmapRecord)
138
+ for element in content:
139
+ if isinstance(element, str):
140
+ continue
141
+ name, attrs, content = element
142
+ gmapRecord.fromXML(name, attrs, content, ttFont)
143
+ else:
144
+ value = attrs["value"]
145
+ if name == "PSFontName":
146
+ self.psFontName = value
147
+ else:
148
+ setattr(self, name, safeEval(value))
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_P_K_G_.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc import sstruct
2
+ from fontTools.misc.textTools import bytesjoin, safeEval, readHex
3
+ from . import DefaultTable
4
+ import sys
5
+ import array
6
+
7
+ GPKGFormat = """
8
+ > # big endian
9
+ version: H
10
+ flags: H
11
+ numGMAPs: H
12
+ numGlyplets: H
13
+ """
14
+ # psFontName is a byte string which follows the record above. This is zero padded
15
+ # to the beginning of the records array. The recordsOffsst is 32 bit aligned.
16
+
17
+
18
+ class table_G_P_K_G_(DefaultTable.DefaultTable):
19
+ """Glyphlets GPKG table
20
+
21
+ The ``GPKG`` table is used by Adobe's SING Glyphlets.
22
+
23
+ See also https://web.archive.org/web/20080627183635/http://www.adobe.com/devnet/opentype/gdk/topic.html
24
+ """
25
+
26
+ def decompile(self, data, ttFont):
27
+ dummy, newData = sstruct.unpack2(GPKGFormat, data, self)
28
+
29
+ GMAPoffsets = array.array("I")
30
+ endPos = (self.numGMAPs + 1) * 4
31
+ GMAPoffsets.frombytes(newData[:endPos])
32
+ if sys.byteorder != "big":
33
+ GMAPoffsets.byteswap()
34
+ self.GMAPs = []
35
+ for i in range(self.numGMAPs):
36
+ start = GMAPoffsets[i]
37
+ end = GMAPoffsets[i + 1]
38
+ self.GMAPs.append(data[start:end])
39
+ pos = endPos
40
+ endPos = pos + (self.numGlyplets + 1) * 4
41
+ glyphletOffsets = array.array("I")
42
+ glyphletOffsets.frombytes(newData[pos:endPos])
43
+ if sys.byteorder != "big":
44
+ glyphletOffsets.byteswap()
45
+ self.glyphlets = []
46
+ for i in range(self.numGlyplets):
47
+ start = glyphletOffsets[i]
48
+ end = glyphletOffsets[i + 1]
49
+ self.glyphlets.append(data[start:end])
50
+
51
+ def compile(self, ttFont):
52
+ self.numGMAPs = len(self.GMAPs)
53
+ self.numGlyplets = len(self.glyphlets)
54
+ GMAPoffsets = [0] * (self.numGMAPs + 1)
55
+ glyphletOffsets = [0] * (self.numGlyplets + 1)
56
+
57
+ dataList = [sstruct.pack(GPKGFormat, self)]
58
+
59
+ pos = len(dataList[0]) + (self.numGMAPs + 1) * 4 + (self.numGlyplets + 1) * 4
60
+ GMAPoffsets[0] = pos
61
+ for i in range(1, self.numGMAPs + 1):
62
+ pos += len(self.GMAPs[i - 1])
63
+ GMAPoffsets[i] = pos
64
+ gmapArray = array.array("I", GMAPoffsets)
65
+ if sys.byteorder != "big":
66
+ gmapArray.byteswap()
67
+ dataList.append(gmapArray.tobytes())
68
+
69
+ glyphletOffsets[0] = pos
70
+ for i in range(1, self.numGlyplets + 1):
71
+ pos += len(self.glyphlets[i - 1])
72
+ glyphletOffsets[i] = pos
73
+ glyphletArray = array.array("I", glyphletOffsets)
74
+ if sys.byteorder != "big":
75
+ glyphletArray.byteswap()
76
+ dataList.append(glyphletArray.tobytes())
77
+ dataList += self.GMAPs
78
+ dataList += self.glyphlets
79
+ data = bytesjoin(dataList)
80
+ return data
81
+
82
+ def toXML(self, writer, ttFont):
83
+ writer.comment("Most of this table will be recalculated by the compiler")
84
+ writer.newline()
85
+ formatstring, names, fixes = sstruct.getformat(GPKGFormat)
86
+ for name in names:
87
+ value = getattr(self, name)
88
+ writer.simpletag(name, value=value)
89
+ writer.newline()
90
+
91
+ writer.begintag("GMAPs")
92
+ writer.newline()
93
+ for gmapData in self.GMAPs:
94
+ writer.begintag("hexdata")
95
+ writer.newline()
96
+ writer.dumphex(gmapData)
97
+ writer.endtag("hexdata")
98
+ writer.newline()
99
+ writer.endtag("GMAPs")
100
+ writer.newline()
101
+
102
+ writer.begintag("glyphlets")
103
+ writer.newline()
104
+ for glyphletData in self.glyphlets:
105
+ writer.begintag("hexdata")
106
+ writer.newline()
107
+ writer.dumphex(glyphletData)
108
+ writer.endtag("hexdata")
109
+ writer.newline()
110
+ writer.endtag("glyphlets")
111
+ writer.newline()
112
+
113
+ def fromXML(self, name, attrs, content, ttFont):
114
+ if name == "GMAPs":
115
+ if not hasattr(self, "GMAPs"):
116
+ self.GMAPs = []
117
+ for element in content:
118
+ if isinstance(element, str):
119
+ continue
120
+ itemName, itemAttrs, itemContent = element
121
+ if itemName == "hexdata":
122
+ self.GMAPs.append(readHex(itemContent))
123
+ elif name == "glyphlets":
124
+ if not hasattr(self, "glyphlets"):
125
+ self.glyphlets = []
126
+ for element in content:
127
+ if isinstance(element, str):
128
+ continue
129
+ itemName, itemAttrs, itemContent = element
130
+ if itemName == "hexdata":
131
+ self.glyphlets.append(readHex(itemContent))
132
+ else:
133
+ setattr(self, name, safeEval(attrs["value"]))
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_P_O_S_.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_G_P_O_S_(BaseTTXConverter):
5
+ """Glyph Positioning table
6
+
7
+ The ``GPOS`` table stores advanced glyph-positioning data
8
+ used in OpenType Layout features, such as mark attachment,
9
+ cursive attachment, kerning, and other position adjustments.
10
+
11
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/gpos
12
+ """
13
+
14
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_S_U_B_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_G_S_U_B_(BaseTTXConverter):
5
+ """Glyph Substitution table
6
+
7
+ The ``GSUB`` table contains glyph-substitution rules used in
8
+ OpenType Layout.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/gsub
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G_V_A_R_.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from ._g_v_a_r import table__g_v_a_r
2
+
3
+
4
+ class table_G_V_A_R_(table__g_v_a_r):
5
+ gid_size = 3
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G__l_a_t.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc import sstruct
2
+ from fontTools.misc.fixedTools import floatToFixedToStr
3
+ from fontTools.misc.textTools import safeEval
4
+
5
+ # from itertools import *
6
+ from functools import partial
7
+ from . import DefaultTable
8
+ from . import grUtils
9
+ import struct
10
+
11
+
12
+ Glat_format_0 = """
13
+ > # big endian
14
+ version: 16.16F
15
+ """
16
+
17
+ Glat_format_3 = """
18
+ >
19
+ version: 16.16F
20
+ compression:L # compression scheme or reserved
21
+ """
22
+
23
+ Glat_format_1_entry = """
24
+ >
25
+ attNum: B # Attribute number of first attribute
26
+ num: B # Number of attributes in this run
27
+ """
28
+ Glat_format_23_entry = """
29
+ >
30
+ attNum: H # Attribute number of first attribute
31
+ num: H # Number of attributes in this run
32
+ """
33
+
34
+ Glat_format_3_octabox_metrics = """
35
+ >
36
+ subboxBitmap: H # Which subboxes exist on 4x4 grid
37
+ diagNegMin: B # Defines minimum negatively-sloped diagonal (si)
38
+ diagNegMax: B # Defines maximum negatively-sloped diagonal (sa)
39
+ diagPosMin: B # Defines minimum positively-sloped diagonal (di)
40
+ diagPosMax: B # Defines maximum positively-sloped diagonal (da)
41
+ """
42
+
43
+ Glat_format_3_subbox_entry = """
44
+ >
45
+ left: B # xi
46
+ right: B # xa
47
+ bottom: B # yi
48
+ top: B # ya
49
+ diagNegMin: B # Defines minimum negatively-sloped diagonal (si)
50
+ diagNegMax: B # Defines maximum negatively-sloped diagonal (sa)
51
+ diagPosMin: B # Defines minimum positively-sloped diagonal (di)
52
+ diagPosMax: B # Defines maximum positively-sloped diagonal (da)
53
+ """
54
+
55
+
56
+ class _Object:
57
+ pass
58
+
59
+
60
+ class _Dict(dict):
61
+ pass
62
+
63
+
64
+ class table_G__l_a_t(DefaultTable.DefaultTable):
65
+ """Graphite Glyph Attributes table
66
+
67
+ See also https://graphite.sil.org/graphite_techAbout#graphite-font-tables
68
+ """
69
+
70
+ def __init__(self, tag=None):
71
+ DefaultTable.DefaultTable.__init__(self, tag)
72
+ self.scheme = 0
73
+
74
+ def decompile(self, data, ttFont):
75
+ sstruct.unpack2(Glat_format_0, data, self)
76
+ self.version = float(floatToFixedToStr(self.version, precisionBits=16))
77
+ if self.version <= 1.9:
78
+ decoder = partial(self.decompileAttributes12, fmt=Glat_format_1_entry)
79
+ elif self.version <= 2.9:
80
+ decoder = partial(self.decompileAttributes12, fmt=Glat_format_23_entry)
81
+ elif self.version >= 3.0:
82
+ (data, self.scheme) = grUtils.decompress(data)
83
+ sstruct.unpack2(Glat_format_3, data, self)
84
+ self.hasOctaboxes = (self.compression & 1) == 1
85
+ decoder = self.decompileAttributes3
86
+
87
+ gloc = ttFont["Gloc"]
88
+ self.attributes = {}
89
+ count = 0
90
+ for s, e in zip(gloc, gloc[1:]):
91
+ self.attributes[ttFont.getGlyphName(count)] = decoder(data[s:e])
92
+ count += 1
93
+
94
+ def decompileAttributes12(self, data, fmt):
95
+ attributes = _Dict()
96
+ while len(data) > 3:
97
+ e, data = sstruct.unpack2(fmt, data, _Object())
98
+ keys = range(e.attNum, e.attNum + e.num)
99
+ if len(data) >= 2 * e.num:
100
+ vals = struct.unpack_from((">%dh" % e.num), data)
101
+ attributes.update(zip(keys, vals))
102
+ data = data[2 * e.num :]
103
+ return attributes
104
+
105
+ def decompileAttributes3(self, data):
106
+ if self.hasOctaboxes:
107
+ o, data = sstruct.unpack2(Glat_format_3_octabox_metrics, data, _Object())
108
+ numsub = bin(o.subboxBitmap).count("1")
109
+ o.subboxes = []
110
+ for b in range(numsub):
111
+ if len(data) >= 8:
112
+ subbox, data = sstruct.unpack2(
113
+ Glat_format_3_subbox_entry, data, _Object()
114
+ )
115
+ o.subboxes.append(subbox)
116
+ attrs = self.decompileAttributes12(data, Glat_format_23_entry)
117
+ if self.hasOctaboxes:
118
+ attrs.octabox = o
119
+ return attrs
120
+
121
+ def compile(self, ttFont):
122
+ data = sstruct.pack(Glat_format_0, self)
123
+ if self.version <= 1.9:
124
+ encoder = partial(self.compileAttributes12, fmt=Glat_format_1_entry)
125
+ elif self.version <= 2.9:
126
+ encoder = partial(self.compileAttributes12, fmt=Glat_format_1_entry)
127
+ elif self.version >= 3.0:
128
+ self.compression = (self.scheme << 27) + (1 if self.hasOctaboxes else 0)
129
+ data = sstruct.pack(Glat_format_3, self)
130
+ encoder = self.compileAttributes3
131
+
132
+ glocs = []
133
+ for n in range(len(self.attributes)):
134
+ glocs.append(len(data))
135
+ data += encoder(self.attributes[ttFont.getGlyphName(n)])
136
+ glocs.append(len(data))
137
+ ttFont["Gloc"].set(glocs)
138
+
139
+ if self.version >= 3.0:
140
+ data = grUtils.compress(self.scheme, data)
141
+ return data
142
+
143
+ def compileAttributes12(self, attrs, fmt):
144
+ data = b""
145
+ for e in grUtils.entries(attrs):
146
+ data += sstruct.pack(fmt, {"attNum": e[0], "num": e[1]}) + struct.pack(
147
+ (">%dh" % len(e[2])), *e[2]
148
+ )
149
+ return data
150
+
151
+ def compileAttributes3(self, attrs):
152
+ if self.hasOctaboxes:
153
+ o = attrs.octabox
154
+ data = sstruct.pack(Glat_format_3_octabox_metrics, o)
155
+ numsub = bin(o.subboxBitmap).count("1")
156
+ for b in range(numsub):
157
+ data += sstruct.pack(Glat_format_3_subbox_entry, o.subboxes[b])
158
+ else:
159
+ data = ""
160
+ return data + self.compileAttributes12(attrs, Glat_format_23_entry)
161
+
162
+ def toXML(self, writer, ttFont):
163
+ writer.simpletag("version", version=self.version, compressionScheme=self.scheme)
164
+ writer.newline()
165
+ for n, a in sorted(
166
+ self.attributes.items(), key=lambda x: ttFont.getGlyphID(x[0])
167
+ ):
168
+ writer.begintag("glyph", name=n)
169
+ writer.newline()
170
+ if hasattr(a, "octabox"):
171
+ o = a.octabox
172
+ formatstring, names, fixes = sstruct.getformat(
173
+ Glat_format_3_octabox_metrics
174
+ )
175
+ vals = {}
176
+ for k in names:
177
+ if k == "subboxBitmap":
178
+ continue
179
+ vals[k] = "{:.3f}%".format(getattr(o, k) * 100.0 / 255)
180
+ vals["bitmap"] = "{:0X}".format(o.subboxBitmap)
181
+ writer.begintag("octaboxes", **vals)
182
+ writer.newline()
183
+ formatstring, names, fixes = sstruct.getformat(
184
+ Glat_format_3_subbox_entry
185
+ )
186
+ for s in o.subboxes:
187
+ vals = {}
188
+ for k in names:
189
+ vals[k] = "{:.3f}%".format(getattr(s, k) * 100.0 / 255)
190
+ writer.simpletag("octabox", **vals)
191
+ writer.newline()
192
+ writer.endtag("octaboxes")
193
+ writer.newline()
194
+ for k, v in sorted(a.items()):
195
+ writer.simpletag("attribute", index=k, value=v)
196
+ writer.newline()
197
+ writer.endtag("glyph")
198
+ writer.newline()
199
+
200
+ def fromXML(self, name, attrs, content, ttFont):
201
+ if name == "version":
202
+ self.version = float(safeEval(attrs["version"]))
203
+ self.scheme = int(safeEval(attrs["compressionScheme"]))
204
+ if name != "glyph":
205
+ return
206
+ if not hasattr(self, "attributes"):
207
+ self.attributes = {}
208
+ gname = attrs["name"]
209
+ attributes = _Dict()
210
+ for element in content:
211
+ if not isinstance(element, tuple):
212
+ continue
213
+ tag, attrs, subcontent = element
214
+ if tag == "attribute":
215
+ k = int(safeEval(attrs["index"]))
216
+ v = int(safeEval(attrs["value"]))
217
+ attributes[k] = v
218
+ elif tag == "octaboxes":
219
+ self.hasOctaboxes = True
220
+ o = _Object()
221
+ o.subboxBitmap = int(attrs["bitmap"], 16)
222
+ o.subboxes = []
223
+ del attrs["bitmap"]
224
+ for k, v in attrs.items():
225
+ setattr(o, k, int(float(v[:-1]) * 255.0 / 100.0 + 0.5))
226
+ for element in subcontent:
227
+ if not isinstance(element, tuple):
228
+ continue
229
+ (tag, attrs, subcontent) = element
230
+ so = _Object()
231
+ for k, v in attrs.items():
232
+ setattr(so, k, int(float(v[:-1]) * 255.0 / 100.0 + 0.5))
233
+ o.subboxes.append(so)
234
+ attributes.octabox = o
235
+ self.attributes[gname] = attributes
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/G__l_o_c.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc import sstruct
2
+ from fontTools.misc.textTools import safeEval
3
+ from . import DefaultTable
4
+ import array
5
+ import sys
6
+
7
+
8
+ Gloc_header = """
9
+ > # big endian
10
+ version: 16.16F # Table version
11
+ flags: H # bit 0: 1=long format, 0=short format
12
+ # bit 1: 1=attribute names, 0=no names
13
+ numAttribs: H # NUmber of attributes
14
+ """
15
+
16
+
17
+ class table_G__l_o_c(DefaultTable.DefaultTable):
18
+ """Graphite Index to Glyph Atttributes table
19
+
20
+ See also https://graphite.sil.org/graphite_techAbout#graphite-font-tables
21
+ """
22
+
23
+ dependencies = ["Glat"]
24
+
25
+ def __init__(self, tag=None):
26
+ DefaultTable.DefaultTable.__init__(self, tag)
27
+ self.attribIds = None
28
+ self.numAttribs = 0
29
+
30
+ def decompile(self, data, ttFont):
31
+ _, data = sstruct.unpack2(Gloc_header, data, self)
32
+ flags = self.flags
33
+ del self.flags
34
+ self.locations = array.array("I" if flags & 1 else "H")
35
+ self.locations.frombytes(data[: len(data) - self.numAttribs * (flags & 2)])
36
+ if sys.byteorder != "big":
37
+ self.locations.byteswap()
38
+ self.attribIds = array.array("H")
39
+ if flags & 2:
40
+ self.attribIds.frombytes(data[-self.numAttribs * 2 :])
41
+ if sys.byteorder != "big":
42
+ self.attribIds.byteswap()
43
+
44
+ def compile(self, ttFont):
45
+ data = sstruct.pack(
46
+ Gloc_header,
47
+ dict(
48
+ version=1.0,
49
+ flags=(bool(self.attribIds) << 1) + (self.locations.typecode == "I"),
50
+ numAttribs=self.numAttribs,
51
+ ),
52
+ )
53
+ if sys.byteorder != "big":
54
+ self.locations.byteswap()
55
+ data += self.locations.tobytes()
56
+ if sys.byteorder != "big":
57
+ self.locations.byteswap()
58
+ if self.attribIds:
59
+ if sys.byteorder != "big":
60
+ self.attribIds.byteswap()
61
+ data += self.attribIds.tobytes()
62
+ if sys.byteorder != "big":
63
+ self.attribIds.byteswap()
64
+ return data
65
+
66
+ def set(self, locations):
67
+ long_format = max(locations) >= 65536
68
+ self.locations = array.array("I" if long_format else "H", locations)
69
+
70
+ def toXML(self, writer, ttFont):
71
+ writer.simpletag("attributes", number=self.numAttribs)
72
+ writer.newline()
73
+
74
+ def fromXML(self, name, attrs, content, ttFont):
75
+ if name == "attributes":
76
+ self.numAttribs = int(safeEval(attrs["number"]))
77
+
78
+ def __getitem__(self, index):
79
+ return self.locations[index]
80
+
81
+ def __len__(self):
82
+ return len(self.locations)
83
+
84
+ def __iter__(self):
85
+ return iter(self.locations)
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/H_V_A_R_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_H_V_A_R_(BaseTTXConverter):
5
+ """Horizontal Metrics Variations table
6
+
7
+ The ``HVAR`` table contains variations in horizontal glyph metrics
8
+ in variable fonts.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/hvar
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/J_S_T_F_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_J_S_T_F_(BaseTTXConverter):
5
+ """Justification table
6
+
7
+ The ``JSTF`` table contains glyph substitution and positioning
8
+ data used to perform text justification.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/jstf
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/L_T_S_H_.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc.textTools import safeEval
2
+ from . import DefaultTable
3
+ import struct
4
+ import array
5
+
6
+ # XXX I've lowered the strictness, to make sure Apple's own Chicago
7
+ # XXX gets through. They're looking into it, I hope to raise the standards
8
+ # XXX back to normal eventually.
9
+
10
+
11
+ class table_L_T_S_H_(DefaultTable.DefaultTable):
12
+ """Linear Threshold table
13
+
14
+ The ``LTSH`` table contains per-glyph settings indicating the ppem sizes
15
+ at which the advance width metric should be scaled linearly, despite the
16
+ effects of any TrueType instructions that might otherwise alter the
17
+ advance width.
18
+
19
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/ltsh
20
+ """
21
+
22
+ def decompile(self, data, ttFont):
23
+ version, numGlyphs = struct.unpack(">HH", data[:4])
24
+ data = data[4:]
25
+ assert version == 0, "unknown version: %s" % version
26
+ assert (len(data) % numGlyphs) < 4, "numGlyphs doesn't match data length"
27
+ # ouch: the assertion is not true in Chicago!
28
+ # assert numGlyphs == ttFont['maxp'].numGlyphs
29
+ yPels = array.array("B")
30
+ yPels.frombytes(data)
31
+ self.yPels = {}
32
+ for i in range(numGlyphs):
33
+ self.yPels[ttFont.getGlyphName(i)] = yPels[i]
34
+
35
+ def compile(self, ttFont):
36
+ version = 0
37
+ names = list(self.yPels.keys())
38
+ numGlyphs = len(names)
39
+ yPels = [0] * numGlyphs
40
+ # ouch: the assertion is not true in Chicago!
41
+ # assert len(self.yPels) == ttFont['maxp'].numGlyphs == numGlyphs
42
+ for name in names:
43
+ yPels[ttFont.getGlyphID(name)] = self.yPels[name]
44
+ yPels = array.array("B", yPels)
45
+ return struct.pack(">HH", version, numGlyphs) + yPels.tobytes()
46
+
47
+ def toXML(self, writer, ttFont):
48
+ names = sorted(self.yPels.keys())
49
+ for name in names:
50
+ writer.simpletag("yPel", name=name, value=self.yPels[name])
51
+ writer.newline()
52
+
53
+ def fromXML(self, name, attrs, content, ttFont):
54
+ if not hasattr(self, "yPels"):
55
+ self.yPels = {}
56
+ if name != "yPel":
57
+ return # ignore unknown tags
58
+ self.yPels[attrs["name"]] = safeEval(attrs["value"])
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_A_T_H_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_M_A_T_H_(BaseTTXConverter):
5
+ """Mathematical Typesetting table
6
+
7
+ The ``MATH`` table contains a variety of information needed to
8
+ typeset glyphs in mathematical formulas and expressions.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/math
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_E_T_A_.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc import sstruct
2
+ from fontTools.misc.textTools import byteord, safeEval
3
+ from . import DefaultTable
4
+ import pdb
5
+ import struct
6
+
7
+
8
+ METAHeaderFormat = """
9
+ > # big endian
10
+ tableVersionMajor: H
11
+ tableVersionMinor: H
12
+ metaEntriesVersionMajor: H
13
+ metaEntriesVersionMinor: H
14
+ unicodeVersion: L
15
+ metaFlags: H
16
+ nMetaRecs: H
17
+ """
18
+ # This record is followed by nMetaRecs of METAGlyphRecordFormat.
19
+ # This in turn is followd by as many METAStringRecordFormat entries
20
+ # as specified by the METAGlyphRecordFormat entries
21
+ # this is followed by the strings specifried in the METAStringRecordFormat
22
+ METAGlyphRecordFormat = """
23
+ > # big endian
24
+ glyphID: H
25
+ nMetaEntry: H
26
+ """
27
+ # This record is followd by a variable data length field:
28
+ # USHORT or ULONG hdrOffset
29
+ # Offset from start of META table to the beginning
30
+ # of this glyphs array of ns Metadata string entries.
31
+ # Size determined by metaFlags field
32
+ # METAGlyphRecordFormat entries must be sorted by glyph ID
33
+
34
+ METAStringRecordFormat = """
35
+ > # big endian
36
+ labelID: H
37
+ stringLen: H
38
+ """
39
+ # This record is followd by a variable data length field:
40
+ # USHORT or ULONG stringOffset
41
+ # METAStringRecordFormat entries must be sorted in order of labelID
42
+ # There may be more than one entry with the same labelID
43
+ # There may be more than one strign with the same content.
44
+
45
+ # Strings shall be Unicode UTF-8 encoded, and null-terminated.
46
+
47
+ METALabelDict = {
48
+ 0: "MojikumiX4051", # An integer in the range 1-20
49
+ 1: "UNIUnifiedBaseChars",
50
+ 2: "BaseFontName",
51
+ 3: "Language",
52
+ 4: "CreationDate",
53
+ 5: "FoundryName",
54
+ 6: "FoundryCopyright",
55
+ 7: "OwnerURI",
56
+ 8: "WritingScript",
57
+ 10: "StrokeCount",
58
+ 11: "IndexingRadical",
59
+ }
60
+
61
+
62
+ def getLabelString(labelID):
63
+ try:
64
+ label = METALabelDict[labelID]
65
+ except KeyError:
66
+ label = "Unknown label"
67
+ return str(label)
68
+
69
+
70
+ class table_M_E_T_A_(DefaultTable.DefaultTable):
71
+ """Glyphlets META table
72
+
73
+ The ``META`` table is used by Adobe's SING Glyphlets.
74
+
75
+ See also https://web.archive.org/web/20080627183635/http://www.adobe.com/devnet/opentype/gdk/topic.html
76
+ """
77
+
78
+ dependencies = []
79
+
80
+ def decompile(self, data, ttFont):
81
+ dummy, newData = sstruct.unpack2(METAHeaderFormat, data, self)
82
+ self.glyphRecords = []
83
+ for i in range(self.nMetaRecs):
84
+ glyphRecord, newData = sstruct.unpack2(
85
+ METAGlyphRecordFormat, newData, GlyphRecord()
86
+ )
87
+ if self.metaFlags == 0:
88
+ [glyphRecord.offset] = struct.unpack(">H", newData[:2])
89
+ newData = newData[2:]
90
+ elif self.metaFlags == 1:
91
+ [glyphRecord.offset] = struct.unpack(">H", newData[:4])
92
+ newData = newData[4:]
93
+ else:
94
+ assert 0, (
95
+ "The metaFlags field in the META table header has a value other than 0 or 1 :"
96
+ + str(self.metaFlags)
97
+ )
98
+ glyphRecord.stringRecs = []
99
+ newData = data[glyphRecord.offset :]
100
+ for j in range(glyphRecord.nMetaEntry):
101
+ stringRec, newData = sstruct.unpack2(
102
+ METAStringRecordFormat, newData, StringRecord()
103
+ )
104
+ if self.metaFlags == 0:
105
+ [stringRec.offset] = struct.unpack(">H", newData[:2])
106
+ newData = newData[2:]
107
+ else:
108
+ [stringRec.offset] = struct.unpack(">H", newData[:4])
109
+ newData = newData[4:]
110
+ stringRec.string = data[
111
+ stringRec.offset : stringRec.offset + stringRec.stringLen
112
+ ]
113
+ glyphRecord.stringRecs.append(stringRec)
114
+ self.glyphRecords.append(glyphRecord)
115
+
116
+ def compile(self, ttFont):
117
+ offsetOK = 0
118
+ self.nMetaRecs = len(self.glyphRecords)
119
+ count = 0
120
+ while offsetOK != 1:
121
+ count = count + 1
122
+ if count > 4:
123
+ pdb.set_trace()
124
+ metaData = sstruct.pack(METAHeaderFormat, self)
125
+ stringRecsOffset = len(metaData) + self.nMetaRecs * (
126
+ 6 + 2 * (self.metaFlags & 1)
127
+ )
128
+ stringRecSize = 6 + 2 * (self.metaFlags & 1)
129
+ for glyphRec in self.glyphRecords:
130
+ glyphRec.offset = stringRecsOffset
131
+ if (glyphRec.offset > 65535) and ((self.metaFlags & 1) == 0):
132
+ self.metaFlags = self.metaFlags + 1
133
+ offsetOK = -1
134
+ break
135
+ metaData = metaData + glyphRec.compile(self)
136
+ stringRecsOffset = stringRecsOffset + (
137
+ glyphRec.nMetaEntry * stringRecSize
138
+ )
139
+ # this will be the String Record offset for the next GlyphRecord.
140
+ if offsetOK == -1:
141
+ offsetOK = 0
142
+ continue
143
+
144
+ # metaData now contains the header and all of the GlyphRecords. Its length should bw
145
+ # the offset to the first StringRecord.
146
+ stringOffset = stringRecsOffset
147
+ for glyphRec in self.glyphRecords:
148
+ assert glyphRec.offset == len(
149
+ metaData
150
+ ), "Glyph record offset did not compile correctly! for rec:" + str(
151
+ glyphRec
152
+ )
153
+ for stringRec in glyphRec.stringRecs:
154
+ stringRec.offset = stringOffset
155
+ if (stringRec.offset > 65535) and ((self.metaFlags & 1) == 0):
156
+ self.metaFlags = self.metaFlags + 1
157
+ offsetOK = -1
158
+ break
159
+ metaData = metaData + stringRec.compile(self)
160
+ stringOffset = stringOffset + stringRec.stringLen
161
+ if offsetOK == -1:
162
+ offsetOK = 0
163
+ continue
164
+
165
+ if ((self.metaFlags & 1) == 1) and (stringOffset < 65536):
166
+ self.metaFlags = self.metaFlags - 1
167
+ continue
168
+ else:
169
+ offsetOK = 1
170
+
171
+ # metaData now contains the header and all of the GlyphRecords and all of the String Records.
172
+ # Its length should be the offset to the first string datum.
173
+ for glyphRec in self.glyphRecords:
174
+ for stringRec in glyphRec.stringRecs:
175
+ assert stringRec.offset == len(
176
+ metaData
177
+ ), "String offset did not compile correctly! for string:" + str(
178
+ stringRec.string
179
+ )
180
+ metaData = metaData + stringRec.string
181
+
182
+ return metaData
183
+
184
+ def toXML(self, writer, ttFont):
185
+ writer.comment(
186
+ "Lengths and number of entries in this table will be recalculated by the compiler"
187
+ )
188
+ writer.newline()
189
+ formatstring, names, fixes = sstruct.getformat(METAHeaderFormat)
190
+ for name in names:
191
+ value = getattr(self, name)
192
+ writer.simpletag(name, value=value)
193
+ writer.newline()
194
+ for glyphRec in self.glyphRecords:
195
+ glyphRec.toXML(writer, ttFont)
196
+
197
+ def fromXML(self, name, attrs, content, ttFont):
198
+ if name == "GlyphRecord":
199
+ if not hasattr(self, "glyphRecords"):
200
+ self.glyphRecords = []
201
+ glyphRec = GlyphRecord()
202
+ self.glyphRecords.append(glyphRec)
203
+ for element in content:
204
+ if isinstance(element, str):
205
+ continue
206
+ name, attrs, content = element
207
+ glyphRec.fromXML(name, attrs, content, ttFont)
208
+ glyphRec.offset = -1
209
+ glyphRec.nMetaEntry = len(glyphRec.stringRecs)
210
+ else:
211
+ setattr(self, name, safeEval(attrs["value"]))
212
+
213
+
214
+ class GlyphRecord(object):
215
+ def __init__(self):
216
+ self.glyphID = -1
217
+ self.nMetaEntry = -1
218
+ self.offset = -1
219
+ self.stringRecs = []
220
+
221
+ def toXML(self, writer, ttFont):
222
+ writer.begintag("GlyphRecord")
223
+ writer.newline()
224
+ writer.simpletag("glyphID", value=self.glyphID)
225
+ writer.newline()
226
+ writer.simpletag("nMetaEntry", value=self.nMetaEntry)
227
+ writer.newline()
228
+ for stringRec in self.stringRecs:
229
+ stringRec.toXML(writer, ttFont)
230
+ writer.endtag("GlyphRecord")
231
+ writer.newline()
232
+
233
+ def fromXML(self, name, attrs, content, ttFont):
234
+ if name == "StringRecord":
235
+ stringRec = StringRecord()
236
+ self.stringRecs.append(stringRec)
237
+ for element in content:
238
+ if isinstance(element, str):
239
+ continue
240
+ stringRec.fromXML(name, attrs, content, ttFont)
241
+ stringRec.stringLen = len(stringRec.string)
242
+ else:
243
+ setattr(self, name, safeEval(attrs["value"]))
244
+
245
+ def compile(self, parentTable):
246
+ data = sstruct.pack(METAGlyphRecordFormat, self)
247
+ if parentTable.metaFlags == 0:
248
+ datum = struct.pack(">H", self.offset)
249
+ elif parentTable.metaFlags == 1:
250
+ datum = struct.pack(">L", self.offset)
251
+ data = data + datum
252
+ return data
253
+
254
+ def __repr__(self):
255
+ return (
256
+ "GlyphRecord[ glyphID: "
257
+ + str(self.glyphID)
258
+ + ", nMetaEntry: "
259
+ + str(self.nMetaEntry)
260
+ + ", offset: "
261
+ + str(self.offset)
262
+ + " ]"
263
+ )
264
+
265
+
266
+ # XXX The following two functions are really broken around UTF-8 vs Unicode
267
+
268
+
269
+ def mapXMLToUTF8(string):
270
+ uString = str()
271
+ strLen = len(string)
272
+ i = 0
273
+ while i < strLen:
274
+ prefixLen = 0
275
+ if string[i : i + 3] == "&#x":
276
+ prefixLen = 3
277
+ elif string[i : i + 7] == "&amp;#x":
278
+ prefixLen = 7
279
+ if prefixLen:
280
+ i = i + prefixLen
281
+ j = i
282
+ while string[i] != ";":
283
+ i = i + 1
284
+ valStr = string[j:i]
285
+
286
+ uString = uString + chr(eval("0x" + valStr))
287
+ else:
288
+ uString = uString + chr(byteord(string[i]))
289
+ i = i + 1
290
+
291
+ return uString.encode("utf_8")
292
+
293
+
294
+ def mapUTF8toXML(string):
295
+ uString = string.decode("utf_8")
296
+ string = ""
297
+ for uChar in uString:
298
+ i = ord(uChar)
299
+ if (i < 0x80) and (i > 0x1F):
300
+ string = string + uChar
301
+ else:
302
+ string = string + "&#x" + hex(i)[2:] + ";"
303
+ return string
304
+
305
+
306
+ class StringRecord(object):
307
+ def toXML(self, writer, ttFont):
308
+ writer.begintag("StringRecord")
309
+ writer.newline()
310
+ writer.simpletag("labelID", value=self.labelID)
311
+ writer.comment(getLabelString(self.labelID))
312
+ writer.newline()
313
+ writer.newline()
314
+ writer.simpletag("string", value=mapUTF8toXML(self.string))
315
+ writer.newline()
316
+ writer.endtag("StringRecord")
317
+ writer.newline()
318
+
319
+ def fromXML(self, name, attrs, content, ttFont):
320
+ for element in content:
321
+ if isinstance(element, str):
322
+ continue
323
+ name, attrs, content = element
324
+ value = attrs["value"]
325
+ if name == "string":
326
+ self.string = mapXMLToUTF8(value)
327
+ else:
328
+ setattr(self, name, safeEval(value))
329
+
330
+ def compile(self, parentTable):
331
+ data = sstruct.pack(METAStringRecordFormat, self)
332
+ if parentTable.metaFlags == 0:
333
+ datum = struct.pack(">H", self.offset)
334
+ elif parentTable.metaFlags == 1:
335
+ datum = struct.pack(">L", self.offset)
336
+ data = data + datum
337
+ return data
338
+
339
+ def __repr__(self):
340
+ return (
341
+ "StringRecord [ labelID: "
342
+ + str(self.labelID)
343
+ + " aka "
344
+ + getLabelString(self.labelID)
345
+ + ", offset: "
346
+ + str(self.offset)
347
+ + ", length: "
348
+ + str(self.stringLen)
349
+ + ", string: "
350
+ + self.string
351
+ + " ]"
352
+ )
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/M_V_A_R_.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .otBase import BaseTTXConverter
2
+
3
+
4
+ class table_M_V_A_R_(BaseTTXConverter):
5
+ """Metrics Variations table
6
+
7
+ The ``MVAR`` table contains variation information for font-wide
8
+ metrics in a variable font.
9
+
10
+ See also https://learn.microsoft.com/en-us/typography/opentype/spec/mvar
11
+ """
12
+
13
+ pass
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otBase.py ADDED
@@ -0,0 +1,1464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.config import OPTIONS
2
+ from fontTools.misc.textTools import Tag, bytesjoin
3
+ from .DefaultTable import DefaultTable
4
+ from enum import IntEnum
5
+ import sys
6
+ import array
7
+ import struct
8
+ import logging
9
+ from functools import lru_cache
10
+ from typing import Iterator, NamedTuple, Optional, Tuple
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ have_uharfbuzz = False
15
+ try:
16
+ import uharfbuzz as hb
17
+
18
+ # repack method added in uharfbuzz >= 0.23; if uharfbuzz *can* be
19
+ # imported but repack method is missing, behave as if uharfbuzz
20
+ # is not available (fallback to the slower Python implementation)
21
+ have_uharfbuzz = callable(getattr(hb, "repack", None))
22
+ except ImportError:
23
+ pass
24
+
25
+ USE_HARFBUZZ_REPACKER = OPTIONS[f"{__name__}:USE_HARFBUZZ_REPACKER"]
26
+
27
+
28
+ class OverflowErrorRecord(object):
29
+ def __init__(self, overflowTuple):
30
+ self.tableType = overflowTuple[0]
31
+ self.LookupListIndex = overflowTuple[1]
32
+ self.SubTableIndex = overflowTuple[2]
33
+ self.itemName = overflowTuple[3]
34
+ self.itemIndex = overflowTuple[4]
35
+
36
+ def __repr__(self):
37
+ return str(
38
+ (
39
+ self.tableType,
40
+ "LookupIndex:",
41
+ self.LookupListIndex,
42
+ "SubTableIndex:",
43
+ self.SubTableIndex,
44
+ "ItemName:",
45
+ self.itemName,
46
+ "ItemIndex:",
47
+ self.itemIndex,
48
+ )
49
+ )
50
+
51
+
52
+ class OTLOffsetOverflowError(Exception):
53
+ def __init__(self, overflowErrorRecord):
54
+ self.value = overflowErrorRecord
55
+
56
+ def __str__(self):
57
+ return repr(self.value)
58
+
59
+
60
+ class RepackerState(IntEnum):
61
+ # Repacking control flow is implemnted using a state machine. The state machine table:
62
+ #
63
+ # State | Packing Success | Packing Failed | Exception Raised |
64
+ # ------------+-----------------+----------------+------------------+
65
+ # PURE_FT | Return result | PURE_FT | Return failure |
66
+ # HB_FT | Return result | HB_FT | FT_FALLBACK |
67
+ # FT_FALLBACK | HB_FT | FT_FALLBACK | Return failure |
68
+
69
+ # Pack only with fontTools, don't allow sharing between extensions.
70
+ PURE_FT = 1
71
+
72
+ # Attempt to pack with harfbuzz (allowing sharing between extensions)
73
+ # use fontTools to attempt overflow resolution.
74
+ HB_FT = 2
75
+
76
+ # Fallback if HB/FT packing gets stuck. Pack only with fontTools, don't allow sharing between
77
+ # extensions.
78
+ FT_FALLBACK = 3
79
+
80
+
81
+ class BaseTTXConverter(DefaultTable):
82
+ """Generic base class for TTX table converters. It functions as an
83
+ adapter between the TTX (ttLib actually) table model and the model
84
+ we use for OpenType tables, which is necessarily subtly different.
85
+ """
86
+
87
+ def decompile(self, data, font):
88
+ """Create an object from the binary data. Called automatically on access."""
89
+ from . import otTables
90
+
91
+ reader = OTTableReader(data, tableTag=self.tableTag)
92
+ tableClass = getattr(otTables, self.tableTag)
93
+ self.table = tableClass()
94
+ self.table.decompile(reader, font)
95
+
96
+ def compile(self, font):
97
+ """Compiles the table into binary. Called automatically on save."""
98
+
99
+ # General outline:
100
+ # Create a top-level OTTableWriter for the GPOS/GSUB table.
101
+ # Call the compile method for the the table
102
+ # for each 'converter' record in the table converter list
103
+ # call converter's write method for each item in the value.
104
+ # - For simple items, the write method adds a string to the
105
+ # writer's self.items list.
106
+ # - For Struct/Table/Subtable items, it add first adds new writer to the
107
+ # to the writer's self.items, then calls the item's compile method.
108
+ # This creates a tree of writers, rooted at the GUSB/GPOS writer, with
109
+ # each writer representing a table, and the writer.items list containing
110
+ # the child data strings and writers.
111
+ # call the getAllData method
112
+ # call _doneWriting, which removes duplicates
113
+ # call _gatherTables. This traverses the tables, adding unique occurences to a flat list of tables
114
+ # Traverse the flat list of tables, calling getDataLength on each to update their position
115
+ # Traverse the flat list of tables again, calling getData each get the data in the table, now that
116
+ # pos's and offset are known.
117
+
118
+ # If a lookup subtable overflows an offset, we have to start all over.
119
+ overflowRecord = None
120
+ # this is 3-state option: default (None) means automatically use hb.repack or
121
+ # silently fall back if it fails; True, use it and raise error if not possible
122
+ # or it errors out; False, don't use it, even if you can.
123
+ use_hb_repack = font.cfg[USE_HARFBUZZ_REPACKER]
124
+ if self.tableTag in ("GSUB", "GPOS"):
125
+ if use_hb_repack is False:
126
+ log.debug(
127
+ "hb.repack disabled, compiling '%s' with pure-python serializer",
128
+ self.tableTag,
129
+ )
130
+ elif not have_uharfbuzz:
131
+ if use_hb_repack is True:
132
+ raise ImportError("No module named 'uharfbuzz'")
133
+ else:
134
+ assert use_hb_repack is None
135
+ log.debug(
136
+ "uharfbuzz not found, compiling '%s' with pure-python serializer",
137
+ self.tableTag,
138
+ )
139
+
140
+ if (
141
+ use_hb_repack in (None, True)
142
+ and have_uharfbuzz
143
+ and self.tableTag in ("GSUB", "GPOS")
144
+ ):
145
+ state = RepackerState.HB_FT
146
+ else:
147
+ state = RepackerState.PURE_FT
148
+
149
+ hb_first_error_logged = False
150
+ lastOverflowRecord = None
151
+ while True:
152
+ try:
153
+ writer = OTTableWriter(tableTag=self.tableTag)
154
+ self.table.compile(writer, font)
155
+ if state == RepackerState.HB_FT:
156
+ return self.tryPackingHarfbuzz(writer, hb_first_error_logged)
157
+ elif state == RepackerState.PURE_FT:
158
+ return self.tryPackingFontTools(writer)
159
+ elif state == RepackerState.FT_FALLBACK:
160
+ # Run packing with FontTools only, but don't return the result as it will
161
+ # not be optimally packed. Once a successful packing has been found, state is
162
+ # changed back to harfbuzz packing to produce the final, optimal, packing.
163
+ self.tryPackingFontTools(writer)
164
+ log.debug(
165
+ "Re-enabling sharing between extensions and switching back to "
166
+ "harfbuzz+fontTools packing."
167
+ )
168
+ state = RepackerState.HB_FT
169
+
170
+ except OTLOffsetOverflowError as e:
171
+ hb_first_error_logged = True
172
+ ok = self.tryResolveOverflow(font, e, lastOverflowRecord)
173
+ lastOverflowRecord = e.value
174
+
175
+ if ok:
176
+ continue
177
+
178
+ if state is RepackerState.HB_FT:
179
+ log.debug(
180
+ "Harfbuzz packing out of resolutions, disabling sharing between extensions and "
181
+ "switching to fontTools only packing."
182
+ )
183
+ state = RepackerState.FT_FALLBACK
184
+ else:
185
+ raise
186
+
187
+ def tryPackingHarfbuzz(self, writer, hb_first_error_logged):
188
+ try:
189
+ log.debug("serializing '%s' with hb.repack", self.tableTag)
190
+ return writer.getAllDataUsingHarfbuzz(self.tableTag)
191
+ except (ValueError, MemoryError, hb.RepackerError) as e:
192
+ # Only log hb repacker errors the first time they occur in
193
+ # the offset-overflow resolution loop, they are just noisy.
194
+ # Maybe we can revisit this if/when uharfbuzz actually gives
195
+ # us more info as to why hb.repack failed...
196
+ if not hb_first_error_logged:
197
+ error_msg = f"{type(e).__name__}"
198
+ if str(e) != "":
199
+ error_msg += f": {e}"
200
+ log.warning(
201
+ "hb.repack failed to serialize '%s', attempting fonttools resolutions "
202
+ "; the error message was: %s",
203
+ self.tableTag,
204
+ error_msg,
205
+ )
206
+ hb_first_error_logged = True
207
+ return writer.getAllData(remove_duplicate=False)
208
+
209
+ def tryPackingFontTools(self, writer):
210
+ return writer.getAllData()
211
+
212
+ def tryResolveOverflow(self, font, e, lastOverflowRecord):
213
+ ok = 0
214
+ if lastOverflowRecord == e.value:
215
+ # Oh well...
216
+ return ok
217
+
218
+ overflowRecord = e.value
219
+ log.info("Attempting to fix OTLOffsetOverflowError %s", e)
220
+
221
+ if overflowRecord.itemName is None:
222
+ from .otTables import fixLookupOverFlows
223
+
224
+ ok = fixLookupOverFlows(font, overflowRecord)
225
+ else:
226
+ from .otTables import fixSubTableOverFlows
227
+
228
+ ok = fixSubTableOverFlows(font, overflowRecord)
229
+
230
+ if ok:
231
+ return ok
232
+
233
+ # Try upgrading lookup to Extension and hope
234
+ # that cross-lookup sharing not happening would
235
+ # fix overflow...
236
+ from .otTables import fixLookupOverFlows
237
+
238
+ return fixLookupOverFlows(font, overflowRecord)
239
+
240
+ def toXML(self, writer, font):
241
+ self.table.toXML2(writer, font)
242
+
243
+ def fromXML(self, name, attrs, content, font):
244
+ from . import otTables
245
+
246
+ if not hasattr(self, "table"):
247
+ tableClass = getattr(otTables, self.tableTag)
248
+ self.table = tableClass()
249
+ self.table.fromXML(name, attrs, content, font)
250
+ self.table.populateDefaults()
251
+
252
+ def ensureDecompiled(self, recurse=True):
253
+ self.table.ensureDecompiled(recurse=recurse)
254
+
255
+
256
+ # https://github.com/fonttools/fonttools/pull/2285#issuecomment-834652928
257
+ assert len(struct.pack("i", 0)) == 4
258
+ assert array.array("i").itemsize == 4, "Oops, file a bug against fonttools."
259
+
260
+
261
+ class OTTableReader(object):
262
+ """Helper class to retrieve data from an OpenType table."""
263
+
264
+ __slots__ = ("data", "offset", "pos", "localState", "tableTag")
265
+
266
+ def __init__(self, data, localState=None, offset=0, tableTag=None):
267
+ self.data = data
268
+ self.offset = offset
269
+ self.pos = offset
270
+ self.localState = localState
271
+ self.tableTag = tableTag
272
+
273
+ def advance(self, count):
274
+ self.pos += count
275
+
276
+ def seek(self, pos):
277
+ self.pos = pos
278
+
279
+ def copy(self):
280
+ other = self.__class__(self.data, self.localState, self.offset, self.tableTag)
281
+ other.pos = self.pos
282
+ return other
283
+
284
+ def getSubReader(self, offset):
285
+ offset = self.offset + offset
286
+ return self.__class__(self.data, self.localState, offset, self.tableTag)
287
+
288
+ def readValue(self, typecode, staticSize):
289
+ pos = self.pos
290
+ newpos = pos + staticSize
291
+ (value,) = struct.unpack(f">{typecode}", self.data[pos:newpos])
292
+ self.pos = newpos
293
+ return value
294
+
295
+ def readArray(self, typecode, staticSize, count):
296
+ pos = self.pos
297
+ newpos = pos + count * staticSize
298
+ value = array.array(typecode, self.data[pos:newpos])
299
+ if sys.byteorder != "big":
300
+ value.byteswap()
301
+ self.pos = newpos
302
+ return value.tolist()
303
+
304
+ def readInt8(self):
305
+ return self.readValue("b", staticSize=1)
306
+
307
+ def readInt8Array(self, count):
308
+ return self.readArray("b", staticSize=1, count=count)
309
+
310
+ def readShort(self):
311
+ return self.readValue("h", staticSize=2)
312
+
313
+ def readShortArray(self, count):
314
+ return self.readArray("h", staticSize=2, count=count)
315
+
316
+ def readLong(self):
317
+ return self.readValue("i", staticSize=4)
318
+
319
+ def readLongArray(self, count):
320
+ return self.readArray("i", staticSize=4, count=count)
321
+
322
+ def readUInt8(self):
323
+ return self.readValue("B", staticSize=1)
324
+
325
+ def readUInt8Array(self, count):
326
+ return self.readArray("B", staticSize=1, count=count)
327
+
328
+ def readUShort(self):
329
+ return self.readValue("H", staticSize=2)
330
+
331
+ def readUShortArray(self, count):
332
+ return self.readArray("H", staticSize=2, count=count)
333
+
334
+ def readULong(self):
335
+ return self.readValue("I", staticSize=4)
336
+
337
+ def readULongArray(self, count):
338
+ return self.readArray("I", staticSize=4, count=count)
339
+
340
+ def readUInt24(self):
341
+ pos = self.pos
342
+ newpos = pos + 3
343
+ (value,) = struct.unpack(">l", b"\0" + self.data[pos:newpos])
344
+ self.pos = newpos
345
+ return value
346
+
347
+ def readUInt24Array(self, count):
348
+ return [self.readUInt24() for _ in range(count)]
349
+
350
+ def readTag(self):
351
+ pos = self.pos
352
+ newpos = pos + 4
353
+ value = Tag(self.data[pos:newpos])
354
+ assert len(value) == 4, value
355
+ self.pos = newpos
356
+ return value
357
+
358
+ def readData(self, count):
359
+ pos = self.pos
360
+ newpos = pos + count
361
+ value = self.data[pos:newpos]
362
+ self.pos = newpos
363
+ return value
364
+
365
+ def __setitem__(self, name, value):
366
+ state = self.localState.copy() if self.localState else dict()
367
+ state[name] = value
368
+ self.localState = state
369
+
370
+ def __getitem__(self, name):
371
+ return self.localState and self.localState[name]
372
+
373
+ def __contains__(self, name):
374
+ return self.localState and name in self.localState
375
+
376
+
377
+ class OffsetToWriter(object):
378
+ def __init__(self, subWriter, offsetSize):
379
+ self.subWriter = subWriter
380
+ self.offsetSize = offsetSize
381
+
382
+ def __eq__(self, other):
383
+ if type(self) != type(other):
384
+ return NotImplemented
385
+ return self.subWriter == other.subWriter and self.offsetSize == other.offsetSize
386
+
387
+ def __hash__(self):
388
+ # only works after self._doneWriting() has been called
389
+ return hash((self.subWriter, self.offsetSize))
390
+
391
+
392
+ class OTTableWriter(object):
393
+ """Helper class to gather and assemble data for OpenType tables."""
394
+
395
+ def __init__(self, localState=None, tableTag=None):
396
+ self.items = []
397
+ self.pos = None
398
+ self.localState = localState
399
+ self.tableTag = tableTag
400
+ self.parent = None
401
+ self.name = "<none>"
402
+
403
+ def __setitem__(self, name, value):
404
+ state = self.localState.copy() if self.localState else dict()
405
+ state[name] = value
406
+ self.localState = state
407
+
408
+ def __getitem__(self, name):
409
+ return self.localState[name]
410
+
411
+ def __delitem__(self, name):
412
+ del self.localState[name]
413
+
414
+ # assembler interface
415
+
416
+ def getDataLength(self):
417
+ """Return the length of this table in bytes, without subtables."""
418
+ l = 0
419
+ for item in self.items:
420
+ if hasattr(item, "getCountData"):
421
+ l += item.size
422
+ elif hasattr(item, "subWriter"):
423
+ l += item.offsetSize
424
+ else:
425
+ l = l + len(item)
426
+ return l
427
+
428
+ def getData(self):
429
+ """Assemble the data for this writer/table, without subtables."""
430
+ items = list(self.items) # make a shallow copy
431
+ pos = self.pos
432
+ numItems = len(items)
433
+ for i in range(numItems):
434
+ item = items[i]
435
+
436
+ if hasattr(item, "subWriter"):
437
+ if item.offsetSize == 4:
438
+ items[i] = packULong(item.subWriter.pos - pos)
439
+ elif item.offsetSize == 2:
440
+ try:
441
+ items[i] = packUShort(item.subWriter.pos - pos)
442
+ except struct.error:
443
+ # provide data to fix overflow problem.
444
+ overflowErrorRecord = self.getOverflowErrorRecord(
445
+ item.subWriter
446
+ )
447
+
448
+ raise OTLOffsetOverflowError(overflowErrorRecord)
449
+ elif item.offsetSize == 3:
450
+ items[i] = packUInt24(item.subWriter.pos - pos)
451
+ else:
452
+ raise ValueError(item.offsetSize)
453
+
454
+ return bytesjoin(items)
455
+
456
+ def getDataForHarfbuzz(self):
457
+ """Assemble the data for this writer/table with all offset field set to 0"""
458
+ items = list(self.items)
459
+ packFuncs = {2: packUShort, 3: packUInt24, 4: packULong}
460
+ for i, item in enumerate(items):
461
+ if hasattr(item, "subWriter"):
462
+ # Offset value is not needed in harfbuzz repacker, so setting offset to 0 to avoid overflow here
463
+ if item.offsetSize in packFuncs:
464
+ items[i] = packFuncs[item.offsetSize](0)
465
+ else:
466
+ raise ValueError(item.offsetSize)
467
+
468
+ return bytesjoin(items)
469
+
470
+ def __hash__(self):
471
+ # only works after self._doneWriting() has been called
472
+ return hash(self.items)
473
+
474
+ def __ne__(self, other):
475
+ result = self.__eq__(other)
476
+ return result if result is NotImplemented else not result
477
+
478
+ def __eq__(self, other):
479
+ if type(self) != type(other):
480
+ return NotImplemented
481
+ return self.items == other.items
482
+
483
+ def _doneWriting(self, internedTables, shareExtension=False):
484
+ # Convert CountData references to data string items
485
+ # collapse duplicate table references to a unique entry
486
+ # "tables" are OTTableWriter objects.
487
+
488
+ # For Extension Lookup types, we can
489
+ # eliminate duplicates only within the tree under the Extension Lookup,
490
+ # as offsets may exceed 64K even between Extension LookupTable subtables.
491
+ isExtension = hasattr(self, "Extension")
492
+
493
+ # Certain versions of Uniscribe reject the font if the GSUB/GPOS top-level
494
+ # arrays (ScriptList, FeatureList, LookupList) point to the same, possibly
495
+ # empty, array. So, we don't share those.
496
+ # See: https://github.com/fonttools/fonttools/issues/518
497
+ dontShare = hasattr(self, "DontShare")
498
+
499
+ if isExtension and not shareExtension:
500
+ internedTables = {}
501
+
502
+ items = self.items
503
+ for i, item in enumerate(items):
504
+ if hasattr(item, "getCountData"):
505
+ items[i] = item.getCountData()
506
+ elif hasattr(item, "subWriter"):
507
+ item.subWriter._doneWriting(
508
+ internedTables, shareExtension=shareExtension
509
+ )
510
+ # At this point, all subwriters are hashable based on their items.
511
+ # (See hash and comparison magic methods above.) So the ``setdefault``
512
+ # call here will return the first writer object we've seen with
513
+ # equal content, or store it in the dictionary if it's not been
514
+ # seen yet. We therefore replace the subwriter object with an equivalent
515
+ # object, which deduplicates the tree.
516
+ if not dontShare:
517
+ items[i].subWriter = internedTables.setdefault(
518
+ item.subWriter, item.subWriter
519
+ )
520
+ self.items = tuple(items)
521
+
522
+ def _gatherTables(self, tables, extTables, done):
523
+ # Convert table references in self.items tree to a flat
524
+ # list of tables in depth-first traversal order.
525
+ # "tables" are OTTableWriter objects.
526
+ # We do the traversal in reverse order at each level, in order to
527
+ # resolve duplicate references to be the last reference in the list of tables.
528
+ # For extension lookups, duplicate references can be merged only within the
529
+ # writer tree under the extension lookup.
530
+
531
+ done[id(self)] = True
532
+
533
+ numItems = len(self.items)
534
+ iRange = list(range(numItems))
535
+ iRange.reverse()
536
+
537
+ isExtension = hasattr(self, "Extension")
538
+
539
+ selfTables = tables
540
+
541
+ if isExtension:
542
+ assert (
543
+ extTables is not None
544
+ ), "Program or XML editing error. Extension subtables cannot contain extensions subtables"
545
+ tables, extTables, done = extTables, None, {}
546
+
547
+ # add Coverage table if it is sorted last.
548
+ sortCoverageLast = False
549
+ if hasattr(self, "sortCoverageLast"):
550
+ # Find coverage table
551
+ for i in range(numItems):
552
+ item = self.items[i]
553
+ if (
554
+ hasattr(item, "subWriter")
555
+ and getattr(item.subWriter, "name", None) == "Coverage"
556
+ ):
557
+ sortCoverageLast = True
558
+ break
559
+ if id(item.subWriter) not in done:
560
+ item.subWriter._gatherTables(tables, extTables, done)
561
+ else:
562
+ # We're a new parent of item
563
+ pass
564
+
565
+ for i in iRange:
566
+ item = self.items[i]
567
+ if not hasattr(item, "subWriter"):
568
+ continue
569
+
570
+ if (
571
+ sortCoverageLast
572
+ and (i == 1)
573
+ and getattr(item.subWriter, "name", None) == "Coverage"
574
+ ):
575
+ # we've already 'gathered' it above
576
+ continue
577
+
578
+ if id(item.subWriter) not in done:
579
+ item.subWriter._gatherTables(tables, extTables, done)
580
+ else:
581
+ # Item is already written out by other parent
582
+ pass
583
+
584
+ selfTables.append(self)
585
+
586
+ def _gatherGraphForHarfbuzz(self, tables, obj_list, done, objidx, virtual_edges):
587
+ real_links = []
588
+ virtual_links = []
589
+ item_idx = objidx
590
+
591
+ # Merge virtual_links from parent
592
+ for idx in virtual_edges:
593
+ virtual_links.append((0, 0, idx))
594
+
595
+ sortCoverageLast = False
596
+ coverage_idx = 0
597
+ if hasattr(self, "sortCoverageLast"):
598
+ # Find coverage table
599
+ for i, item in enumerate(self.items):
600
+ if getattr(item, "name", None) == "Coverage":
601
+ sortCoverageLast = True
602
+ if id(item) not in done:
603
+ coverage_idx = item_idx = item._gatherGraphForHarfbuzz(
604
+ tables, obj_list, done, item_idx, virtual_edges
605
+ )
606
+ else:
607
+ coverage_idx = done[id(item)]
608
+ virtual_edges.append(coverage_idx)
609
+ break
610
+
611
+ child_idx = 0
612
+ offset_pos = 0
613
+ for i, item in enumerate(self.items):
614
+ if hasattr(item, "subWriter"):
615
+ pos = offset_pos
616
+ elif hasattr(item, "getCountData"):
617
+ offset_pos += item.size
618
+ continue
619
+ else:
620
+ offset_pos = offset_pos + len(item)
621
+ continue
622
+
623
+ if id(item.subWriter) not in done:
624
+ child_idx = item_idx = item.subWriter._gatherGraphForHarfbuzz(
625
+ tables, obj_list, done, item_idx, virtual_edges
626
+ )
627
+ else:
628
+ child_idx = done[id(item.subWriter)]
629
+
630
+ real_edge = (pos, item.offsetSize, child_idx)
631
+ real_links.append(real_edge)
632
+ offset_pos += item.offsetSize
633
+
634
+ tables.append(self)
635
+ obj_list.append((real_links, virtual_links))
636
+ item_idx += 1
637
+ done[id(self)] = item_idx
638
+ if sortCoverageLast:
639
+ virtual_edges.pop()
640
+
641
+ return item_idx
642
+
643
+ def getAllDataUsingHarfbuzz(self, tableTag):
644
+ """The Whole table is represented as a Graph.
645
+ Assemble graph data and call Harfbuzz repacker to pack the table.
646
+ Harfbuzz repacker is faster and retain as much sub-table sharing as possible, see also:
647
+ https://github.com/harfbuzz/harfbuzz/blob/main/docs/repacker.md
648
+ The input format for hb.repack() method is explained here:
649
+ https://github.com/harfbuzz/uharfbuzz/blob/main/src/uharfbuzz/_harfbuzz.pyx#L1149
650
+ """
651
+ internedTables = {}
652
+ self._doneWriting(internedTables, shareExtension=True)
653
+ tables = []
654
+ obj_list = []
655
+ done = {}
656
+ objidx = 0
657
+ virtual_edges = []
658
+ self._gatherGraphForHarfbuzz(tables, obj_list, done, objidx, virtual_edges)
659
+ # Gather all data in two passes: the absolute positions of all
660
+ # subtable are needed before the actual data can be assembled.
661
+ pos = 0
662
+ for table in tables:
663
+ table.pos = pos
664
+ pos = pos + table.getDataLength()
665
+
666
+ data = []
667
+ for table in tables:
668
+ tableData = table.getDataForHarfbuzz()
669
+ data.append(tableData)
670
+
671
+ if hasattr(hb, "repack_with_tag"):
672
+ return hb.repack_with_tag(str(tableTag), data, obj_list)
673
+ else:
674
+ return hb.repack(data, obj_list)
675
+
676
+ def getAllData(self, remove_duplicate=True):
677
+ """Assemble all data, including all subtables."""
678
+ if remove_duplicate:
679
+ internedTables = {}
680
+ self._doneWriting(internedTables)
681
+ tables = []
682
+ extTables = []
683
+ done = {}
684
+ self._gatherTables(tables, extTables, done)
685
+ tables.reverse()
686
+ extTables.reverse()
687
+ # Gather all data in two passes: the absolute positions of all
688
+ # subtable are needed before the actual data can be assembled.
689
+ pos = 0
690
+ for table in tables:
691
+ table.pos = pos
692
+ pos = pos + table.getDataLength()
693
+
694
+ for table in extTables:
695
+ table.pos = pos
696
+ pos = pos + table.getDataLength()
697
+
698
+ data = []
699
+ for table in tables:
700
+ tableData = table.getData()
701
+ data.append(tableData)
702
+
703
+ for table in extTables:
704
+ tableData = table.getData()
705
+ data.append(tableData)
706
+
707
+ return bytesjoin(data)
708
+
709
+ # interface for gathering data, as used by table.compile()
710
+
711
+ def getSubWriter(self):
712
+ subwriter = self.__class__(self.localState, self.tableTag)
713
+ subwriter.parent = (
714
+ self # because some subtables have idential values, we discard
715
+ )
716
+ # the duplicates under the getAllData method. Hence some
717
+ # subtable writers can have more than one parent writer.
718
+ # But we just care about first one right now.
719
+ return subwriter
720
+
721
+ def writeValue(self, typecode, value):
722
+ self.items.append(struct.pack(f">{typecode}", value))
723
+
724
+ def writeArray(self, typecode, values):
725
+ a = array.array(typecode, values)
726
+ if sys.byteorder != "big":
727
+ a.byteswap()
728
+ self.items.append(a.tobytes())
729
+
730
+ def writeInt8(self, value):
731
+ assert -128 <= value < 128, value
732
+ self.items.append(struct.pack(">b", value))
733
+
734
+ def writeInt8Array(self, values):
735
+ self.writeArray("b", values)
736
+
737
+ def writeShort(self, value):
738
+ assert -32768 <= value < 32768, value
739
+ self.items.append(struct.pack(">h", value))
740
+
741
+ def writeShortArray(self, values):
742
+ self.writeArray("h", values)
743
+
744
+ def writeLong(self, value):
745
+ self.items.append(struct.pack(">i", value))
746
+
747
+ def writeLongArray(self, values):
748
+ self.writeArray("i", values)
749
+
750
+ def writeUInt8(self, value):
751
+ assert 0 <= value < 256, value
752
+ self.items.append(struct.pack(">B", value))
753
+
754
+ def writeUInt8Array(self, values):
755
+ self.writeArray("B", values)
756
+
757
+ def writeUShort(self, value):
758
+ assert 0 <= value < 0x10000, value
759
+ self.items.append(struct.pack(">H", value))
760
+
761
+ def writeUShortArray(self, values):
762
+ self.writeArray("H", values)
763
+
764
+ def writeULong(self, value):
765
+ self.items.append(struct.pack(">I", value))
766
+
767
+ def writeULongArray(self, values):
768
+ self.writeArray("I", values)
769
+
770
+ def writeUInt24(self, value):
771
+ assert 0 <= value < 0x1000000, value
772
+ b = struct.pack(">L", value)
773
+ self.items.append(b[1:])
774
+
775
+ def writeUInt24Array(self, values):
776
+ for value in values:
777
+ self.writeUInt24(value)
778
+
779
+ def writeTag(self, tag):
780
+ tag = Tag(tag).tobytes()
781
+ assert len(tag) == 4, tag
782
+ self.items.append(tag)
783
+
784
+ def writeSubTable(self, subWriter, offsetSize):
785
+ self.items.append(OffsetToWriter(subWriter, offsetSize))
786
+
787
+ def writeCountReference(self, table, name, size=2, value=None):
788
+ ref = CountReference(table, name, size=size, value=value)
789
+ self.items.append(ref)
790
+ return ref
791
+
792
+ def writeStruct(self, format, values):
793
+ data = struct.pack(*(format,) + values)
794
+ self.items.append(data)
795
+
796
+ def writeData(self, data):
797
+ self.items.append(data)
798
+
799
+ def getOverflowErrorRecord(self, item):
800
+ LookupListIndex = SubTableIndex = itemName = itemIndex = None
801
+ if self.name == "LookupList":
802
+ LookupListIndex = item.repeatIndex
803
+ elif self.name == "Lookup":
804
+ LookupListIndex = self.repeatIndex
805
+ SubTableIndex = item.repeatIndex
806
+ else:
807
+ itemName = getattr(item, "name", "<none>")
808
+ if hasattr(item, "repeatIndex"):
809
+ itemIndex = item.repeatIndex
810
+ if self.name == "SubTable":
811
+ LookupListIndex = self.parent.repeatIndex
812
+ SubTableIndex = self.repeatIndex
813
+ elif self.name == "ExtSubTable":
814
+ LookupListIndex = self.parent.parent.repeatIndex
815
+ SubTableIndex = self.parent.repeatIndex
816
+ else: # who knows how far below the SubTable level we are! Climb back up to the nearest subtable.
817
+ itemName = ".".join([self.name, itemName])
818
+ p1 = self.parent
819
+ while p1 and p1.name not in ["ExtSubTable", "SubTable"]:
820
+ itemName = ".".join([p1.name, itemName])
821
+ p1 = p1.parent
822
+ if p1:
823
+ if p1.name == "ExtSubTable":
824
+ LookupListIndex = p1.parent.parent.repeatIndex
825
+ SubTableIndex = p1.parent.repeatIndex
826
+ else:
827
+ LookupListIndex = p1.parent.repeatIndex
828
+ SubTableIndex = p1.repeatIndex
829
+
830
+ return OverflowErrorRecord(
831
+ (self.tableTag, LookupListIndex, SubTableIndex, itemName, itemIndex)
832
+ )
833
+
834
+
835
+ class CountReference(object):
836
+ """A reference to a Count value, not a count of references."""
837
+
838
+ def __init__(self, table, name, size=None, value=None):
839
+ self.table = table
840
+ self.name = name
841
+ self.size = size
842
+ if value is not None:
843
+ self.setValue(value)
844
+
845
+ def setValue(self, value):
846
+ table = self.table
847
+ name = self.name
848
+ if table[name] is None:
849
+ table[name] = value
850
+ else:
851
+ assert table[name] == value, (name, table[name], value)
852
+
853
+ def getValue(self):
854
+ return self.table[self.name]
855
+
856
+ def getCountData(self):
857
+ v = self.table[self.name]
858
+ if v is None:
859
+ v = 0
860
+ return {1: packUInt8, 2: packUShort, 4: packULong}[self.size](v)
861
+
862
+
863
+ def packUInt8(value):
864
+ return struct.pack(">B", value)
865
+
866
+
867
+ def packUShort(value):
868
+ return struct.pack(">H", value)
869
+
870
+
871
+ def packULong(value):
872
+ assert 0 <= value < 0x100000000, value
873
+ return struct.pack(">I", value)
874
+
875
+
876
+ def packUInt24(value):
877
+ assert 0 <= value < 0x1000000, value
878
+ return struct.pack(">I", value)[1:]
879
+
880
+
881
+ class BaseTable(object):
882
+ """Generic base class for all OpenType (sub)tables."""
883
+
884
+ def __getattr__(self, attr):
885
+ reader = self.__dict__.get("reader")
886
+ if reader:
887
+ del self.reader
888
+ font = self.font
889
+ del self.font
890
+ self.decompile(reader, font)
891
+ return getattr(self, attr)
892
+
893
+ raise AttributeError(attr)
894
+
895
+ def ensureDecompiled(self, recurse=False):
896
+ reader = self.__dict__.get("reader")
897
+ if reader:
898
+ del self.reader
899
+ font = self.font
900
+ del self.font
901
+ self.decompile(reader, font)
902
+ if recurse:
903
+ for subtable in self.iterSubTables():
904
+ subtable.value.ensureDecompiled(recurse)
905
+
906
+ def __getstate__(self):
907
+ # before copying/pickling 'lazy' objects, make a shallow copy of OTTableReader
908
+ # https://github.com/fonttools/fonttools/issues/2965
909
+ if "reader" in self.__dict__:
910
+ state = self.__dict__.copy()
911
+ state["reader"] = self.__dict__["reader"].copy()
912
+ return state
913
+ return self.__dict__
914
+
915
+ @classmethod
916
+ def getRecordSize(cls, reader):
917
+ totalSize = 0
918
+ for conv in cls.converters:
919
+ size = conv.getRecordSize(reader)
920
+ if size is NotImplemented:
921
+ return NotImplemented
922
+ countValue = 1
923
+ if conv.repeat:
924
+ if conv.repeat in reader:
925
+ countValue = reader[conv.repeat] + conv.aux
926
+ else:
927
+ return NotImplemented
928
+ totalSize += size * countValue
929
+ return totalSize
930
+
931
+ def getConverters(self):
932
+ return self.converters
933
+
934
+ def getConverterByName(self, name):
935
+ return self.convertersByName[name]
936
+
937
+ def populateDefaults(self, propagator=None):
938
+ for conv in self.getConverters():
939
+ if conv.repeat:
940
+ if not hasattr(self, conv.name):
941
+ setattr(self, conv.name, [])
942
+ countValue = len(getattr(self, conv.name)) - conv.aux
943
+ try:
944
+ count_conv = self.getConverterByName(conv.repeat)
945
+ setattr(self, conv.repeat, countValue)
946
+ except KeyError:
947
+ # conv.repeat is a propagated count
948
+ if propagator and conv.repeat in propagator:
949
+ propagator[conv.repeat].setValue(countValue)
950
+ else:
951
+ if conv.aux and not eval(conv.aux, None, self.__dict__):
952
+ continue
953
+ if hasattr(self, conv.name):
954
+ continue # Warn if it should NOT be present?!
955
+ if hasattr(conv, "writeNullOffset"):
956
+ setattr(self, conv.name, None) # Warn?
957
+ # elif not conv.isCount:
958
+ # # Warn?
959
+ # pass
960
+ if hasattr(conv, "DEFAULT"):
961
+ # OptionalValue converters (e.g. VarIndex)
962
+ setattr(self, conv.name, conv.DEFAULT)
963
+
964
+ def decompile(self, reader, font):
965
+ self.readFormat(reader)
966
+ table = {}
967
+ self.__rawTable = table # for debugging
968
+ for conv in self.getConverters():
969
+ if conv.name == "SubTable":
970
+ conv = conv.getConverter(reader.tableTag, table["LookupType"])
971
+ if conv.name == "ExtSubTable":
972
+ conv = conv.getConverter(reader.tableTag, table["ExtensionLookupType"])
973
+ if conv.name == "FeatureParams":
974
+ conv = conv.getConverter(reader["FeatureTag"])
975
+ if conv.name == "SubStruct":
976
+ conv = conv.getConverter(reader.tableTag, table["MorphType"])
977
+ try:
978
+ if conv.repeat:
979
+ if isinstance(conv.repeat, int):
980
+ countValue = conv.repeat
981
+ elif conv.repeat in table:
982
+ countValue = table[conv.repeat]
983
+ else:
984
+ # conv.repeat is a propagated count
985
+ countValue = reader[conv.repeat]
986
+ countValue += conv.aux
987
+ table[conv.name] = conv.readArray(reader, font, table, countValue)
988
+ else:
989
+ if conv.aux and not eval(conv.aux, None, table):
990
+ continue
991
+ table[conv.name] = conv.read(reader, font, table)
992
+ if conv.isPropagated:
993
+ reader[conv.name] = table[conv.name]
994
+ except Exception as e:
995
+ name = conv.name
996
+ e.args = e.args + (name,)
997
+ raise
998
+
999
+ if hasattr(self, "postRead"):
1000
+ self.postRead(table, font)
1001
+ else:
1002
+ self.__dict__.update(table)
1003
+
1004
+ del self.__rawTable # succeeded, get rid of debugging info
1005
+
1006
+ def compile(self, writer, font):
1007
+ self.ensureDecompiled()
1008
+ # TODO Following hack to be removed by rewriting how FormatSwitching tables
1009
+ # are handled.
1010
+ # https://github.com/fonttools/fonttools/pull/2238#issuecomment-805192631
1011
+ if hasattr(self, "preWrite"):
1012
+ deleteFormat = not hasattr(self, "Format")
1013
+ table = self.preWrite(font)
1014
+ deleteFormat = deleteFormat and hasattr(self, "Format")
1015
+ else:
1016
+ deleteFormat = False
1017
+ table = self.__dict__.copy()
1018
+
1019
+ # some count references may have been initialized in a custom preWrite; we set
1020
+ # these in the writer's state beforehand (instead of sequentially) so they will
1021
+ # be propagated to all nested subtables even if the count appears in the current
1022
+ # table only *after* the offset to the subtable that it is counting.
1023
+ for conv in self.getConverters():
1024
+ if conv.isCount and conv.isPropagated:
1025
+ value = table.get(conv.name)
1026
+ if isinstance(value, CountReference):
1027
+ writer[conv.name] = value
1028
+
1029
+ if hasattr(self, "sortCoverageLast"):
1030
+ writer.sortCoverageLast = 1
1031
+
1032
+ if hasattr(self, "DontShare"):
1033
+ writer.DontShare = True
1034
+
1035
+ if hasattr(self.__class__, "LookupType"):
1036
+ writer["LookupType"].setValue(self.__class__.LookupType)
1037
+
1038
+ self.writeFormat(writer)
1039
+ for conv in self.getConverters():
1040
+ value = table.get(
1041
+ conv.name
1042
+ ) # TODO Handle defaults instead of defaulting to None!
1043
+ if conv.repeat:
1044
+ if value is None:
1045
+ value = []
1046
+ countValue = len(value) - conv.aux
1047
+ if isinstance(conv.repeat, int):
1048
+ assert len(value) == conv.repeat, "expected %d values, got %d" % (
1049
+ conv.repeat,
1050
+ len(value),
1051
+ )
1052
+ elif conv.repeat in table:
1053
+ CountReference(table, conv.repeat, value=countValue)
1054
+ else:
1055
+ # conv.repeat is a propagated count
1056
+ writer[conv.repeat].setValue(countValue)
1057
+ try:
1058
+ conv.writeArray(writer, font, table, value)
1059
+ except Exception as e:
1060
+ e.args = e.args + (conv.name + "[]",)
1061
+ raise
1062
+ elif conv.isCount:
1063
+ # Special-case Count values.
1064
+ # Assumption: a Count field will *always* precede
1065
+ # the actual array(s).
1066
+ # We need a default value, as it may be set later by a nested
1067
+ # table. We will later store it here.
1068
+ # We add a reference: by the time the data is assembled
1069
+ # the Count value will be filled in.
1070
+ # We ignore the current count value since it will be recomputed,
1071
+ # unless it's a CountReference that was already initialized in a custom preWrite.
1072
+ if isinstance(value, CountReference):
1073
+ ref = value
1074
+ ref.size = conv.staticSize
1075
+ writer.writeData(ref)
1076
+ table[conv.name] = ref.getValue()
1077
+ else:
1078
+ ref = writer.writeCountReference(table, conv.name, conv.staticSize)
1079
+ table[conv.name] = None
1080
+ if conv.isPropagated:
1081
+ writer[conv.name] = ref
1082
+ elif conv.isLookupType:
1083
+ # We make sure that subtables have the same lookup type,
1084
+ # and that the type is the same as the one set on the
1085
+ # Lookup object, if any is set.
1086
+ if conv.name not in table:
1087
+ table[conv.name] = None
1088
+ ref = writer.writeCountReference(
1089
+ table, conv.name, conv.staticSize, table[conv.name]
1090
+ )
1091
+ writer["LookupType"] = ref
1092
+ else:
1093
+ if conv.aux and not eval(conv.aux, None, table):
1094
+ continue
1095
+ try:
1096
+ conv.write(writer, font, table, value)
1097
+ except Exception as e:
1098
+ name = value.__class__.__name__ if value is not None else conv.name
1099
+ e.args = e.args + (name,)
1100
+ raise
1101
+ if conv.isPropagated:
1102
+ writer[conv.name] = value
1103
+
1104
+ if deleteFormat:
1105
+ del self.Format
1106
+
1107
+ def readFormat(self, reader):
1108
+ pass
1109
+
1110
+ def writeFormat(self, writer):
1111
+ pass
1112
+
1113
+ def toXML(self, xmlWriter, font, attrs=None, name=None):
1114
+ tableName = name if name else self.__class__.__name__
1115
+ if attrs is None:
1116
+ attrs = []
1117
+ if hasattr(self, "Format"):
1118
+ attrs = attrs + [("Format", self.Format)]
1119
+ xmlWriter.begintag(tableName, attrs)
1120
+ xmlWriter.newline()
1121
+ self.toXML2(xmlWriter, font)
1122
+ xmlWriter.endtag(tableName)
1123
+ xmlWriter.newline()
1124
+
1125
+ def toXML2(self, xmlWriter, font):
1126
+ # Simpler variant of toXML, *only* for the top level tables (like GPOS, GSUB).
1127
+ # This is because in TTX our parent writes our main tag, and in otBase.py we
1128
+ # do it ourselves. I think I'm getting schizophrenic...
1129
+ for conv in self.getConverters():
1130
+ if conv.repeat:
1131
+ value = getattr(self, conv.name, [])
1132
+ for i, item in enumerate(value):
1133
+ conv.xmlWrite(xmlWriter, font, item, conv.name, [("index", i)])
1134
+ else:
1135
+ if conv.aux and not eval(conv.aux, None, vars(self)):
1136
+ continue
1137
+ value = getattr(
1138
+ self, conv.name, None
1139
+ ) # TODO Handle defaults instead of defaulting to None!
1140
+ conv.xmlWrite(xmlWriter, font, value, conv.name, [])
1141
+
1142
+ def fromXML(self, name, attrs, content, font):
1143
+ try:
1144
+ conv = self.getConverterByName(name)
1145
+ except KeyError:
1146
+ raise # XXX on KeyError, raise nice error
1147
+ value = conv.xmlRead(attrs, content, font)
1148
+ # Some manually-written tables have a conv.repeat of ""
1149
+ # to represent lists. Hence comparing to None here to
1150
+ # allow those lists to be read correctly from XML.
1151
+ if conv.repeat is not None:
1152
+ seq = getattr(self, conv.name, None)
1153
+ if seq is None:
1154
+ seq = []
1155
+ setattr(self, conv.name, seq)
1156
+ seq.append(value)
1157
+ else:
1158
+ setattr(self, conv.name, value)
1159
+
1160
+ def __ne__(self, other):
1161
+ result = self.__eq__(other)
1162
+ return result if result is NotImplemented else not result
1163
+
1164
+ def __eq__(self, other):
1165
+ if type(self) != type(other):
1166
+ return NotImplemented
1167
+
1168
+ self.ensureDecompiled()
1169
+ other.ensureDecompiled()
1170
+
1171
+ return self.__dict__ == other.__dict__
1172
+
1173
+ class SubTableEntry(NamedTuple):
1174
+ """See BaseTable.iterSubTables()"""
1175
+
1176
+ name: str
1177
+ value: "BaseTable"
1178
+ index: Optional[int] = None # index into given array, None for single values
1179
+
1180
+ def iterSubTables(self) -> Iterator[SubTableEntry]:
1181
+ """Yield (name, value, index) namedtuples for all subtables of current table.
1182
+
1183
+ A sub-table is an instance of BaseTable (or subclass thereof) that is a child
1184
+ of self, the current parent table.
1185
+ The tuples also contain the attribute name (str) of the of parent table to get
1186
+ a subtable, and optionally, for lists of subtables (i.e. attributes associated
1187
+ with a converter that has a 'repeat'), an index into the list containing the
1188
+ given subtable value.
1189
+ This method can be useful to traverse trees of otTables.
1190
+ """
1191
+ for conv in self.getConverters():
1192
+ name = conv.name
1193
+ value = getattr(self, name, None)
1194
+ if value is None:
1195
+ continue
1196
+ if isinstance(value, BaseTable):
1197
+ yield self.SubTableEntry(name, value)
1198
+ elif isinstance(value, list):
1199
+ yield from (
1200
+ self.SubTableEntry(name, v, index=i)
1201
+ for i, v in enumerate(value)
1202
+ if isinstance(v, BaseTable)
1203
+ )
1204
+
1205
+ # instance (not @class)method for consistency with FormatSwitchingBaseTable
1206
+ def getVariableAttrs(self):
1207
+ return getVariableAttrs(self.__class__)
1208
+
1209
+
1210
+ class FormatSwitchingBaseTable(BaseTable):
1211
+ """Minor specialization of BaseTable, for tables that have multiple
1212
+ formats, eg. CoverageFormat1 vs. CoverageFormat2."""
1213
+
1214
+ @classmethod
1215
+ def getRecordSize(cls, reader):
1216
+ return NotImplemented
1217
+
1218
+ def getConverters(self):
1219
+ try:
1220
+ fmt = self.Format
1221
+ except AttributeError:
1222
+ # some FormatSwitchingBaseTables (e.g. Coverage) no longer have 'Format'
1223
+ # attribute after fully decompiled, only gain one in preWrite before being
1224
+ # recompiled. In the decompiled state, these hand-coded classes defined in
1225
+ # otTables.py lose their format-specific nature and gain more high-level
1226
+ # attributes that are not tied to converters.
1227
+ return []
1228
+ return self.converters.get(self.Format, [])
1229
+
1230
+ def getConverterByName(self, name):
1231
+ return self.convertersByName[self.Format][name]
1232
+
1233
+ def readFormat(self, reader):
1234
+ self.Format = reader.readUShort()
1235
+
1236
+ def writeFormat(self, writer):
1237
+ writer.writeUShort(self.Format)
1238
+
1239
+ def toXML(self, xmlWriter, font, attrs=None, name=None):
1240
+ BaseTable.toXML(self, xmlWriter, font, attrs, name)
1241
+
1242
+ def getVariableAttrs(self):
1243
+ return getVariableAttrs(self.__class__, self.Format)
1244
+
1245
+
1246
+ class UInt8FormatSwitchingBaseTable(FormatSwitchingBaseTable):
1247
+ def readFormat(self, reader):
1248
+ self.Format = reader.readUInt8()
1249
+
1250
+ def writeFormat(self, writer):
1251
+ writer.writeUInt8(self.Format)
1252
+
1253
+
1254
+ formatSwitchingBaseTables = {
1255
+ "uint16": FormatSwitchingBaseTable,
1256
+ "uint8": UInt8FormatSwitchingBaseTable,
1257
+ }
1258
+
1259
+
1260
+ def getFormatSwitchingBaseTableClass(formatType):
1261
+ try:
1262
+ return formatSwitchingBaseTables[formatType]
1263
+ except KeyError:
1264
+ raise TypeError(f"Unsupported format type: {formatType!r}")
1265
+
1266
+
1267
+ # memoize since these are parsed from otData.py, thus stay constant
1268
+ @lru_cache()
1269
+ def getVariableAttrs(cls: BaseTable, fmt: Optional[int] = None) -> Tuple[str]:
1270
+ """Return sequence of variable table field names (can be empty).
1271
+
1272
+ Attributes are deemed "variable" when their otData.py's description contain
1273
+ 'VarIndexBase + {offset}', e.g. COLRv1 PaintVar* tables.
1274
+ """
1275
+ if not issubclass(cls, BaseTable):
1276
+ raise TypeError(cls)
1277
+ if issubclass(cls, FormatSwitchingBaseTable):
1278
+ if fmt is None:
1279
+ raise TypeError(f"'fmt' is required for format-switching {cls.__name__}")
1280
+ converters = cls.convertersByName[fmt]
1281
+ else:
1282
+ converters = cls.convertersByName
1283
+ # assume if no 'VarIndexBase' field is present, table has no variable fields
1284
+ if "VarIndexBase" not in converters:
1285
+ return ()
1286
+ varAttrs = {}
1287
+ for name, conv in converters.items():
1288
+ offset = conv.getVarIndexOffset()
1289
+ if offset is not None:
1290
+ varAttrs[name] = offset
1291
+ return tuple(sorted(varAttrs, key=varAttrs.__getitem__))
1292
+
1293
+
1294
+ #
1295
+ # Support for ValueRecords
1296
+ #
1297
+ # This data type is so different from all other OpenType data types that
1298
+ # it requires quite a bit of code for itself. It even has special support
1299
+ # in OTTableReader and OTTableWriter...
1300
+ #
1301
+
1302
+ valueRecordFormat = [
1303
+ # Mask Name isDevice signed
1304
+ (0x0001, "XPlacement", 0, 1),
1305
+ (0x0002, "YPlacement", 0, 1),
1306
+ (0x0004, "XAdvance", 0, 1),
1307
+ (0x0008, "YAdvance", 0, 1),
1308
+ (0x0010, "XPlaDevice", 1, 0),
1309
+ (0x0020, "YPlaDevice", 1, 0),
1310
+ (0x0040, "XAdvDevice", 1, 0),
1311
+ (0x0080, "YAdvDevice", 1, 0),
1312
+ # reserved:
1313
+ (0x0100, "Reserved1", 0, 0),
1314
+ (0x0200, "Reserved2", 0, 0),
1315
+ (0x0400, "Reserved3", 0, 0),
1316
+ (0x0800, "Reserved4", 0, 0),
1317
+ (0x1000, "Reserved5", 0, 0),
1318
+ (0x2000, "Reserved6", 0, 0),
1319
+ (0x4000, "Reserved7", 0, 0),
1320
+ (0x8000, "Reserved8", 0, 0),
1321
+ ]
1322
+
1323
+
1324
+ def _buildDict():
1325
+ d = {}
1326
+ for mask, name, isDevice, signed in valueRecordFormat:
1327
+ d[name] = mask, isDevice, signed
1328
+ return d
1329
+
1330
+
1331
+ valueRecordFormatDict = _buildDict()
1332
+
1333
+
1334
+ class ValueRecordFactory(object):
1335
+ """Given a format code, this object convert ValueRecords."""
1336
+
1337
+ def __init__(self, valueFormat):
1338
+ format = []
1339
+ for mask, name, isDevice, signed in valueRecordFormat:
1340
+ if valueFormat & mask:
1341
+ format.append((name, isDevice, signed))
1342
+ self.format = format
1343
+
1344
+ def __len__(self):
1345
+ return len(self.format)
1346
+
1347
+ def readValueRecord(self, reader, font):
1348
+ format = self.format
1349
+ if not format:
1350
+ return None
1351
+ valueRecord = ValueRecord()
1352
+ for name, isDevice, signed in format:
1353
+ if signed:
1354
+ value = reader.readShort()
1355
+ else:
1356
+ value = reader.readUShort()
1357
+ if isDevice:
1358
+ if value:
1359
+ from . import otTables
1360
+
1361
+ subReader = reader.getSubReader(value)
1362
+ value = getattr(otTables, name)()
1363
+ value.decompile(subReader, font)
1364
+ else:
1365
+ value = None
1366
+ setattr(valueRecord, name, value)
1367
+ return valueRecord
1368
+
1369
+ def writeValueRecord(self, writer, font, valueRecord):
1370
+ for name, isDevice, signed in self.format:
1371
+ value = getattr(valueRecord, name, 0)
1372
+ if isDevice:
1373
+ if value:
1374
+ subWriter = writer.getSubWriter()
1375
+ writer.writeSubTable(subWriter, offsetSize=2)
1376
+ value.compile(subWriter, font)
1377
+ else:
1378
+ writer.writeUShort(0)
1379
+ elif signed:
1380
+ writer.writeShort(value)
1381
+ else:
1382
+ writer.writeUShort(value)
1383
+
1384
+
1385
+ class ValueRecord(object):
1386
+ # see ValueRecordFactory
1387
+
1388
+ def __init__(self, valueFormat=None, src=None):
1389
+ if valueFormat is not None:
1390
+ for mask, name, isDevice, signed in valueRecordFormat:
1391
+ if valueFormat & mask:
1392
+ setattr(self, name, None if isDevice else 0)
1393
+ if src is not None:
1394
+ for key, val in src.__dict__.items():
1395
+ if not hasattr(self, key):
1396
+ continue
1397
+ setattr(self, key, val)
1398
+ elif src is not None:
1399
+ self.__dict__ = src.__dict__.copy()
1400
+
1401
+ def getFormat(self):
1402
+ format = 0
1403
+ for name in self.__dict__.keys():
1404
+ format = format | valueRecordFormatDict[name][0]
1405
+ return format
1406
+
1407
+ def getEffectiveFormat(self):
1408
+ format = 0
1409
+ for name, value in self.__dict__.items():
1410
+ if value:
1411
+ format = format | valueRecordFormatDict[name][0]
1412
+ return format
1413
+
1414
+ def toXML(self, xmlWriter, font, valueName, attrs=None):
1415
+ if attrs is None:
1416
+ simpleItems = []
1417
+ else:
1418
+ simpleItems = list(attrs)
1419
+ for mask, name, isDevice, format in valueRecordFormat[:4]: # "simple" values
1420
+ if hasattr(self, name):
1421
+ simpleItems.append((name, getattr(self, name)))
1422
+ deviceItems = []
1423
+ for mask, name, isDevice, format in valueRecordFormat[4:8]: # device records
1424
+ if hasattr(self, name):
1425
+ device = getattr(self, name)
1426
+ if device is not None:
1427
+ deviceItems.append((name, device))
1428
+ if deviceItems:
1429
+ xmlWriter.begintag(valueName, simpleItems)
1430
+ xmlWriter.newline()
1431
+ for name, deviceRecord in deviceItems:
1432
+ if deviceRecord is not None:
1433
+ deviceRecord.toXML(xmlWriter, font, name=name)
1434
+ xmlWriter.endtag(valueName)
1435
+ xmlWriter.newline()
1436
+ else:
1437
+ xmlWriter.simpletag(valueName, simpleItems)
1438
+ xmlWriter.newline()
1439
+
1440
+ def fromXML(self, name, attrs, content, font):
1441
+ from . import otTables
1442
+
1443
+ for k, v in attrs.items():
1444
+ setattr(self, k, int(v))
1445
+ for element in content:
1446
+ if not isinstance(element, tuple):
1447
+ continue
1448
+ name, attrs, content = element
1449
+ value = getattr(otTables, name)()
1450
+ for elem2 in content:
1451
+ if not isinstance(elem2, tuple):
1452
+ continue
1453
+ name2, attrs2, content2 = elem2
1454
+ value.fromXML(name2, attrs2, content2, font)
1455
+ setattr(self, name, value)
1456
+
1457
+ def __ne__(self, other):
1458
+ result = self.__eq__(other)
1459
+ return result if result is NotImplemented else not result
1460
+
1461
+ def __eq__(self, other):
1462
+ if type(self) != type(other):
1463
+ return NotImplemented
1464
+ return self.__dict__ == other.__dict__
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otConverters.py ADDED
@@ -0,0 +1,2068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fontTools.misc.fixedTools import (
2
+ fixedToFloat as fi2fl,
3
+ floatToFixed as fl2fi,
4
+ floatToFixedToStr as fl2str,
5
+ strToFixedToFloat as str2fl,
6
+ ensureVersionIsLong as fi2ve,
7
+ versionToFixed as ve2fi,
8
+ )
9
+ from fontTools.ttLib.tables.TupleVariation import TupleVariation
10
+ from fontTools.misc.roundTools import nearestMultipleShortestRepr, otRound
11
+ from fontTools.misc.textTools import bytesjoin, tobytes, tostr, pad, safeEval
12
+ from fontTools.misc.lazyTools import LazyList
13
+ from fontTools.ttLib import OPTIMIZE_FONT_SPEED, getSearchRange
14
+ from .otBase import (
15
+ CountReference,
16
+ FormatSwitchingBaseTable,
17
+ OTTableReader,
18
+ OTTableWriter,
19
+ ValueRecordFactory,
20
+ )
21
+ from .otTables import (
22
+ lookupTypes,
23
+ VarCompositeGlyph,
24
+ AATStateTable,
25
+ AATState,
26
+ AATAction,
27
+ ContextualMorphAction,
28
+ LigatureMorphAction,
29
+ InsertionMorphAction,
30
+ MorxSubtable,
31
+ ExtendMode as _ExtendMode,
32
+ CompositeMode as _CompositeMode,
33
+ NO_VARIATION_INDEX,
34
+ )
35
+ from itertools import zip_longest, accumulate
36
+ from functools import partial
37
+ from types import SimpleNamespace
38
+ import re
39
+ import struct
40
+ from typing import Optional
41
+ import logging
42
+
43
+
44
+ log = logging.getLogger(__name__)
45
+ istuple = lambda t: isinstance(t, tuple)
46
+
47
+
48
+ def buildConverters(tableSpec, tableNamespace):
49
+ """Given a table spec from otData.py, build a converter object for each
50
+ field of the table. This is called for each table in otData.py, and
51
+ the results are assigned to the corresponding class in otTables.py."""
52
+ converters = []
53
+ convertersByName = {}
54
+ for tp, name, repeat, aux, descr in tableSpec:
55
+ tableName = name
56
+ if name.startswith("ValueFormat"):
57
+ assert tp == "uint16"
58
+ converterClass = ValueFormat
59
+ elif name.endswith("Count") or name in ("StructLength", "MorphType"):
60
+ converterClass = {
61
+ "uint8": ComputedUInt8,
62
+ "uint16": ComputedUShort,
63
+ "uint32": ComputedULong,
64
+ }[tp]
65
+ elif name == "SubTable":
66
+ converterClass = SubTable
67
+ elif name == "ExtSubTable":
68
+ converterClass = ExtSubTable
69
+ elif name == "SubStruct":
70
+ converterClass = SubStruct
71
+ elif name == "FeatureParams":
72
+ converterClass = FeatureParams
73
+ elif name in ("CIDGlyphMapping", "GlyphCIDMapping"):
74
+ converterClass = StructWithLength
75
+ else:
76
+ if not tp in converterMapping and "(" not in tp:
77
+ tableName = tp
78
+ converterClass = Struct
79
+ else:
80
+ converterClass = eval(tp, tableNamespace, converterMapping)
81
+
82
+ conv = converterClass(name, repeat, aux, description=descr)
83
+
84
+ if conv.tableClass:
85
+ # A "template" such as OffsetTo(AType) knows the table class already
86
+ tableClass = conv.tableClass
87
+ elif tp in ("MortChain", "MortSubtable", "MorxChain"):
88
+ tableClass = tableNamespace.get(tp)
89
+ else:
90
+ tableClass = tableNamespace.get(tableName)
91
+
92
+ if not conv.tableClass:
93
+ conv.tableClass = tableClass
94
+
95
+ if name in ["SubTable", "ExtSubTable", "SubStruct"]:
96
+ conv.lookupTypes = tableNamespace["lookupTypes"]
97
+ # also create reverse mapping
98
+ for t in conv.lookupTypes.values():
99
+ for cls in t.values():
100
+ convertersByName[cls.__name__] = Table(name, repeat, aux, cls)
101
+ if name == "FeatureParams":
102
+ conv.featureParamTypes = tableNamespace["featureParamTypes"]
103
+ conv.defaultFeatureParams = tableNamespace["FeatureParams"]
104
+ for cls in conv.featureParamTypes.values():
105
+ convertersByName[cls.__name__] = Table(name, repeat, aux, cls)
106
+ converters.append(conv)
107
+ assert name not in convertersByName, name
108
+ convertersByName[name] = conv
109
+ return converters, convertersByName
110
+
111
+
112
+ class BaseConverter(object):
113
+ """Base class for converter objects. Apart from the constructor, this
114
+ is an abstract class."""
115
+
116
+ def __init__(self, name, repeat, aux, tableClass=None, *, description=""):
117
+ self.name = name
118
+ self.repeat = repeat
119
+ self.aux = aux
120
+ if self.aux and not self.repeat:
121
+ self.aux = compile(self.aux, "<string>", "eval")
122
+ self.tableClass = tableClass
123
+ self.isCount = name.endswith("Count") or name in [
124
+ "DesignAxisRecordSize",
125
+ "ValueRecordSize",
126
+ ]
127
+ self.isLookupType = name.endswith("LookupType") or name == "MorphType"
128
+ self.isPropagated = name in [
129
+ "ClassCount",
130
+ "Class2Count",
131
+ "FeatureTag",
132
+ "SettingsCount",
133
+ "VarRegionCount",
134
+ "MappingCount",
135
+ "RegionAxisCount",
136
+ "DesignAxisCount",
137
+ "DesignAxisRecordSize",
138
+ "AxisValueCount",
139
+ "ValueRecordSize",
140
+ "AxisCount",
141
+ "BaseGlyphRecordCount",
142
+ "LayerRecordCount",
143
+ "AxisIndicesList",
144
+ ]
145
+ self.description = description
146
+
147
+ def readArray(self, reader, font, tableDict, count):
148
+ """Read an array of values from the reader."""
149
+ lazy = font.lazy and count > 8
150
+ if lazy:
151
+ recordSize = self.getRecordSize(reader)
152
+ if recordSize is NotImplemented:
153
+ lazy = False
154
+ if not lazy:
155
+ l = []
156
+ for i in range(count):
157
+ l.append(self.read(reader, font, tableDict))
158
+ return l
159
+ else:
160
+
161
+ def get_read_item():
162
+ reader_copy = reader.copy()
163
+ pos = reader.pos
164
+
165
+ def read_item(i):
166
+ reader_copy.seek(pos + i * recordSize)
167
+ return self.read(reader_copy, font, {})
168
+
169
+ return read_item
170
+
171
+ read_item = get_read_item()
172
+ l = LazyList(read_item for i in range(count))
173
+ reader.advance(count * recordSize)
174
+
175
+ return l
176
+
177
+ def getRecordSize(self, reader):
178
+ if hasattr(self, "staticSize"):
179
+ return self.staticSize
180
+ return NotImplemented
181
+
182
+ def read(self, reader, font, tableDict):
183
+ """Read a value from the reader."""
184
+ raise NotImplementedError(self)
185
+
186
+ def writeArray(self, writer, font, tableDict, values):
187
+ try:
188
+ for i, value in enumerate(values):
189
+ self.write(writer, font, tableDict, value, i)
190
+ except Exception as e:
191
+ e.args = e.args + (i,)
192
+ raise
193
+
194
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
195
+ """Write a value to the writer."""
196
+ raise NotImplementedError(self)
197
+
198
+ def xmlRead(self, attrs, content, font):
199
+ """Read a value from XML."""
200
+ raise NotImplementedError(self)
201
+
202
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
203
+ """Write a value to XML."""
204
+ raise NotImplementedError(self)
205
+
206
+ varIndexBasePlusOffsetRE = re.compile(r"VarIndexBase\s*\+\s*(\d+)")
207
+
208
+ def getVarIndexOffset(self) -> Optional[int]:
209
+ """If description has `VarIndexBase + {offset}`, return the offset else None."""
210
+ m = self.varIndexBasePlusOffsetRE.search(self.description)
211
+ if not m:
212
+ return None
213
+ return int(m.group(1))
214
+
215
+
216
+ class SimpleValue(BaseConverter):
217
+ @staticmethod
218
+ def toString(value):
219
+ return value
220
+
221
+ @staticmethod
222
+ def fromString(value):
223
+ return value
224
+
225
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
226
+ xmlWriter.simpletag(name, attrs + [("value", self.toString(value))])
227
+ xmlWriter.newline()
228
+
229
+ def xmlRead(self, attrs, content, font):
230
+ return self.fromString(attrs["value"])
231
+
232
+
233
+ class OptionalValue(SimpleValue):
234
+ DEFAULT = None
235
+
236
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
237
+ if value != self.DEFAULT:
238
+ attrs.append(("value", self.toString(value)))
239
+ xmlWriter.simpletag(name, attrs)
240
+ xmlWriter.newline()
241
+
242
+ def xmlRead(self, attrs, content, font):
243
+ if "value" in attrs:
244
+ return self.fromString(attrs["value"])
245
+ return self.DEFAULT
246
+
247
+
248
+ class IntValue(SimpleValue):
249
+ @staticmethod
250
+ def fromString(value):
251
+ return int(value, 0)
252
+
253
+
254
+ class Long(IntValue):
255
+ staticSize = 4
256
+
257
+ def read(self, reader, font, tableDict):
258
+ return reader.readLong()
259
+
260
+ def readArray(self, reader, font, tableDict, count):
261
+ return reader.readLongArray(count)
262
+
263
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
264
+ writer.writeLong(value)
265
+
266
+ def writeArray(self, writer, font, tableDict, values):
267
+ writer.writeLongArray(values)
268
+
269
+
270
+ class ULong(IntValue):
271
+ staticSize = 4
272
+
273
+ def read(self, reader, font, tableDict):
274
+ return reader.readULong()
275
+
276
+ def readArray(self, reader, font, tableDict, count):
277
+ return reader.readULongArray(count)
278
+
279
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
280
+ writer.writeULong(value)
281
+
282
+ def writeArray(self, writer, font, tableDict, values):
283
+ writer.writeULongArray(values)
284
+
285
+
286
+ class Flags32(ULong):
287
+ @staticmethod
288
+ def toString(value):
289
+ return "0x%08X" % value
290
+
291
+
292
+ class VarIndex(OptionalValue, ULong):
293
+ DEFAULT = NO_VARIATION_INDEX
294
+
295
+
296
+ class Short(IntValue):
297
+ staticSize = 2
298
+
299
+ def read(self, reader, font, tableDict):
300
+ return reader.readShort()
301
+
302
+ def readArray(self, reader, font, tableDict, count):
303
+ return reader.readShortArray(count)
304
+
305
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
306
+ writer.writeShort(value)
307
+
308
+ def writeArray(self, writer, font, tableDict, values):
309
+ writer.writeShortArray(values)
310
+
311
+
312
+ class UShort(IntValue):
313
+ staticSize = 2
314
+
315
+ def read(self, reader, font, tableDict):
316
+ return reader.readUShort()
317
+
318
+ def readArray(self, reader, font, tableDict, count):
319
+ return reader.readUShortArray(count)
320
+
321
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
322
+ writer.writeUShort(value)
323
+
324
+ def writeArray(self, writer, font, tableDict, values):
325
+ writer.writeUShortArray(values)
326
+
327
+
328
+ class Int8(IntValue):
329
+ staticSize = 1
330
+
331
+ def read(self, reader, font, tableDict):
332
+ return reader.readInt8()
333
+
334
+ def readArray(self, reader, font, tableDict, count):
335
+ return reader.readInt8Array(count)
336
+
337
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
338
+ writer.writeInt8(value)
339
+
340
+ def writeArray(self, writer, font, tableDict, values):
341
+ writer.writeInt8Array(values)
342
+
343
+
344
+ class UInt8(IntValue):
345
+ staticSize = 1
346
+
347
+ def read(self, reader, font, tableDict):
348
+ return reader.readUInt8()
349
+
350
+ def readArray(self, reader, font, tableDict, count):
351
+ return reader.readUInt8Array(count)
352
+
353
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
354
+ writer.writeUInt8(value)
355
+
356
+ def writeArray(self, writer, font, tableDict, values):
357
+ writer.writeUInt8Array(values)
358
+
359
+
360
+ class UInt24(IntValue):
361
+ staticSize = 3
362
+
363
+ def read(self, reader, font, tableDict):
364
+ return reader.readUInt24()
365
+
366
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
367
+ writer.writeUInt24(value)
368
+
369
+
370
+ class ComputedInt(IntValue):
371
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
372
+ if value is not None:
373
+ xmlWriter.comment("%s=%s" % (name, value))
374
+ xmlWriter.newline()
375
+
376
+
377
+ class ComputedUInt8(ComputedInt, UInt8):
378
+ pass
379
+
380
+
381
+ class ComputedUShort(ComputedInt, UShort):
382
+ pass
383
+
384
+
385
+ class ComputedULong(ComputedInt, ULong):
386
+ pass
387
+
388
+
389
+ class Tag(SimpleValue):
390
+ staticSize = 4
391
+
392
+ def read(self, reader, font, tableDict):
393
+ return reader.readTag()
394
+
395
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
396
+ writer.writeTag(value)
397
+
398
+
399
+ class GlyphID(SimpleValue):
400
+ staticSize = 2
401
+ typecode = "H"
402
+
403
+ def readArray(self, reader, font, tableDict, count):
404
+ return font.getGlyphNameMany(
405
+ reader.readArray(self.typecode, self.staticSize, count)
406
+ )
407
+
408
+ def read(self, reader, font, tableDict):
409
+ return font.getGlyphName(reader.readValue(self.typecode, self.staticSize))
410
+
411
+ def writeArray(self, writer, font, tableDict, values):
412
+ writer.writeArray(self.typecode, font.getGlyphIDMany(values))
413
+
414
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
415
+ writer.writeValue(self.typecode, font.getGlyphID(value))
416
+
417
+
418
+ class GlyphID32(GlyphID):
419
+ staticSize = 4
420
+ typecode = "L"
421
+
422
+
423
+ class NameID(UShort):
424
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
425
+ xmlWriter.simpletag(name, attrs + [("value", value)])
426
+ if font and value:
427
+ nameTable = font.get("name")
428
+ if nameTable:
429
+ name = nameTable.getDebugName(value)
430
+ xmlWriter.write(" ")
431
+ if name:
432
+ xmlWriter.comment(name)
433
+ else:
434
+ xmlWriter.comment("missing from name table")
435
+ log.warning("name id %d missing from name table" % value)
436
+ xmlWriter.newline()
437
+
438
+
439
+ class STATFlags(UShort):
440
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
441
+ xmlWriter.simpletag(name, attrs + [("value", value)])
442
+ flags = []
443
+ if value & 0x01:
444
+ flags.append("OlderSiblingFontAttribute")
445
+ if value & 0x02:
446
+ flags.append("ElidableAxisValueName")
447
+ if flags:
448
+ xmlWriter.write(" ")
449
+ xmlWriter.comment(" ".join(flags))
450
+ xmlWriter.newline()
451
+
452
+
453
+ class FloatValue(SimpleValue):
454
+ @staticmethod
455
+ def fromString(value):
456
+ return float(value)
457
+
458
+
459
+ class DeciPoints(FloatValue):
460
+ staticSize = 2
461
+
462
+ def read(self, reader, font, tableDict):
463
+ return reader.readUShort() / 10
464
+
465
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
466
+ writer.writeUShort(round(value * 10))
467
+
468
+
469
+ class BaseFixedValue(FloatValue):
470
+ staticSize = NotImplemented
471
+ precisionBits = NotImplemented
472
+ readerMethod = NotImplemented
473
+ writerMethod = NotImplemented
474
+
475
+ def read(self, reader, font, tableDict):
476
+ return self.fromInt(getattr(reader, self.readerMethod)())
477
+
478
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
479
+ getattr(writer, self.writerMethod)(self.toInt(value))
480
+
481
+ @classmethod
482
+ def fromInt(cls, value):
483
+ return fi2fl(value, cls.precisionBits)
484
+
485
+ @classmethod
486
+ def toInt(cls, value):
487
+ return fl2fi(value, cls.precisionBits)
488
+
489
+ @classmethod
490
+ def fromString(cls, value):
491
+ return str2fl(value, cls.precisionBits)
492
+
493
+ @classmethod
494
+ def toString(cls, value):
495
+ return fl2str(value, cls.precisionBits)
496
+
497
+
498
+ class Fixed(BaseFixedValue):
499
+ staticSize = 4
500
+ precisionBits = 16
501
+ readerMethod = "readLong"
502
+ writerMethod = "writeLong"
503
+
504
+
505
+ class F2Dot14(BaseFixedValue):
506
+ staticSize = 2
507
+ precisionBits = 14
508
+ readerMethod = "readShort"
509
+ writerMethod = "writeShort"
510
+
511
+
512
+ class Angle(F2Dot14):
513
+ # angles are specified in degrees, and encoded as F2Dot14 fractions of half
514
+ # circle: e.g. 1.0 => 180, -0.5 => -90, -2.0 => -360, etc.
515
+ bias = 0.0
516
+ factor = 1.0 / (1 << 14) * 180 # 0.010986328125
517
+
518
+ @classmethod
519
+ def fromInt(cls, value):
520
+ return (super().fromInt(value) + cls.bias) * 180
521
+
522
+ @classmethod
523
+ def toInt(cls, value):
524
+ return super().toInt((value / 180) - cls.bias)
525
+
526
+ @classmethod
527
+ def fromString(cls, value):
528
+ # quantize to nearest multiples of minimum fixed-precision angle
529
+ return otRound(float(value) / cls.factor) * cls.factor
530
+
531
+ @classmethod
532
+ def toString(cls, value):
533
+ return nearestMultipleShortestRepr(value, cls.factor)
534
+
535
+
536
+ class BiasedAngle(Angle):
537
+ # A bias of 1.0 is used in the representation of start and end angles
538
+ # of COLRv1 PaintSweepGradients to allow for encoding +360deg
539
+ bias = 1.0
540
+
541
+
542
+ class Version(SimpleValue):
543
+ staticSize = 4
544
+
545
+ def read(self, reader, font, tableDict):
546
+ value = reader.readLong()
547
+ return value
548
+
549
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
550
+ value = fi2ve(value)
551
+ writer.writeLong(value)
552
+
553
+ @staticmethod
554
+ def fromString(value):
555
+ return ve2fi(value)
556
+
557
+ @staticmethod
558
+ def toString(value):
559
+ return "0x%08x" % value
560
+
561
+ @staticmethod
562
+ def fromFloat(v):
563
+ return fl2fi(v, 16)
564
+
565
+
566
+ class Char64(SimpleValue):
567
+ """An ASCII string with up to 64 characters.
568
+
569
+ Unused character positions are filled with 0x00 bytes.
570
+ Used in Apple AAT fonts in the `gcid` table.
571
+ """
572
+
573
+ staticSize = 64
574
+
575
+ def read(self, reader, font, tableDict):
576
+ data = reader.readData(self.staticSize)
577
+ zeroPos = data.find(b"\0")
578
+ if zeroPos >= 0:
579
+ data = data[:zeroPos]
580
+ s = tostr(data, encoding="ascii", errors="replace")
581
+ if s != tostr(data, encoding="ascii", errors="ignore"):
582
+ log.warning('replaced non-ASCII characters in "%s"' % s)
583
+ return s
584
+
585
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
586
+ data = tobytes(value, encoding="ascii", errors="replace")
587
+ if data != tobytes(value, encoding="ascii", errors="ignore"):
588
+ log.warning('replacing non-ASCII characters in "%s"' % value)
589
+ if len(data) > self.staticSize:
590
+ log.warning(
591
+ 'truncating overlong "%s" to %d bytes' % (value, self.staticSize)
592
+ )
593
+ data = (data + b"\0" * self.staticSize)[: self.staticSize]
594
+ writer.writeData(data)
595
+
596
+
597
+ class Struct(BaseConverter):
598
+ def getRecordSize(self, reader):
599
+ return self.tableClass and self.tableClass.getRecordSize(reader)
600
+
601
+ def read(self, reader, font, tableDict):
602
+ table = self.tableClass()
603
+ table.decompile(reader, font)
604
+ return table
605
+
606
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
607
+ value.compile(writer, font)
608
+
609
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
610
+ if value is None:
611
+ if attrs:
612
+ # If there are attributes (probably index), then
613
+ # don't drop this even if it's NULL. It will mess
614
+ # up the array indices of the containing element.
615
+ xmlWriter.simpletag(name, attrs + [("empty", 1)])
616
+ xmlWriter.newline()
617
+ else:
618
+ pass # NULL table, ignore
619
+ else:
620
+ value.toXML(xmlWriter, font, attrs, name=name)
621
+
622
+ def xmlRead(self, attrs, content, font):
623
+ if "empty" in attrs and safeEval(attrs["empty"]):
624
+ return None
625
+ table = self.tableClass()
626
+ Format = attrs.get("Format")
627
+ if Format is not None:
628
+ table.Format = int(Format)
629
+
630
+ noPostRead = not hasattr(table, "postRead")
631
+ if noPostRead:
632
+ # TODO Cache table.hasPropagated.
633
+ cleanPropagation = False
634
+ for conv in table.getConverters():
635
+ if conv.isPropagated:
636
+ cleanPropagation = True
637
+ if not hasattr(font, "_propagator"):
638
+ font._propagator = {}
639
+ propagator = font._propagator
640
+ assert conv.name not in propagator, (conv.name, propagator)
641
+ setattr(table, conv.name, None)
642
+ propagator[conv.name] = CountReference(table.__dict__, conv.name)
643
+
644
+ for element in content:
645
+ if isinstance(element, tuple):
646
+ name, attrs, content = element
647
+ table.fromXML(name, attrs, content, font)
648
+ else:
649
+ pass
650
+
651
+ table.populateDefaults(propagator=getattr(font, "_propagator", None))
652
+
653
+ if noPostRead:
654
+ if cleanPropagation:
655
+ for conv in table.getConverters():
656
+ if conv.isPropagated:
657
+ propagator = font._propagator
658
+ del propagator[conv.name]
659
+ if not propagator:
660
+ del font._propagator
661
+
662
+ return table
663
+
664
+ def __repr__(self):
665
+ return "Struct of " + repr(self.tableClass)
666
+
667
+
668
+ class StructWithLength(Struct):
669
+ def read(self, reader, font, tableDict):
670
+ pos = reader.pos
671
+ table = self.tableClass()
672
+ table.decompile(reader, font)
673
+ reader.seek(pos + table.StructLength)
674
+ return table
675
+
676
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
677
+ for convIndex, conv in enumerate(value.getConverters()):
678
+ if conv.name == "StructLength":
679
+ break
680
+ lengthIndex = len(writer.items) + convIndex
681
+ if isinstance(value, FormatSwitchingBaseTable):
682
+ lengthIndex += 1 # implicit Format field
683
+ deadbeef = {1: 0xDE, 2: 0xDEAD, 4: 0xDEADBEEF}[conv.staticSize]
684
+
685
+ before = writer.getDataLength()
686
+ value.StructLength = deadbeef
687
+ value.compile(writer, font)
688
+ length = writer.getDataLength() - before
689
+ lengthWriter = writer.getSubWriter()
690
+ conv.write(lengthWriter, font, tableDict, length)
691
+ assert writer.items[lengthIndex] == b"\xde\xad\xbe\xef"[: conv.staticSize]
692
+ writer.items[lengthIndex] = lengthWriter.getAllData()
693
+
694
+
695
+ class Table(Struct):
696
+ staticSize = 2
697
+
698
+ def readOffset(self, reader):
699
+ return reader.readUShort()
700
+
701
+ def writeNullOffset(self, writer):
702
+ writer.writeUShort(0)
703
+
704
+ def read(self, reader, font, tableDict):
705
+ offset = self.readOffset(reader)
706
+ if offset == 0:
707
+ return None
708
+ table = self.tableClass()
709
+ reader = reader.getSubReader(offset)
710
+ if font.lazy:
711
+ table.reader = reader
712
+ table.font = font
713
+ else:
714
+ table.decompile(reader, font)
715
+ return table
716
+
717
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
718
+ if value is None:
719
+ self.writeNullOffset(writer)
720
+ else:
721
+ subWriter = writer.getSubWriter()
722
+ subWriter.name = self.name
723
+ if repeatIndex is not None:
724
+ subWriter.repeatIndex = repeatIndex
725
+ writer.writeSubTable(subWriter, offsetSize=self.staticSize)
726
+ value.compile(subWriter, font)
727
+
728
+
729
+ class LTable(Table):
730
+ staticSize = 4
731
+
732
+ def readOffset(self, reader):
733
+ return reader.readULong()
734
+
735
+ def writeNullOffset(self, writer):
736
+ writer.writeULong(0)
737
+
738
+
739
+ # Table pointed to by a 24-bit, 3-byte long offset
740
+ class Table24(Table):
741
+ staticSize = 3
742
+
743
+ def readOffset(self, reader):
744
+ return reader.readUInt24()
745
+
746
+ def writeNullOffset(self, writer):
747
+ writer.writeUInt24(0)
748
+
749
+
750
+ # TODO Clean / merge the SubTable and SubStruct
751
+
752
+
753
+ class SubStruct(Struct):
754
+ def getConverter(self, tableType, lookupType):
755
+ tableClass = self.lookupTypes[tableType][lookupType]
756
+ return self.__class__(self.name, self.repeat, self.aux, tableClass)
757
+
758
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
759
+ super(SubStruct, self).xmlWrite(xmlWriter, font, value, None, attrs)
760
+
761
+
762
+ class SubTable(Table):
763
+ def getConverter(self, tableType, lookupType):
764
+ tableClass = self.lookupTypes[tableType][lookupType]
765
+ return self.__class__(self.name, self.repeat, self.aux, tableClass)
766
+
767
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
768
+ super(SubTable, self).xmlWrite(xmlWriter, font, value, None, attrs)
769
+
770
+
771
+ class ExtSubTable(LTable, SubTable):
772
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
773
+ writer.Extension = True # actually, mere presence of the field flags it as an Ext Subtable writer.
774
+ Table.write(self, writer, font, tableDict, value, repeatIndex)
775
+
776
+
777
+ class FeatureParams(Table):
778
+ def getConverter(self, featureTag):
779
+ tableClass = self.featureParamTypes.get(featureTag, self.defaultFeatureParams)
780
+ return self.__class__(self.name, self.repeat, self.aux, tableClass)
781
+
782
+
783
+ class ValueFormat(IntValue):
784
+ staticSize = 2
785
+
786
+ def __init__(self, name, repeat, aux, tableClass=None, *, description=""):
787
+ BaseConverter.__init__(
788
+ self, name, repeat, aux, tableClass, description=description
789
+ )
790
+ self.which = "ValueFormat" + ("2" if name[-1] == "2" else "1")
791
+
792
+ def read(self, reader, font, tableDict):
793
+ format = reader.readUShort()
794
+ reader[self.which] = ValueRecordFactory(format)
795
+ return format
796
+
797
+ def write(self, writer, font, tableDict, format, repeatIndex=None):
798
+ writer.writeUShort(format)
799
+ writer[self.which] = ValueRecordFactory(format)
800
+
801
+
802
+ class ValueRecord(ValueFormat):
803
+ def getRecordSize(self, reader):
804
+ return 2 * len(reader[self.which])
805
+
806
+ def read(self, reader, font, tableDict):
807
+ return reader[self.which].readValueRecord(reader, font)
808
+
809
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
810
+ writer[self.which].writeValueRecord(writer, font, value)
811
+
812
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
813
+ if value is None:
814
+ pass # NULL table, ignore
815
+ else:
816
+ value.toXML(xmlWriter, font, self.name, attrs)
817
+
818
+ def xmlRead(self, attrs, content, font):
819
+ from .otBase import ValueRecord
820
+
821
+ value = ValueRecord()
822
+ value.fromXML(None, attrs, content, font)
823
+ return value
824
+
825
+
826
+ class AATLookup(BaseConverter):
827
+ BIN_SEARCH_HEADER_SIZE = 10
828
+
829
+ def __init__(self, name, repeat, aux, tableClass, *, description=""):
830
+ BaseConverter.__init__(
831
+ self, name, repeat, aux, tableClass, description=description
832
+ )
833
+ if issubclass(self.tableClass, SimpleValue):
834
+ self.converter = self.tableClass(name="Value", repeat=None, aux=None)
835
+ else:
836
+ self.converter = Table(
837
+ name="Value", repeat=None, aux=None, tableClass=self.tableClass
838
+ )
839
+
840
+ def read(self, reader, font, tableDict):
841
+ format = reader.readUShort()
842
+ if format == 0:
843
+ return self.readFormat0(reader, font)
844
+ elif format == 2:
845
+ return self.readFormat2(reader, font)
846
+ elif format == 4:
847
+ return self.readFormat4(reader, font)
848
+ elif format == 6:
849
+ return self.readFormat6(reader, font)
850
+ elif format == 8:
851
+ return self.readFormat8(reader, font)
852
+ else:
853
+ assert False, "unsupported lookup format: %d" % format
854
+
855
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
856
+ values = list(
857
+ sorted([(font.getGlyphID(glyph), val) for glyph, val in value.items()])
858
+ )
859
+ # TODO: Also implement format 4.
860
+ formats = list(
861
+ sorted(
862
+ filter(
863
+ None,
864
+ [
865
+ self.buildFormat0(writer, font, values),
866
+ self.buildFormat2(writer, font, values),
867
+ self.buildFormat6(writer, font, values),
868
+ self.buildFormat8(writer, font, values),
869
+ ],
870
+ )
871
+ )
872
+ )
873
+ # We use the format ID as secondary sort key to make the output
874
+ # deterministic when multiple formats have same encoded size.
875
+ dataSize, lookupFormat, writeMethod = formats[0]
876
+ pos = writer.getDataLength()
877
+ writeMethod()
878
+ actualSize = writer.getDataLength() - pos
879
+ assert (
880
+ actualSize == dataSize
881
+ ), "AATLookup format %d claimed to write %d bytes, but wrote %d" % (
882
+ lookupFormat,
883
+ dataSize,
884
+ actualSize,
885
+ )
886
+
887
+ @staticmethod
888
+ def writeBinSearchHeader(writer, numUnits, unitSize):
889
+ writer.writeUShort(unitSize)
890
+ writer.writeUShort(numUnits)
891
+ searchRange, entrySelector, rangeShift = getSearchRange(
892
+ n=numUnits, itemSize=unitSize
893
+ )
894
+ writer.writeUShort(searchRange)
895
+ writer.writeUShort(entrySelector)
896
+ writer.writeUShort(rangeShift)
897
+
898
+ def buildFormat0(self, writer, font, values):
899
+ numGlyphs = len(font.getGlyphOrder())
900
+ if len(values) != numGlyphs:
901
+ return None
902
+ valueSize = self.converter.staticSize
903
+ return (
904
+ 2 + numGlyphs * valueSize,
905
+ 0,
906
+ lambda: self.writeFormat0(writer, font, values),
907
+ )
908
+
909
+ def writeFormat0(self, writer, font, values):
910
+ writer.writeUShort(0)
911
+ for glyphID_, value in values:
912
+ self.converter.write(
913
+ writer, font, tableDict=None, value=value, repeatIndex=None
914
+ )
915
+
916
+ def buildFormat2(self, writer, font, values):
917
+ segStart, segValue = values[0]
918
+ segEnd = segStart
919
+ segments = []
920
+ for glyphID, curValue in values[1:]:
921
+ if glyphID != segEnd + 1 or curValue != segValue:
922
+ segments.append((segStart, segEnd, segValue))
923
+ segStart = segEnd = glyphID
924
+ segValue = curValue
925
+ else:
926
+ segEnd = glyphID
927
+ segments.append((segStart, segEnd, segValue))
928
+ valueSize = self.converter.staticSize
929
+ numUnits, unitSize = len(segments) + 1, valueSize + 4
930
+ return (
931
+ 2 + self.BIN_SEARCH_HEADER_SIZE + numUnits * unitSize,
932
+ 2,
933
+ lambda: self.writeFormat2(writer, font, segments),
934
+ )
935
+
936
+ def writeFormat2(self, writer, font, segments):
937
+ writer.writeUShort(2)
938
+ valueSize = self.converter.staticSize
939
+ numUnits, unitSize = len(segments), valueSize + 4
940
+ self.writeBinSearchHeader(writer, numUnits, unitSize)
941
+ for firstGlyph, lastGlyph, value in segments:
942
+ writer.writeUShort(lastGlyph)
943
+ writer.writeUShort(firstGlyph)
944
+ self.converter.write(
945
+ writer, font, tableDict=None, value=value, repeatIndex=None
946
+ )
947
+ writer.writeUShort(0xFFFF)
948
+ writer.writeUShort(0xFFFF)
949
+ writer.writeData(b"\x00" * valueSize)
950
+
951
+ def buildFormat6(self, writer, font, values):
952
+ valueSize = self.converter.staticSize
953
+ numUnits, unitSize = len(values), valueSize + 2
954
+ return (
955
+ 2 + self.BIN_SEARCH_HEADER_SIZE + (numUnits + 1) * unitSize,
956
+ 6,
957
+ lambda: self.writeFormat6(writer, font, values),
958
+ )
959
+
960
+ def writeFormat6(self, writer, font, values):
961
+ writer.writeUShort(6)
962
+ valueSize = self.converter.staticSize
963
+ numUnits, unitSize = len(values), valueSize + 2
964
+ self.writeBinSearchHeader(writer, numUnits, unitSize)
965
+ for glyphID, value in values:
966
+ writer.writeUShort(glyphID)
967
+ self.converter.write(
968
+ writer, font, tableDict=None, value=value, repeatIndex=None
969
+ )
970
+ writer.writeUShort(0xFFFF)
971
+ writer.writeData(b"\x00" * valueSize)
972
+
973
+ def buildFormat8(self, writer, font, values):
974
+ minGlyphID, maxGlyphID = values[0][0], values[-1][0]
975
+ if len(values) != maxGlyphID - minGlyphID + 1:
976
+ return None
977
+ valueSize = self.converter.staticSize
978
+ return (
979
+ 6 + len(values) * valueSize,
980
+ 8,
981
+ lambda: self.writeFormat8(writer, font, values),
982
+ )
983
+
984
+ def writeFormat8(self, writer, font, values):
985
+ firstGlyphID = values[0][0]
986
+ writer.writeUShort(8)
987
+ writer.writeUShort(firstGlyphID)
988
+ writer.writeUShort(len(values))
989
+ for _, value in values:
990
+ self.converter.write(
991
+ writer, font, tableDict=None, value=value, repeatIndex=None
992
+ )
993
+
994
+ def readFormat0(self, reader, font):
995
+ numGlyphs = len(font.getGlyphOrder())
996
+ data = self.converter.readArray(reader, font, tableDict=None, count=numGlyphs)
997
+ return {font.getGlyphName(k): value for k, value in enumerate(data)}
998
+
999
+ def readFormat2(self, reader, font):
1000
+ mapping = {}
1001
+ pos = reader.pos - 2 # start of table is at UShort for format
1002
+ unitSize, numUnits = reader.readUShort(), reader.readUShort()
1003
+ assert unitSize >= 4 + self.converter.staticSize, unitSize
1004
+ for i in range(numUnits):
1005
+ reader.seek(pos + i * unitSize + 12)
1006
+ last = reader.readUShort()
1007
+ first = reader.readUShort()
1008
+ value = self.converter.read(reader, font, tableDict=None)
1009
+ if last != 0xFFFF:
1010
+ for k in range(first, last + 1):
1011
+ mapping[font.getGlyphName(k)] = value
1012
+ return mapping
1013
+
1014
+ def readFormat4(self, reader, font):
1015
+ mapping = {}
1016
+ pos = reader.pos - 2 # start of table is at UShort for format
1017
+ unitSize = reader.readUShort()
1018
+ assert unitSize >= 6, unitSize
1019
+ for i in range(reader.readUShort()):
1020
+ reader.seek(pos + i * unitSize + 12)
1021
+ last = reader.readUShort()
1022
+ first = reader.readUShort()
1023
+ offset = reader.readUShort()
1024
+ if last != 0xFFFF:
1025
+ dataReader = reader.getSubReader(0) # relative to current position
1026
+ dataReader.seek(pos + offset) # relative to start of table
1027
+ data = self.converter.readArray(
1028
+ dataReader, font, tableDict=None, count=last - first + 1
1029
+ )
1030
+ for k, v in enumerate(data):
1031
+ mapping[font.getGlyphName(first + k)] = v
1032
+ return mapping
1033
+
1034
+ def readFormat6(self, reader, font):
1035
+ mapping = {}
1036
+ pos = reader.pos - 2 # start of table is at UShort for format
1037
+ unitSize = reader.readUShort()
1038
+ assert unitSize >= 2 + self.converter.staticSize, unitSize
1039
+ for i in range(reader.readUShort()):
1040
+ reader.seek(pos + i * unitSize + 12)
1041
+ glyphID = reader.readUShort()
1042
+ value = self.converter.read(reader, font, tableDict=None)
1043
+ if glyphID != 0xFFFF:
1044
+ mapping[font.getGlyphName(glyphID)] = value
1045
+ return mapping
1046
+
1047
+ def readFormat8(self, reader, font):
1048
+ first = reader.readUShort()
1049
+ count = reader.readUShort()
1050
+ data = self.converter.readArray(reader, font, tableDict=None, count=count)
1051
+ return {font.getGlyphName(first + k): value for (k, value) in enumerate(data)}
1052
+
1053
+ def xmlRead(self, attrs, content, font):
1054
+ value = {}
1055
+ for element in content:
1056
+ if isinstance(element, tuple):
1057
+ name, a, eltContent = element
1058
+ if name == "Lookup":
1059
+ value[a["glyph"]] = self.converter.xmlRead(a, eltContent, font)
1060
+ return value
1061
+
1062
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1063
+ xmlWriter.begintag(name, attrs)
1064
+ xmlWriter.newline()
1065
+ for glyph, value in sorted(value.items()):
1066
+ self.converter.xmlWrite(
1067
+ xmlWriter, font, value=value, name="Lookup", attrs=[("glyph", glyph)]
1068
+ )
1069
+ xmlWriter.endtag(name)
1070
+ xmlWriter.newline()
1071
+
1072
+
1073
+ # The AAT 'ankr' table has an unusual structure: An offset to an AATLookup
1074
+ # followed by an offset to a glyph data table. Other than usual, the
1075
+ # offsets in the AATLookup are not relative to the beginning of
1076
+ # the beginning of the 'ankr' table, but relative to the glyph data table.
1077
+ # So, to find the anchor data for a glyph, one needs to add the offset
1078
+ # to the data table to the offset found in the AATLookup, and then use
1079
+ # the sum of these two offsets to find the actual data.
1080
+ class AATLookupWithDataOffset(BaseConverter):
1081
+ def read(self, reader, font, tableDict):
1082
+ lookupOffset = reader.readULong()
1083
+ dataOffset = reader.readULong()
1084
+ lookupReader = reader.getSubReader(lookupOffset)
1085
+ lookup = AATLookup("DataOffsets", None, None, UShort)
1086
+ offsets = lookup.read(lookupReader, font, tableDict)
1087
+ result = {}
1088
+ for glyph, offset in offsets.items():
1089
+ dataReader = reader.getSubReader(offset + dataOffset)
1090
+ item = self.tableClass()
1091
+ item.decompile(dataReader, font)
1092
+ result[glyph] = item
1093
+ return result
1094
+
1095
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1096
+ # We do not work with OTTableWriter sub-writers because
1097
+ # the offsets in our AATLookup are relative to our data
1098
+ # table, for which we need to provide an offset value itself.
1099
+ # It might have been possible to somehow make a kludge for
1100
+ # performing this indirect offset computation directly inside
1101
+ # OTTableWriter. But this would have made the internal logic
1102
+ # of OTTableWriter even more complex than it already is,
1103
+ # so we decided to roll our own offset computation for the
1104
+ # contents of the AATLookup and associated data table.
1105
+ offsetByGlyph, offsetByData, dataLen = {}, {}, 0
1106
+ compiledData = []
1107
+ for glyph in sorted(value, key=font.getGlyphID):
1108
+ subWriter = OTTableWriter()
1109
+ value[glyph].compile(subWriter, font)
1110
+ data = subWriter.getAllData()
1111
+ offset = offsetByData.get(data, None)
1112
+ if offset == None:
1113
+ offset = dataLen
1114
+ dataLen = dataLen + len(data)
1115
+ offsetByData[data] = offset
1116
+ compiledData.append(data)
1117
+ offsetByGlyph[glyph] = offset
1118
+ # For calculating the offsets to our AATLookup and data table,
1119
+ # we can use the regular OTTableWriter infrastructure.
1120
+ lookupWriter = writer.getSubWriter()
1121
+ lookup = AATLookup("DataOffsets", None, None, UShort)
1122
+ lookup.write(lookupWriter, font, tableDict, offsetByGlyph, None)
1123
+
1124
+ dataWriter = writer.getSubWriter()
1125
+ writer.writeSubTable(lookupWriter, offsetSize=4)
1126
+ writer.writeSubTable(dataWriter, offsetSize=4)
1127
+ for d in compiledData:
1128
+ dataWriter.writeData(d)
1129
+
1130
+ def xmlRead(self, attrs, content, font):
1131
+ lookup = AATLookup("DataOffsets", None, None, self.tableClass)
1132
+ return lookup.xmlRead(attrs, content, font)
1133
+
1134
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1135
+ lookup = AATLookup("DataOffsets", None, None, self.tableClass)
1136
+ lookup.xmlWrite(xmlWriter, font, value, name, attrs)
1137
+
1138
+
1139
+ class MorxSubtableConverter(BaseConverter):
1140
+ _PROCESSING_ORDERS = {
1141
+ # bits 30 and 28 of morx.CoverageFlags; see morx spec
1142
+ (False, False): "LayoutOrder",
1143
+ (True, False): "ReversedLayoutOrder",
1144
+ (False, True): "LogicalOrder",
1145
+ (True, True): "ReversedLogicalOrder",
1146
+ }
1147
+
1148
+ _PROCESSING_ORDERS_REVERSED = {val: key for key, val in _PROCESSING_ORDERS.items()}
1149
+
1150
+ def __init__(self, name, repeat, aux, tableClass=None, *, description=""):
1151
+ BaseConverter.__init__(
1152
+ self, name, repeat, aux, tableClass, description=description
1153
+ )
1154
+
1155
+ def _setTextDirectionFromCoverageFlags(self, flags, subtable):
1156
+ if (flags & 0x20) != 0:
1157
+ subtable.TextDirection = "Any"
1158
+ elif (flags & 0x80) != 0:
1159
+ subtable.TextDirection = "Vertical"
1160
+ else:
1161
+ subtable.TextDirection = "Horizontal"
1162
+
1163
+ def read(self, reader, font, tableDict):
1164
+ pos = reader.pos
1165
+ m = MorxSubtable()
1166
+ m.StructLength = reader.readULong()
1167
+ flags = reader.readUInt8()
1168
+ orderKey = ((flags & 0x40) != 0, (flags & 0x10) != 0)
1169
+ m.ProcessingOrder = self._PROCESSING_ORDERS[orderKey]
1170
+ self._setTextDirectionFromCoverageFlags(flags, m)
1171
+ m.Reserved = reader.readUShort()
1172
+ m.Reserved |= (flags & 0xF) << 16
1173
+ m.MorphType = reader.readUInt8()
1174
+ m.SubFeatureFlags = reader.readULong()
1175
+ tableClass = lookupTypes["morx"].get(m.MorphType)
1176
+ if tableClass is None:
1177
+ assert False, "unsupported 'morx' lookup type %s" % m.MorphType
1178
+ # To decode AAT ligatures, we need to know the subtable size.
1179
+ # The easiest way to pass this along is to create a new reader
1180
+ # that works on just the subtable as its data.
1181
+ headerLength = reader.pos - pos
1182
+ data = reader.data[reader.pos : reader.pos + m.StructLength - headerLength]
1183
+ assert len(data) == m.StructLength - headerLength
1184
+ subReader = OTTableReader(data=data, tableTag=reader.tableTag)
1185
+ m.SubStruct = tableClass()
1186
+ m.SubStruct.decompile(subReader, font)
1187
+ reader.seek(pos + m.StructLength)
1188
+ return m
1189
+
1190
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1191
+ xmlWriter.begintag(name, attrs)
1192
+ xmlWriter.newline()
1193
+ xmlWriter.comment("StructLength=%d" % value.StructLength)
1194
+ xmlWriter.newline()
1195
+ xmlWriter.simpletag("TextDirection", value=value.TextDirection)
1196
+ xmlWriter.newline()
1197
+ xmlWriter.simpletag("ProcessingOrder", value=value.ProcessingOrder)
1198
+ xmlWriter.newline()
1199
+ if value.Reserved != 0:
1200
+ xmlWriter.simpletag("Reserved", value="0x%04x" % value.Reserved)
1201
+ xmlWriter.newline()
1202
+ xmlWriter.comment("MorphType=%d" % value.MorphType)
1203
+ xmlWriter.newline()
1204
+ xmlWriter.simpletag("SubFeatureFlags", value="0x%08x" % value.SubFeatureFlags)
1205
+ xmlWriter.newline()
1206
+ value.SubStruct.toXML(xmlWriter, font)
1207
+ xmlWriter.endtag(name)
1208
+ xmlWriter.newline()
1209
+
1210
+ def xmlRead(self, attrs, content, font):
1211
+ m = MorxSubtable()
1212
+ covFlags = 0
1213
+ m.Reserved = 0
1214
+ for eltName, eltAttrs, eltContent in filter(istuple, content):
1215
+ if eltName == "CoverageFlags":
1216
+ # Only in XML from old versions of fonttools.
1217
+ covFlags = safeEval(eltAttrs["value"])
1218
+ orderKey = ((covFlags & 0x40) != 0, (covFlags & 0x10) != 0)
1219
+ m.ProcessingOrder = self._PROCESSING_ORDERS[orderKey]
1220
+ self._setTextDirectionFromCoverageFlags(covFlags, m)
1221
+ elif eltName == "ProcessingOrder":
1222
+ m.ProcessingOrder = eltAttrs["value"]
1223
+ assert m.ProcessingOrder in self._PROCESSING_ORDERS_REVERSED, (
1224
+ "unknown ProcessingOrder: %s" % m.ProcessingOrder
1225
+ )
1226
+ elif eltName == "TextDirection":
1227
+ m.TextDirection = eltAttrs["value"]
1228
+ assert m.TextDirection in {"Horizontal", "Vertical", "Any"}, (
1229
+ "unknown TextDirection %s" % m.TextDirection
1230
+ )
1231
+ elif eltName == "Reserved":
1232
+ m.Reserved = safeEval(eltAttrs["value"])
1233
+ elif eltName == "SubFeatureFlags":
1234
+ m.SubFeatureFlags = safeEval(eltAttrs["value"])
1235
+ elif eltName.endswith("Morph"):
1236
+ m.fromXML(eltName, eltAttrs, eltContent, font)
1237
+ else:
1238
+ assert False, eltName
1239
+ m.Reserved = (covFlags & 0xF) << 16 | m.Reserved
1240
+ return m
1241
+
1242
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1243
+ covFlags = (value.Reserved & 0x000F0000) >> 16
1244
+ reverseOrder, logicalOrder = self._PROCESSING_ORDERS_REVERSED[
1245
+ value.ProcessingOrder
1246
+ ]
1247
+ covFlags |= 0x80 if value.TextDirection == "Vertical" else 0
1248
+ covFlags |= 0x40 if reverseOrder else 0
1249
+ covFlags |= 0x20 if value.TextDirection == "Any" else 0
1250
+ covFlags |= 0x10 if logicalOrder else 0
1251
+ value.CoverageFlags = covFlags
1252
+ lengthIndex = len(writer.items)
1253
+ before = writer.getDataLength()
1254
+ value.StructLength = 0xDEADBEEF
1255
+ # The high nibble of value.Reserved is actuallly encoded
1256
+ # into coverageFlags, so we need to clear it here.
1257
+ origReserved = value.Reserved # including high nibble
1258
+ value.Reserved = value.Reserved & 0xFFFF # without high nibble
1259
+ value.compile(writer, font)
1260
+ value.Reserved = origReserved # restore original value
1261
+ assert writer.items[lengthIndex] == b"\xde\xad\xbe\xef"
1262
+ length = writer.getDataLength() - before
1263
+ writer.items[lengthIndex] = struct.pack(">L", length)
1264
+
1265
+
1266
+ # https://developer.apple.com/fonts/TrueType-Reference-Manual/RM06/Chap6Tables.html#ExtendedStateHeader
1267
+ # TODO: Untangle the implementation of the various lookup-specific formats.
1268
+ class STXHeader(BaseConverter):
1269
+ def __init__(self, name, repeat, aux, tableClass, *, description=""):
1270
+ BaseConverter.__init__(
1271
+ self, name, repeat, aux, tableClass, description=description
1272
+ )
1273
+ assert issubclass(self.tableClass, AATAction)
1274
+ self.classLookup = AATLookup("GlyphClasses", None, None, UShort)
1275
+ if issubclass(self.tableClass, ContextualMorphAction):
1276
+ self.perGlyphLookup = AATLookup("PerGlyphLookup", None, None, GlyphID)
1277
+ else:
1278
+ self.perGlyphLookup = None
1279
+
1280
+ def read(self, reader, font, tableDict):
1281
+ table = AATStateTable()
1282
+ pos = reader.pos
1283
+ classTableReader = reader.getSubReader(0)
1284
+ stateArrayReader = reader.getSubReader(0)
1285
+ entryTableReader = reader.getSubReader(0)
1286
+ actionReader = None
1287
+ ligaturesReader = None
1288
+ table.GlyphClassCount = reader.readULong()
1289
+ classTableReader.seek(pos + reader.readULong())
1290
+ stateArrayReader.seek(pos + reader.readULong())
1291
+ entryTableReader.seek(pos + reader.readULong())
1292
+ if self.perGlyphLookup is not None:
1293
+ perGlyphTableReader = reader.getSubReader(0)
1294
+ perGlyphTableReader.seek(pos + reader.readULong())
1295
+ if issubclass(self.tableClass, LigatureMorphAction):
1296
+ actionReader = reader.getSubReader(0)
1297
+ actionReader.seek(pos + reader.readULong())
1298
+ ligComponentReader = reader.getSubReader(0)
1299
+ ligComponentReader.seek(pos + reader.readULong())
1300
+ ligaturesReader = reader.getSubReader(0)
1301
+ ligaturesReader.seek(pos + reader.readULong())
1302
+ numLigComponents = (ligaturesReader.pos - ligComponentReader.pos) // 2
1303
+ assert numLigComponents >= 0
1304
+ table.LigComponents = ligComponentReader.readUShortArray(numLigComponents)
1305
+ table.Ligatures = self._readLigatures(ligaturesReader, font)
1306
+ elif issubclass(self.tableClass, InsertionMorphAction):
1307
+ actionReader = reader.getSubReader(0)
1308
+ actionReader.seek(pos + reader.readULong())
1309
+ table.GlyphClasses = self.classLookup.read(classTableReader, font, tableDict)
1310
+ numStates = int(
1311
+ (entryTableReader.pos - stateArrayReader.pos) / (table.GlyphClassCount * 2)
1312
+ )
1313
+ for stateIndex in range(numStates):
1314
+ state = AATState()
1315
+ table.States.append(state)
1316
+ for glyphClass in range(table.GlyphClassCount):
1317
+ entryIndex = stateArrayReader.readUShort()
1318
+ state.Transitions[glyphClass] = self._readTransition(
1319
+ entryTableReader, entryIndex, font, actionReader
1320
+ )
1321
+ if self.perGlyphLookup is not None:
1322
+ table.PerGlyphLookups = self._readPerGlyphLookups(
1323
+ table, perGlyphTableReader, font
1324
+ )
1325
+ return table
1326
+
1327
+ def _readTransition(self, reader, entryIndex, font, actionReader):
1328
+ transition = self.tableClass()
1329
+ entryReader = reader.getSubReader(
1330
+ reader.pos + entryIndex * transition.staticSize
1331
+ )
1332
+ transition.decompile(entryReader, font, actionReader)
1333
+ return transition
1334
+
1335
+ def _readLigatures(self, reader, font):
1336
+ limit = len(reader.data)
1337
+ numLigatureGlyphs = (limit - reader.pos) // 2
1338
+ return font.getGlyphNameMany(reader.readUShortArray(numLigatureGlyphs))
1339
+
1340
+ def _countPerGlyphLookups(self, table):
1341
+ # Somewhat annoyingly, the morx table does not encode
1342
+ # the size of the per-glyph table. So we need to find
1343
+ # the maximum value that MorphActions use as index
1344
+ # into this table.
1345
+ numLookups = 0
1346
+ for state in table.States:
1347
+ for t in state.Transitions.values():
1348
+ if isinstance(t, ContextualMorphAction):
1349
+ if t.MarkIndex != 0xFFFF:
1350
+ numLookups = max(numLookups, t.MarkIndex + 1)
1351
+ if t.CurrentIndex != 0xFFFF:
1352
+ numLookups = max(numLookups, t.CurrentIndex + 1)
1353
+ return numLookups
1354
+
1355
+ def _readPerGlyphLookups(self, table, reader, font):
1356
+ pos = reader.pos
1357
+ lookups = []
1358
+ for _ in range(self._countPerGlyphLookups(table)):
1359
+ lookupReader = reader.getSubReader(0)
1360
+ lookupReader.seek(pos + reader.readULong())
1361
+ lookups.append(self.perGlyphLookup.read(lookupReader, font, {}))
1362
+ return lookups
1363
+
1364
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1365
+ glyphClassWriter = OTTableWriter()
1366
+ self.classLookup.write(
1367
+ glyphClassWriter, font, tableDict, value.GlyphClasses, repeatIndex=None
1368
+ )
1369
+ glyphClassData = pad(glyphClassWriter.getAllData(), 2)
1370
+ glyphClassCount = max(value.GlyphClasses.values()) + 1
1371
+ glyphClassTableOffset = 16 # size of STXHeader
1372
+ if self.perGlyphLookup is not None:
1373
+ glyphClassTableOffset += 4
1374
+
1375
+ glyphClassTableOffset += self.tableClass.actionHeaderSize
1376
+ actionData, actionIndex = self.tableClass.compileActions(font, value.States)
1377
+ stateArrayData, entryTableData = self._compileStates(
1378
+ font, value.States, glyphClassCount, actionIndex
1379
+ )
1380
+ stateArrayOffset = glyphClassTableOffset + len(glyphClassData)
1381
+ entryTableOffset = stateArrayOffset + len(stateArrayData)
1382
+ perGlyphOffset = entryTableOffset + len(entryTableData)
1383
+ perGlyphData = pad(self._compilePerGlyphLookups(value, font), 4)
1384
+ if actionData is not None:
1385
+ actionOffset = entryTableOffset + len(entryTableData)
1386
+ else:
1387
+ actionOffset = None
1388
+
1389
+ ligaturesOffset, ligComponentsOffset = None, None
1390
+ ligComponentsData = self._compileLigComponents(value, font)
1391
+ ligaturesData = self._compileLigatures(value, font)
1392
+ if ligComponentsData is not None:
1393
+ assert len(perGlyphData) == 0
1394
+ ligComponentsOffset = actionOffset + len(actionData)
1395
+ ligaturesOffset = ligComponentsOffset + len(ligComponentsData)
1396
+
1397
+ writer.writeULong(glyphClassCount)
1398
+ writer.writeULong(glyphClassTableOffset)
1399
+ writer.writeULong(stateArrayOffset)
1400
+ writer.writeULong(entryTableOffset)
1401
+ if self.perGlyphLookup is not None:
1402
+ writer.writeULong(perGlyphOffset)
1403
+ if actionOffset is not None:
1404
+ writer.writeULong(actionOffset)
1405
+ if ligComponentsOffset is not None:
1406
+ writer.writeULong(ligComponentsOffset)
1407
+ writer.writeULong(ligaturesOffset)
1408
+ writer.writeData(glyphClassData)
1409
+ writer.writeData(stateArrayData)
1410
+ writer.writeData(entryTableData)
1411
+ writer.writeData(perGlyphData)
1412
+ if actionData is not None:
1413
+ writer.writeData(actionData)
1414
+ if ligComponentsData is not None:
1415
+ writer.writeData(ligComponentsData)
1416
+ if ligaturesData is not None:
1417
+ writer.writeData(ligaturesData)
1418
+
1419
+ def _compileStates(self, font, states, glyphClassCount, actionIndex):
1420
+ stateArrayWriter = OTTableWriter()
1421
+ entries, entryIDs = [], {}
1422
+ for state in states:
1423
+ for glyphClass in range(glyphClassCount):
1424
+ transition = state.Transitions[glyphClass]
1425
+ entryWriter = OTTableWriter()
1426
+ transition.compile(entryWriter, font, actionIndex)
1427
+ entryData = entryWriter.getAllData()
1428
+ assert (
1429
+ len(entryData) == transition.staticSize
1430
+ ), "%s has staticSize %d, " "but actually wrote %d bytes" % (
1431
+ repr(transition),
1432
+ transition.staticSize,
1433
+ len(entryData),
1434
+ )
1435
+ entryIndex = entryIDs.get(entryData)
1436
+ if entryIndex is None:
1437
+ entryIndex = len(entries)
1438
+ entryIDs[entryData] = entryIndex
1439
+ entries.append(entryData)
1440
+ stateArrayWriter.writeUShort(entryIndex)
1441
+ stateArrayData = pad(stateArrayWriter.getAllData(), 4)
1442
+ entryTableData = pad(bytesjoin(entries), 4)
1443
+ return stateArrayData, entryTableData
1444
+
1445
+ def _compilePerGlyphLookups(self, table, font):
1446
+ if self.perGlyphLookup is None:
1447
+ return b""
1448
+ numLookups = self._countPerGlyphLookups(table)
1449
+ assert len(table.PerGlyphLookups) == numLookups, (
1450
+ "len(AATStateTable.PerGlyphLookups) is %d, "
1451
+ "but the actions inside the table refer to %d"
1452
+ % (len(table.PerGlyphLookups), numLookups)
1453
+ )
1454
+ writer = OTTableWriter()
1455
+ for lookup in table.PerGlyphLookups:
1456
+ lookupWriter = writer.getSubWriter()
1457
+ self.perGlyphLookup.write(lookupWriter, font, {}, lookup, None)
1458
+ writer.writeSubTable(lookupWriter, offsetSize=4)
1459
+ return writer.getAllData()
1460
+
1461
+ def _compileLigComponents(self, table, font):
1462
+ if not hasattr(table, "LigComponents"):
1463
+ return None
1464
+ writer = OTTableWriter()
1465
+ for component in table.LigComponents:
1466
+ writer.writeUShort(component)
1467
+ return writer.getAllData()
1468
+
1469
+ def _compileLigatures(self, table, font):
1470
+ if not hasattr(table, "Ligatures"):
1471
+ return None
1472
+ writer = OTTableWriter()
1473
+ for glyphName in table.Ligatures:
1474
+ writer.writeUShort(font.getGlyphID(glyphName))
1475
+ return writer.getAllData()
1476
+
1477
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1478
+ xmlWriter.begintag(name, attrs)
1479
+ xmlWriter.newline()
1480
+ xmlWriter.comment("GlyphClassCount=%s" % value.GlyphClassCount)
1481
+ xmlWriter.newline()
1482
+ for g, klass in sorted(value.GlyphClasses.items()):
1483
+ xmlWriter.simpletag("GlyphClass", glyph=g, value=klass)
1484
+ xmlWriter.newline()
1485
+ for stateIndex, state in enumerate(value.States):
1486
+ xmlWriter.begintag("State", index=stateIndex)
1487
+ xmlWriter.newline()
1488
+ for glyphClass, trans in sorted(state.Transitions.items()):
1489
+ trans.toXML(
1490
+ xmlWriter,
1491
+ font=font,
1492
+ attrs={"onGlyphClass": glyphClass},
1493
+ name="Transition",
1494
+ )
1495
+ xmlWriter.endtag("State")
1496
+ xmlWriter.newline()
1497
+ for i, lookup in enumerate(value.PerGlyphLookups):
1498
+ xmlWriter.begintag("PerGlyphLookup", index=i)
1499
+ xmlWriter.newline()
1500
+ for glyph, val in sorted(lookup.items()):
1501
+ xmlWriter.simpletag("Lookup", glyph=glyph, value=val)
1502
+ xmlWriter.newline()
1503
+ xmlWriter.endtag("PerGlyphLookup")
1504
+ xmlWriter.newline()
1505
+ if hasattr(value, "LigComponents"):
1506
+ xmlWriter.begintag("LigComponents")
1507
+ xmlWriter.newline()
1508
+ for i, val in enumerate(getattr(value, "LigComponents")):
1509
+ xmlWriter.simpletag("LigComponent", index=i, value=val)
1510
+ xmlWriter.newline()
1511
+ xmlWriter.endtag("LigComponents")
1512
+ xmlWriter.newline()
1513
+ self._xmlWriteLigatures(xmlWriter, font, value, name, attrs)
1514
+ xmlWriter.endtag(name)
1515
+ xmlWriter.newline()
1516
+
1517
+ def _xmlWriteLigatures(self, xmlWriter, font, value, name, attrs):
1518
+ if not hasattr(value, "Ligatures"):
1519
+ return
1520
+ xmlWriter.begintag("Ligatures")
1521
+ xmlWriter.newline()
1522
+ for i, g in enumerate(getattr(value, "Ligatures")):
1523
+ xmlWriter.simpletag("Ligature", index=i, glyph=g)
1524
+ xmlWriter.newline()
1525
+ xmlWriter.endtag("Ligatures")
1526
+ xmlWriter.newline()
1527
+
1528
+ def xmlRead(self, attrs, content, font):
1529
+ table = AATStateTable()
1530
+ for eltName, eltAttrs, eltContent in filter(istuple, content):
1531
+ if eltName == "GlyphClass":
1532
+ glyph = eltAttrs["glyph"]
1533
+ value = eltAttrs["value"]
1534
+ table.GlyphClasses[glyph] = safeEval(value)
1535
+ elif eltName == "State":
1536
+ state = self._xmlReadState(eltAttrs, eltContent, font)
1537
+ table.States.append(state)
1538
+ elif eltName == "PerGlyphLookup":
1539
+ lookup = self.perGlyphLookup.xmlRead(eltAttrs, eltContent, font)
1540
+ table.PerGlyphLookups.append(lookup)
1541
+ elif eltName == "LigComponents":
1542
+ table.LigComponents = self._xmlReadLigComponents(
1543
+ eltAttrs, eltContent, font
1544
+ )
1545
+ elif eltName == "Ligatures":
1546
+ table.Ligatures = self._xmlReadLigatures(eltAttrs, eltContent, font)
1547
+ table.GlyphClassCount = max(table.GlyphClasses.values()) + 1
1548
+ return table
1549
+
1550
+ def _xmlReadState(self, attrs, content, font):
1551
+ state = AATState()
1552
+ for eltName, eltAttrs, eltContent in filter(istuple, content):
1553
+ if eltName == "Transition":
1554
+ glyphClass = safeEval(eltAttrs["onGlyphClass"])
1555
+ transition = self.tableClass()
1556
+ transition.fromXML(eltName, eltAttrs, eltContent, font)
1557
+ state.Transitions[glyphClass] = transition
1558
+ return state
1559
+
1560
+ def _xmlReadLigComponents(self, attrs, content, font):
1561
+ ligComponents = []
1562
+ for eltName, eltAttrs, _eltContent in filter(istuple, content):
1563
+ if eltName == "LigComponent":
1564
+ ligComponents.append(safeEval(eltAttrs["value"]))
1565
+ return ligComponents
1566
+
1567
+ def _xmlReadLigatures(self, attrs, content, font):
1568
+ ligs = []
1569
+ for eltName, eltAttrs, _eltContent in filter(istuple, content):
1570
+ if eltName == "Ligature":
1571
+ ligs.append(eltAttrs["glyph"])
1572
+ return ligs
1573
+
1574
+
1575
+ class CIDGlyphMap(BaseConverter):
1576
+ def read(self, reader, font, tableDict):
1577
+ numCIDs = reader.readUShort()
1578
+ result = {}
1579
+ for cid, glyphID in enumerate(reader.readUShortArray(numCIDs)):
1580
+ if glyphID != 0xFFFF:
1581
+ result[cid] = font.getGlyphName(glyphID)
1582
+ return result
1583
+
1584
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1585
+ items = {cid: font.getGlyphID(glyph) for cid, glyph in value.items()}
1586
+ count = max(items) + 1 if items else 0
1587
+ writer.writeUShort(count)
1588
+ for cid in range(count):
1589
+ writer.writeUShort(items.get(cid, 0xFFFF))
1590
+
1591
+ def xmlRead(self, attrs, content, font):
1592
+ result = {}
1593
+ for eName, eAttrs, _eContent in filter(istuple, content):
1594
+ if eName == "CID":
1595
+ result[safeEval(eAttrs["cid"])] = eAttrs["glyph"].strip()
1596
+ return result
1597
+
1598
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1599
+ xmlWriter.begintag(name, attrs)
1600
+ xmlWriter.newline()
1601
+ for cid, glyph in sorted(value.items()):
1602
+ if glyph is not None and glyph != 0xFFFF:
1603
+ xmlWriter.simpletag("CID", cid=cid, glyph=glyph)
1604
+ xmlWriter.newline()
1605
+ xmlWriter.endtag(name)
1606
+ xmlWriter.newline()
1607
+
1608
+
1609
+ class GlyphCIDMap(BaseConverter):
1610
+ def read(self, reader, font, tableDict):
1611
+ glyphOrder = font.getGlyphOrder()
1612
+ count = reader.readUShort()
1613
+ cids = reader.readUShortArray(count)
1614
+ if count > len(glyphOrder):
1615
+ log.warning(
1616
+ "GlyphCIDMap has %d elements, "
1617
+ "but the font has only %d glyphs; "
1618
+ "ignoring the rest" % (count, len(glyphOrder))
1619
+ )
1620
+ result = {}
1621
+ for glyphID in range(min(len(cids), len(glyphOrder))):
1622
+ cid = cids[glyphID]
1623
+ if cid != 0xFFFF:
1624
+ result[glyphOrder[glyphID]] = cid
1625
+ return result
1626
+
1627
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1628
+ items = {
1629
+ font.getGlyphID(g): cid
1630
+ for g, cid in value.items()
1631
+ if cid is not None and cid != 0xFFFF
1632
+ }
1633
+ count = max(items) + 1 if items else 0
1634
+ writer.writeUShort(count)
1635
+ for glyphID in range(count):
1636
+ writer.writeUShort(items.get(glyphID, 0xFFFF))
1637
+
1638
+ def xmlRead(self, attrs, content, font):
1639
+ result = {}
1640
+ for eName, eAttrs, _eContent in filter(istuple, content):
1641
+ if eName == "CID":
1642
+ result[eAttrs["glyph"]] = safeEval(eAttrs["value"])
1643
+ return result
1644
+
1645
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1646
+ xmlWriter.begintag(name, attrs)
1647
+ xmlWriter.newline()
1648
+ for glyph, cid in sorted(value.items()):
1649
+ if cid is not None and cid != 0xFFFF:
1650
+ xmlWriter.simpletag("CID", glyph=glyph, value=cid)
1651
+ xmlWriter.newline()
1652
+ xmlWriter.endtag(name)
1653
+ xmlWriter.newline()
1654
+
1655
+
1656
+ class DeltaValue(BaseConverter):
1657
+ def read(self, reader, font, tableDict):
1658
+ StartSize = tableDict["StartSize"]
1659
+ EndSize = tableDict["EndSize"]
1660
+ DeltaFormat = tableDict["DeltaFormat"]
1661
+ assert DeltaFormat in (1, 2, 3), "illegal DeltaFormat"
1662
+ nItems = EndSize - StartSize + 1
1663
+ nBits = 1 << DeltaFormat
1664
+ minusOffset = 1 << nBits
1665
+ mask = (1 << nBits) - 1
1666
+ signMask = 1 << (nBits - 1)
1667
+
1668
+ DeltaValue = []
1669
+ tmp, shift = 0, 0
1670
+ for i in range(nItems):
1671
+ if shift == 0:
1672
+ tmp, shift = reader.readUShort(), 16
1673
+ shift = shift - nBits
1674
+ value = (tmp >> shift) & mask
1675
+ if value & signMask:
1676
+ value = value - minusOffset
1677
+ DeltaValue.append(value)
1678
+ return DeltaValue
1679
+
1680
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1681
+ StartSize = tableDict["StartSize"]
1682
+ EndSize = tableDict["EndSize"]
1683
+ DeltaFormat = tableDict["DeltaFormat"]
1684
+ DeltaValue = value
1685
+ assert DeltaFormat in (1, 2, 3), "illegal DeltaFormat"
1686
+ nItems = EndSize - StartSize + 1
1687
+ nBits = 1 << DeltaFormat
1688
+ assert len(DeltaValue) == nItems
1689
+ mask = (1 << nBits) - 1
1690
+
1691
+ tmp, shift = 0, 16
1692
+ for value in DeltaValue:
1693
+ shift = shift - nBits
1694
+ tmp = tmp | ((value & mask) << shift)
1695
+ if shift == 0:
1696
+ writer.writeUShort(tmp)
1697
+ tmp, shift = 0, 16
1698
+ if shift != 16:
1699
+ writer.writeUShort(tmp)
1700
+
1701
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1702
+ xmlWriter.simpletag(name, attrs + [("value", value)])
1703
+ xmlWriter.newline()
1704
+
1705
+ def xmlRead(self, attrs, content, font):
1706
+ return safeEval(attrs["value"])
1707
+
1708
+
1709
+ class VarIdxMapValue(BaseConverter):
1710
+ def read(self, reader, font, tableDict):
1711
+ fmt = tableDict["EntryFormat"]
1712
+ nItems = tableDict["MappingCount"]
1713
+
1714
+ innerBits = 1 + (fmt & 0x000F)
1715
+ innerMask = (1 << innerBits) - 1
1716
+ outerMask = 0xFFFFFFFF - innerMask
1717
+ outerShift = 16 - innerBits
1718
+
1719
+ entrySize = 1 + ((fmt & 0x0030) >> 4)
1720
+ readArray = {
1721
+ 1: reader.readUInt8Array,
1722
+ 2: reader.readUShortArray,
1723
+ 3: reader.readUInt24Array,
1724
+ 4: reader.readULongArray,
1725
+ }[entrySize]
1726
+
1727
+ return [
1728
+ (((raw & outerMask) << outerShift) | (raw & innerMask))
1729
+ for raw in readArray(nItems)
1730
+ ]
1731
+
1732
+ def write(self, writer, font, tableDict, value, repeatIndex=None):
1733
+ fmt = tableDict["EntryFormat"]
1734
+ mapping = value
1735
+ writer["MappingCount"].setValue(len(mapping))
1736
+
1737
+ innerBits = 1 + (fmt & 0x000F)
1738
+ innerMask = (1 << innerBits) - 1
1739
+ outerShift = 16 - innerBits
1740
+
1741
+ entrySize = 1 + ((fmt & 0x0030) >> 4)
1742
+ writeArray = {
1743
+ 1: writer.writeUInt8Array,
1744
+ 2: writer.writeUShortArray,
1745
+ 3: writer.writeUInt24Array,
1746
+ 4: writer.writeULongArray,
1747
+ }[entrySize]
1748
+
1749
+ writeArray(
1750
+ [
1751
+ (((idx & 0xFFFF0000) >> outerShift) | (idx & innerMask))
1752
+ for idx in mapping
1753
+ ]
1754
+ )
1755
+
1756
+
1757
+ class VarDataValue(BaseConverter):
1758
+ def read(self, reader, font, tableDict):
1759
+ values = []
1760
+
1761
+ regionCount = tableDict["VarRegionCount"]
1762
+ wordCount = tableDict["NumShorts"]
1763
+
1764
+ # https://github.com/fonttools/fonttools/issues/2279
1765
+ longWords = bool(wordCount & 0x8000)
1766
+ wordCount = wordCount & 0x7FFF
1767
+
1768
+ if longWords:
1769
+ readBigArray, readSmallArray = reader.readLongArray, reader.readShortArray
1770
+ else:
1771
+ readBigArray, readSmallArray = reader.readShortArray, reader.readInt8Array
1772
+
1773
+ n1, n2 = min(regionCount, wordCount), max(regionCount, wordCount)
1774
+ values.extend(readBigArray(n1))
1775
+ values.extend(readSmallArray(n2 - n1))
1776
+ if n2 > regionCount: # Padding
1777
+ del values[regionCount:]
1778
+
1779
+ return values
1780
+
1781
+ def write(self, writer, font, tableDict, values, repeatIndex=None):
1782
+ regionCount = tableDict["VarRegionCount"]
1783
+ wordCount = tableDict["NumShorts"]
1784
+
1785
+ # https://github.com/fonttools/fonttools/issues/2279
1786
+ longWords = bool(wordCount & 0x8000)
1787
+ wordCount = wordCount & 0x7FFF
1788
+
1789
+ (writeBigArray, writeSmallArray) = {
1790
+ False: (writer.writeShortArray, writer.writeInt8Array),
1791
+ True: (writer.writeLongArray, writer.writeShortArray),
1792
+ }[longWords]
1793
+
1794
+ n1, n2 = min(regionCount, wordCount), max(regionCount, wordCount)
1795
+ writeBigArray(values[:n1])
1796
+ writeSmallArray(values[n1:regionCount])
1797
+ if n2 > regionCount: # Padding
1798
+ writer.writeSmallArray([0] * (n2 - regionCount))
1799
+
1800
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1801
+ xmlWriter.simpletag(name, attrs + [("value", value)])
1802
+ xmlWriter.newline()
1803
+
1804
+ def xmlRead(self, attrs, content, font):
1805
+ return safeEval(attrs["value"])
1806
+
1807
+
1808
+ class TupleValues:
1809
+ def read(self, data, font):
1810
+ return TupleVariation.decompileDeltas_(None, data)[0]
1811
+
1812
+ def write(self, writer, font, tableDict, values, repeatIndex=None):
1813
+ optimizeSpeed = font.cfg[OPTIMIZE_FONT_SPEED]
1814
+ return bytes(
1815
+ TupleVariation.compileDeltaValues_(values, optimizeSize=not optimizeSpeed)
1816
+ )
1817
+
1818
+ def xmlRead(self, attrs, content, font):
1819
+ return safeEval(attrs["value"])
1820
+
1821
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1822
+ xmlWriter.simpletag(name, attrs + [("value", value)])
1823
+ xmlWriter.newline()
1824
+
1825
+
1826
+ class CFF2Index(BaseConverter):
1827
+ def __init__(
1828
+ self,
1829
+ name,
1830
+ repeat,
1831
+ aux,
1832
+ tableClass=None,
1833
+ *,
1834
+ itemClass=None,
1835
+ itemConverterClass=None,
1836
+ description="",
1837
+ ):
1838
+ BaseConverter.__init__(
1839
+ self, name, repeat, aux, tableClass, description=description
1840
+ )
1841
+ self._itemClass = itemClass
1842
+ self._converter = (
1843
+ itemConverterClass() if itemConverterClass is not None else None
1844
+ )
1845
+
1846
+ def read(self, reader, font, tableDict):
1847
+ count = reader.readULong()
1848
+ if count == 0:
1849
+ return []
1850
+ offSize = reader.readUInt8()
1851
+
1852
+ def getReadArray(reader, offSize):
1853
+ return {
1854
+ 1: reader.readUInt8Array,
1855
+ 2: reader.readUShortArray,
1856
+ 3: reader.readUInt24Array,
1857
+ 4: reader.readULongArray,
1858
+ }[offSize]
1859
+
1860
+ readArray = getReadArray(reader, offSize)
1861
+
1862
+ lazy = font.lazy is not False and count > 8
1863
+ if not lazy:
1864
+ offsets = readArray(count + 1)
1865
+ items = []
1866
+ lastOffset = offsets.pop(0)
1867
+ reader.readData(lastOffset - 1) # In case first offset is not 1
1868
+
1869
+ for offset in offsets:
1870
+ assert lastOffset <= offset
1871
+ item = reader.readData(offset - lastOffset)
1872
+
1873
+ if self._itemClass is not None:
1874
+ obj = self._itemClass()
1875
+ obj.decompile(item, font, reader.localState)
1876
+ item = obj
1877
+ elif self._converter is not None:
1878
+ item = self._converter.read(item, font)
1879
+
1880
+ items.append(item)
1881
+ lastOffset = offset
1882
+ return items
1883
+ else:
1884
+
1885
+ def get_read_item():
1886
+ reader_copy = reader.copy()
1887
+ offset_pos = reader.pos
1888
+ data_pos = offset_pos + (count + 1) * offSize - 1
1889
+ readArray = getReadArray(reader_copy, offSize)
1890
+
1891
+ def read_item(i):
1892
+ reader_copy.seek(offset_pos + i * offSize)
1893
+ offsets = readArray(2)
1894
+ reader_copy.seek(data_pos + offsets[0])
1895
+ item = reader_copy.readData(offsets[1] - offsets[0])
1896
+
1897
+ if self._itemClass is not None:
1898
+ obj = self._itemClass()
1899
+ obj.decompile(item, font, reader_copy.localState)
1900
+ item = obj
1901
+ elif self._converter is not None:
1902
+ item = self._converter.read(item, font)
1903
+ return item
1904
+
1905
+ return read_item
1906
+
1907
+ read_item = get_read_item()
1908
+ l = LazyList([read_item] * count)
1909
+
1910
+ # TODO: Advance reader
1911
+
1912
+ return l
1913
+
1914
+ def write(self, writer, font, tableDict, values, repeatIndex=None):
1915
+ items = values
1916
+
1917
+ writer.writeULong(len(items))
1918
+ if not len(items):
1919
+ return
1920
+
1921
+ if self._itemClass is not None:
1922
+ items = [item.compile(font) for item in items]
1923
+ elif self._converter is not None:
1924
+ items = [
1925
+ self._converter.write(writer, font, tableDict, item, i)
1926
+ for i, item in enumerate(items)
1927
+ ]
1928
+
1929
+ offsets = [len(item) for item in items]
1930
+ offsets = list(accumulate(offsets, initial=1))
1931
+
1932
+ lastOffset = offsets[-1]
1933
+ offSize = (
1934
+ 1
1935
+ if lastOffset < 0x100
1936
+ else 2 if lastOffset < 0x10000 else 3 if lastOffset < 0x1000000 else 4
1937
+ )
1938
+ writer.writeUInt8(offSize)
1939
+
1940
+ writeArray = {
1941
+ 1: writer.writeUInt8Array,
1942
+ 2: writer.writeUShortArray,
1943
+ 3: writer.writeUInt24Array,
1944
+ 4: writer.writeULongArray,
1945
+ }[offSize]
1946
+
1947
+ writeArray(offsets)
1948
+ for item in items:
1949
+ writer.writeData(item)
1950
+
1951
+ def xmlRead(self, attrs, content, font):
1952
+ if self._itemClass is not None:
1953
+ obj = self._itemClass()
1954
+ obj.fromXML(None, attrs, content, font)
1955
+ return obj
1956
+ elif self._converter is not None:
1957
+ return self._converter.xmlRead(attrs, content, font)
1958
+ else:
1959
+ raise NotImplementedError()
1960
+
1961
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1962
+ if self._itemClass is not None:
1963
+ for i, item in enumerate(value):
1964
+ item.toXML(xmlWriter, font, [("index", i)], name)
1965
+ elif self._converter is not None:
1966
+ for i, item in enumerate(value):
1967
+ self._converter.xmlWrite(
1968
+ xmlWriter, font, item, name, attrs + [("index", i)]
1969
+ )
1970
+ else:
1971
+ raise NotImplementedError()
1972
+
1973
+
1974
+ class LookupFlag(UShort):
1975
+ def xmlWrite(self, xmlWriter, font, value, name, attrs):
1976
+ xmlWriter.simpletag(name, attrs + [("value", value)])
1977
+ flags = []
1978
+ if value & 0x01:
1979
+ flags.append("rightToLeft")
1980
+ if value & 0x02:
1981
+ flags.append("ignoreBaseGlyphs")
1982
+ if value & 0x04:
1983
+ flags.append("ignoreLigatures")
1984
+ if value & 0x08:
1985
+ flags.append("ignoreMarks")
1986
+ if value & 0x10:
1987
+ flags.append("useMarkFilteringSet")
1988
+ if value & 0xFF00:
1989
+ flags.append("markAttachmentType[%i]" % (value >> 8))
1990
+ if flags:
1991
+ xmlWriter.comment(" ".join(flags))
1992
+ xmlWriter.newline()
1993
+
1994
+
1995
+ class _UInt8Enum(UInt8):
1996
+ enumClass = NotImplemented
1997
+
1998
+ def read(self, reader, font, tableDict):
1999
+ return self.enumClass(super().read(reader, font, tableDict))
2000
+
2001
+ @classmethod
2002
+ def fromString(cls, value):
2003
+ return getattr(cls.enumClass, value.upper())
2004
+
2005
+ @classmethod
2006
+ def toString(cls, value):
2007
+ return cls.enumClass(value).name.lower()
2008
+
2009
+
2010
+ class ExtendMode(_UInt8Enum):
2011
+ enumClass = _ExtendMode
2012
+
2013
+
2014
+ class CompositeMode(_UInt8Enum):
2015
+ enumClass = _CompositeMode
2016
+
2017
+
2018
+ converterMapping = {
2019
+ # type class
2020
+ "int8": Int8,
2021
+ "int16": Short,
2022
+ "int32": Long,
2023
+ "uint8": UInt8,
2024
+ "uint16": UShort,
2025
+ "uint24": UInt24,
2026
+ "uint32": ULong,
2027
+ "char64": Char64,
2028
+ "Flags32": Flags32,
2029
+ "VarIndex": VarIndex,
2030
+ "Version": Version,
2031
+ "Tag": Tag,
2032
+ "GlyphID": GlyphID,
2033
+ "GlyphID32": GlyphID32,
2034
+ "NameID": NameID,
2035
+ "DeciPoints": DeciPoints,
2036
+ "Fixed": Fixed,
2037
+ "F2Dot14": F2Dot14,
2038
+ "Angle": Angle,
2039
+ "BiasedAngle": BiasedAngle,
2040
+ "struct": Struct,
2041
+ "Offset": Table,
2042
+ "LOffset": LTable,
2043
+ "Offset24": Table24,
2044
+ "ValueRecord": ValueRecord,
2045
+ "DeltaValue": DeltaValue,
2046
+ "VarIdxMapValue": VarIdxMapValue,
2047
+ "VarDataValue": VarDataValue,
2048
+ "LookupFlag": LookupFlag,
2049
+ "ExtendMode": ExtendMode,
2050
+ "CompositeMode": CompositeMode,
2051
+ "STATFlags": STATFlags,
2052
+ "TupleList": partial(CFF2Index, itemConverterClass=TupleValues),
2053
+ "VarCompositeGlyphList": partial(CFF2Index, itemClass=VarCompositeGlyph),
2054
+ # AAT
2055
+ "CIDGlyphMap": CIDGlyphMap,
2056
+ "GlyphCIDMap": GlyphCIDMap,
2057
+ "MortChain": StructWithLength,
2058
+ "MortSubtable": StructWithLength,
2059
+ "MorxChain": StructWithLength,
2060
+ "MorxSubtable": MorxSubtableConverter,
2061
+ # "Template" types
2062
+ "AATLookup": lambda C: partial(AATLookup, tableClass=C),
2063
+ "AATLookupWithDataOffset": lambda C: partial(AATLookupWithDataOffset, tableClass=C),
2064
+ "STXHeader": lambda C: partial(STXHeader, tableClass=C),
2065
+ "OffsetTo": lambda C: partial(Table, tableClass=C),
2066
+ "LOffsetTo": lambda C: partial(LTable, tableClass=C),
2067
+ "LOffset24To": lambda C: partial(Table24, tableClass=C),
2068
+ }
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otData.py ADDED
The diff for this file is too large to render. See raw diff
 
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otTables.py ADDED
The diff for this file is too large to render. See raw diff
 
external/alphageometry/.venv-ag/Lib/site-packages/fontTools/ttLib/tables/otTraverse.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Methods for traversing trees of otData-driven OpenType tables."""
2
+
3
+ from collections import deque
4
+ from typing import Callable, Deque, Iterable, List, Optional, Tuple
5
+ from .otBase import BaseTable
6
+
7
+
8
+ __all__ = [
9
+ "bfs_base_table",
10
+ "dfs_base_table",
11
+ "SubTablePath",
12
+ ]
13
+
14
+
15
+ class SubTablePath(Tuple[BaseTable.SubTableEntry, ...]):
16
+ def __str__(self) -> str:
17
+ path_parts = []
18
+ for entry in self:
19
+ path_part = entry.name
20
+ if entry.index is not None:
21
+ path_part += f"[{entry.index}]"
22
+ path_parts.append(path_part)
23
+ return ".".join(path_parts)
24
+
25
+
26
+ # Given f(current frontier, new entries) add new entries to frontier
27
+ AddToFrontierFn = Callable[[Deque[SubTablePath], List[SubTablePath]], None]
28
+
29
+
30
+ def dfs_base_table(
31
+ root: BaseTable,
32
+ root_accessor: Optional[str] = None,
33
+ skip_root: bool = False,
34
+ predicate: Optional[Callable[[SubTablePath], bool]] = None,
35
+ iter_subtables_fn: Optional[
36
+ Callable[[BaseTable], Iterable[BaseTable.SubTableEntry]]
37
+ ] = None,
38
+ ) -> Iterable[SubTablePath]:
39
+ """Depth-first search tree of BaseTables.
40
+
41
+ Args:
42
+ root (BaseTable): the root of the tree.
43
+ root_accessor (Optional[str]): attribute name for the root table, if any (mostly
44
+ useful for debugging).
45
+ skip_root (Optional[bool]): if True, the root itself is not visited, only its
46
+ children.
47
+ predicate (Optional[Callable[[SubTablePath], bool]]): function to filter out
48
+ paths. If True, the path is yielded and its subtables are added to the
49
+ queue. If False, the path is skipped and its subtables are not traversed.
50
+ iter_subtables_fn (Optional[Callable[[BaseTable], Iterable[BaseTable.SubTableEntry]]]):
51
+ function to iterate over subtables of a table. If None, the default
52
+ BaseTable.iterSubTables() is used.
53
+
54
+ Yields:
55
+ SubTablePath: tuples of BaseTable.SubTableEntry(name, table, index) namedtuples
56
+ for each of the nodes in the tree. The last entry in a path is the current
57
+ subtable, whereas preceding ones refer to its parent tables all the way up to
58
+ the root.
59
+ """
60
+ yield from _traverse_ot_data(
61
+ root,
62
+ root_accessor,
63
+ skip_root,
64
+ predicate,
65
+ lambda frontier, new: frontier.extendleft(reversed(new)),
66
+ iter_subtables_fn,
67
+ )
68
+
69
+
70
+ def bfs_base_table(
71
+ root: BaseTable,
72
+ root_accessor: Optional[str] = None,
73
+ skip_root: bool = False,
74
+ predicate: Optional[Callable[[SubTablePath], bool]] = None,
75
+ iter_subtables_fn: Optional[
76
+ Callable[[BaseTable], Iterable[BaseTable.SubTableEntry]]
77
+ ] = None,
78
+ ) -> Iterable[SubTablePath]:
79
+ """Breadth-first search tree of BaseTables.
80
+
81
+ Args:
82
+ root
83
+ the root of the tree.
84
+ root_accessor (Optional[str]): attribute name for the root table, if any (mostly
85
+ useful for debugging).
86
+ skip_root (Optional[bool]): if True, the root itself is not visited, only its
87
+ children.
88
+ predicate (Optional[Callable[[SubTablePath], bool]]): function to filter out
89
+ paths. If True, the path is yielded and its subtables are added to the
90
+ queue. If False, the path is skipped and its subtables are not traversed.
91
+ iter_subtables_fn (Optional[Callable[[BaseTable], Iterable[BaseTable.SubTableEntry]]]):
92
+ function to iterate over subtables of a table. If None, the default
93
+ BaseTable.iterSubTables() is used.
94
+
95
+ Yields:
96
+ SubTablePath: tuples of BaseTable.SubTableEntry(name, table, index) namedtuples
97
+ for each of the nodes in the tree. The last entry in a path is the current
98
+ subtable, whereas preceding ones refer to its parent tables all the way up to
99
+ the root.
100
+ """
101
+ yield from _traverse_ot_data(
102
+ root,
103
+ root_accessor,
104
+ skip_root,
105
+ predicate,
106
+ lambda frontier, new: frontier.extend(new),
107
+ iter_subtables_fn,
108
+ )
109
+
110
+
111
+ def _traverse_ot_data(
112
+ root: BaseTable,
113
+ root_accessor: Optional[str],
114
+ skip_root: bool,
115
+ predicate: Optional[Callable[[SubTablePath], bool]],
116
+ add_to_frontier_fn: AddToFrontierFn,
117
+ iter_subtables_fn: Optional[
118
+ Callable[[BaseTable], Iterable[BaseTable.SubTableEntry]]
119
+ ] = None,
120
+ ) -> Iterable[SubTablePath]:
121
+ # no visited because general otData cannot cycle (forward-offset only)
122
+ if root_accessor is None:
123
+ root_accessor = type(root).__name__
124
+
125
+ if predicate is None:
126
+
127
+ def predicate(path):
128
+ return True
129
+
130
+ if iter_subtables_fn is None:
131
+
132
+ def iter_subtables_fn(table):
133
+ return table.iterSubTables()
134
+
135
+ frontier: Deque[SubTablePath] = deque()
136
+
137
+ root_entry = BaseTable.SubTableEntry(root_accessor, root)
138
+ if not skip_root:
139
+ frontier.append((root_entry,))
140
+ else:
141
+ add_to_frontier_fn(
142
+ frontier,
143
+ [
144
+ (root_entry, subtable_entry)
145
+ for subtable_entry in iter_subtables_fn(root)
146
+ ],
147
+ )
148
+
149
+ while frontier:
150
+ # path is (value, attr_name) tuples. attr_name is attr of parent to get value
151
+ path = frontier.popleft()
152
+ current = path[-1].value
153
+
154
+ if not predicate(path):
155
+ continue
156
+
157
+ yield SubTablePath(path)
158
+
159
+ new_entries = [
160
+ path + (subtable_entry,) for subtable_entry in iter_subtables_fn(current)
161
+ ]
162
+
163
+ add_to_frontier_fn(frontier, new_entries)
external/alphageometry/.venv-ag/Lib/site-packages/jax/_src/scipy/special.py ADDED
@@ -0,0 +1,2574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from functools import partial
18
+ import operator
19
+ from typing import cast, Any
20
+
21
+ import numpy as np
22
+
23
+ import jax.numpy as jnp
24
+ from jax import jit
25
+ from jax import jvp
26
+ from jax import vmap
27
+ from jax import lax
28
+
29
+ from jax._src import core
30
+ from jax._src import custom_derivatives
31
+ from jax._src import dtypes
32
+ from jax._src.lax.lax import _const as _lax_const
33
+ from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
34
+ from jax._src.ops import special as ops_special
35
+ from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
36
+ from jax._src.typing import Array, ArrayLike
37
+
38
+
39
+ def gammaln(x: ArrayLike) -> Array:
40
+ r"""Natural log of the absolute value of the gamma function.
41
+
42
+ JAX implementation of :obj:`scipy.special.gammaln`.
43
+
44
+ .. math::
45
+
46
+ \mathrm{gammaln}(x) = \log(|\Gamma(x)|)
47
+
48
+ Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
49
+
50
+ Args:
51
+ x: arraylike, real valued.
52
+
53
+ Returns:
54
+ array containing the values of the log-gamma function
55
+
56
+ See Also:
57
+ - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
58
+ - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function
59
+
60
+ Notes:
61
+ ``gammaln`` does not support complex-valued inputs.
62
+ """
63
+ x, = promote_args_inexact("gammaln", x)
64
+ return lax.lgamma(x)
65
+
66
+
67
+ def gammasgn(x: ArrayLike) -> Array:
68
+ r"""Sign of the gamma function.
69
+
70
+ JAX implementation of :obj:`scipy.special.gammasgn`.
71
+
72
+ .. math::
73
+
74
+ \mathrm{gammasgn}(x) = \begin{cases}
75
+ +1 & \Gamma(x) > 0 \\
76
+ -1 & \Gamma(x) < 0
77
+ \end{cases}
78
+
79
+ Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
80
+ Because :math:`\Gamma(x)` is never zero, no condition is required for this case.
81
+
82
+ Args:
83
+ x: arraylike, real valued.
84
+
85
+ Returns:
86
+ array containing the sign of the gamma function
87
+
88
+ See Also:
89
+ - :func:`jax.scipy.special.gamma`: the gamma function
90
+ - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
91
+ """
92
+ x, = promote_args_inexact("gammasgn", x)
93
+ floor_x = lax.floor(x)
94
+ return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)
95
+
96
+
97
+ def gamma(x: ArrayLike) -> Array:
98
+ r"""The gamma function.
99
+
100
+ JAX implementation of :obj:`scipy.special.gamma`.
101
+
102
+ The gamma function is defined for :math:`\Re(z)>0` as
103
+
104
+ .. math::
105
+
106
+ \mathrm{gamma}(z) = \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}\mathrm{d}t
107
+
108
+ and is extended by analytic continuation to arbitrary complex values `z`.
109
+ For positive integers `n`, the gamma function is related to the
110
+ :func:`~jax.scipy.special.factorial` function via the following identity:
111
+
112
+ .. math::
113
+
114
+ \Gamma(n) = (n - 1)!
115
+
116
+ Args:
117
+ x: arraylike, real valued.
118
+
119
+ Returns:
120
+ array containing the values of the gamma function
121
+
122
+ See Also:
123
+ - :func:`jax.scipy.special.factorial`: the factorial function.
124
+ - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
125
+ - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function
126
+
127
+ Notes:
128
+ Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs.
129
+ """
130
+ x, = promote_args_inexact("gamma", x)
131
+ return gammasgn(x) * lax.exp(lax.lgamma(x))
132
+
133
+
134
+ def betaln(a: ArrayLike, b: ArrayLike) -> Array:
135
+ r"""Natural log of the absolute value of the beta function
136
+
137
+ JAX implementation of :obj:`scipy.special.betaln`.
138
+
139
+ .. math::
140
+
141
+ \mathrm{betaln}(a, b) = \log B(a, b)
142
+
143
+ where :math:`B` is the :func:`~jax.scipy.special.beta` function.
144
+
145
+ Args:
146
+ a: arraylike, real-valued. Parameter *a* of the beta distribution.
147
+ b: arraylike, real-valued. Parameter *b* of the beta distribution.
148
+
149
+ Returns:
150
+ array containing the values of the log-beta function
151
+
152
+ See Also:
153
+ :func:`jax.scipy.special.beta`
154
+ """
155
+ a, b = promote_args_inexact("betaln", a, b)
156
+ return _betaln_impl(a, b)
157
+
158
+
159
+ def factorial(n: ArrayLike, exact: bool = False) -> Array:
160
+ r"""Factorial function
161
+
162
+ JAX implementation of :obj:`scipy.special.factorial`
163
+
164
+ .. math::
165
+
166
+ \mathrm{factorial}(n) = n! = \prod_{k=1}^n k
167
+
168
+ Args:
169
+ n: arraylike, values for which factorial will be computed elementwise
170
+ exact: bool, only ``exact=False`` is supported.
171
+
172
+ Returns:
173
+ array containing values of the factorial.
174
+
175
+ Notes:
176
+ This computes the float-valued factorial via the :func:`~jax.scipy.special.gamma`
177
+ function. JAX does not support exact factorials, because it is not particularly
178
+ useful: above ``n=20``, the exact result cannot be represented by 64-bit integers,
179
+ which are the largest integers available to JAX.
180
+
181
+ See Also:
182
+ :func:`jax.scipy.special.gamma`
183
+ """
184
+ if exact:
185
+ raise NotImplementedError("factorial with exact=True")
186
+ n, = promote_args_inexact("factorial", n)
187
+ return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1)))
188
+
189
+
190
+ def beta(x: ArrayLike, y: ArrayLike) -> Array:
191
+ r"""The beta function
192
+
193
+ JAX implementation of :obj:`scipy.special.beta`.
194
+
195
+ .. math::
196
+
197
+ \mathrm{beta}(a, b) = B(a, b) = \frac{\Gamma(a)\Gamma(b)}{\Gamma(a + b)}
198
+
199
+ where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
200
+
201
+ Args:
202
+ a: arraylike, real-valued. Parameter *a* of the beta distribution.
203
+ b: arraylike, real-valued. Parameter *b* of the beta distribution.
204
+
205
+ Returns:
206
+ array containing the values of the beta function.
207
+
208
+ See Also:
209
+ - :func:`jax.scipy.special.gamma`
210
+ - :func:`jax.scipy.special.betaln`
211
+ """
212
+ x, y = promote_args_inexact("beta", x, y)
213
+ sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y)
214
+ return sign * lax.exp(betaln(x, y))
215
+
216
+
217
+ def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
218
+ r"""The regularized incomplete beta function.
219
+
220
+ JAX implementation of :obj:`scipy.special.betainc`.
221
+
222
+ .. math::
223
+
224
+ \mathrm{betainc}(a, b, x) = B(a, b)\int_0^x t^{a-1}(1-t^{b-1})\mathrm{d}t
225
+
226
+ where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function.
227
+
228
+ Args:
229
+ a: arraylike, real-valued. Parameter *a* of the beta distribution.
230
+ b: arraylike, real-valued. Parameter *b* of the beta distribution.
231
+ x: arraylike, real-valued. Upper limit of the integration.
232
+
233
+ Returns:
234
+ array containing values of the betainc function
235
+
236
+ See Also:
237
+ - :func:`jax.scipy.special.beta`
238
+ - :func:`jax.scipy.special.betaln`
239
+ """
240
+ a, b, x = promote_args_inexact("betainc", a, b, x)
241
+ return lax.betainc(a, b, x)
242
+
243
+
244
+ def digamma(x: ArrayLike) -> Array:
245
+ r"""The digamma function
246
+
247
+ JAX implementation of :obj:`scipy.special.digamma`.
248
+
249
+ .. math::
250
+
251
+ \mathrm{digamma}(z) = \psi(z) = \frac{\mathrm{d}}{\mathrm{d}z}\log \Gamma(z)
252
+
253
+ where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function.
254
+
255
+ Args:
256
+ x: arraylike, real-valued.
257
+
258
+ Returns:
259
+ array containing values of the digamma function.
260
+
261
+ Notes:
262
+ The JAX version of `digamma` accepts real-valued inputs.
263
+
264
+ See also:
265
+ - :func:`jax.scipy.special.gamma`
266
+ - :func:`jax.scipy.special.polygamma`
267
+ """
268
+ x, = promote_args_inexact("digamma", x)
269
+ return lax.digamma(x)
270
+
271
+
272
+ def gammainc(a: ArrayLike, x: ArrayLike) -> Array:
273
+ r"""The regularized lower incomplete gamma function.
274
+
275
+ JAX implementation of :obj:`scipy.special.gammainc`.
276
+
277
+ .. math::
278
+
279
+ \mathrm{gammainc}(x; a) = \frac{1}{\Gamma(a)}\int_0^x t^{a-1}e^{-t}\mathrm{d}t
280
+
281
+ where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
282
+
283
+ Args:
284
+ a: arraylike, real-valued. Positive shape parameter of the gamma distribution.
285
+ x: arraylike, real-valued. Non-negative upper limit of integration
286
+
287
+ Returns:
288
+ array containing values of the gammainc function.
289
+
290
+ See Also:
291
+ - :func:`jax.scipy.special.gamma`
292
+ - :func:`jax.scipy.special.gammaincc`
293
+ """
294
+ a, x = promote_args_inexact("gammainc", a, x)
295
+ return lax.igamma(a, x)
296
+
297
+
298
+ def gammaincc(a: ArrayLike, x: ArrayLike) -> Array:
299
+ r"""The regularized upper incomplete gamma function.
300
+
301
+ JAX implementation of :obj:`scipy.special.gammaincc`.
302
+
303
+ .. math::
304
+
305
+ \mathrm{gammaincc}(x; a) = \frac{1}{\Gamma(a)}\int_x^\infty t^{a-1}e^{-t}\mathrm{d}t
306
+
307
+ where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
308
+
309
+ Args:
310
+ a: arraylike, real-valued. Positive shape parameter of the gamma distribution.
311
+ x: arraylike, real-valued. Non-negative lower limit of integration
312
+
313
+ Returns:
314
+ array containing values of the gammaincc function.
315
+
316
+ See Also:
317
+ - :func:`jax.scipy.special.gamma`
318
+ - :func:`jax.scipy.special.gammainc`
319
+ """
320
+ a, x = promote_args_inexact("gammaincc", a, x)
321
+ return lax.igammac(a, x)
322
+
323
+
324
+ def erf(x: ArrayLike) -> Array:
325
+ r"""The error function
326
+
327
+ JAX implementation of :obj:`scipy.special.erf`.
328
+
329
+ .. math::
330
+
331
+ \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t
332
+
333
+ Args:
334
+ x: arraylike, real-valued.
335
+
336
+ Returns:
337
+ array containing values of the error function.
338
+
339
+ Notes:
340
+ The JAX version only supports real-valued inputs.
341
+
342
+ See also:
343
+ - :func:`jax.scipy.special.erfc`
344
+ - :func:`jax.scipy.special.erfinv`
345
+ """
346
+ x, = promote_args_inexact("erf", x)
347
+ return lax.erf(x)
348
+
349
+
350
+ def erfc(x: ArrayLike) -> Array:
351
+ r"""The complement of the error function
352
+
353
+ JAX implementation of :obj:`scipy.special.erfc`.
354
+
355
+ .. math::
356
+
357
+ \mathrm{erfc}(x) = \frac{2}{\sqrt\pi} \int_{x}^\infty e^{-t^2} \mathrm{d}t
358
+
359
+ This is the complement of the error function :func:`~jax.scipy.special.erf`,
360
+ ``erfc(x) = 1 - erf(x)``.
361
+
362
+ Args:
363
+ x: arraylike, real-valued.
364
+
365
+ Returns:
366
+ array containing values of the complement of the error function.
367
+
368
+ Notes:
369
+ The JAX version only supports real-valued inputs.
370
+
371
+ See also:
372
+ - :func:`jax.scipy.special.erf`
373
+ - :func:`jax.scipy.special.erfinv`
374
+ """
375
+ x, = promote_args_inexact("erfc", x)
376
+ return lax.erfc(x)
377
+
378
+
379
+ def erfinv(x: ArrayLike) -> Array:
380
+ """The inverse of the error function
381
+
382
+ JAX implementation of :obj:`scipy.special.erfinv`.
383
+
384
+ Returns the inverse of :func:`~jax.scipy.special.erf`.
385
+
386
+ Args:
387
+ x: arraylike, real-valued.
388
+
389
+ Returns:
390
+ array containing values of the inverse error function.
391
+
392
+ Notes:
393
+ The JAX version only supports real-valued inputs.
394
+
395
+ See also:
396
+ - :func:`jax.scipy.special.erf`
397
+ - :func:`jax.scipy.special.erfc`
398
+ """
399
+ x, = promote_args_inexact("erfinv", x)
400
+ return lax.erf_inv(x)
401
+
402
+
403
+ @custom_derivatives.custom_jvp
404
+ def logit(x: ArrayLike) -> Array:
405
+ r"""The logit function
406
+
407
+ JAX implementation of :obj:`scipy.special.logit`.
408
+
409
+ .. math::
410
+
411
+ \mathrm{logit}(p) = \log\frac{p}{1 - p}
412
+
413
+ Args:
414
+ x: arraylike, real-valued.
415
+
416
+ Returns:
417
+ array containing values of the logit function.
418
+ """
419
+ x, = promote_args_inexact("logit", x)
420
+ return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
421
+ logit.defjvps(
422
+ lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
423
+
424
+
425
+ def expit(x: ArrayLike) -> Array:
426
+ r"""The logistic sigmoid (expit) function
427
+
428
+ JAX implementation of :obj:`scipy.special.expit`.
429
+
430
+ .. math::
431
+
432
+ \mathrm{expit}(x) = \frac{1}{1 + e^{-x}}
433
+
434
+ Args:
435
+ x: arraylike, real-valued.
436
+
437
+ Returns:
438
+ array containing values of the expit function.
439
+ """
440
+ x, = promote_args_inexact("expit", x)
441
+ return lax.logistic(x)
442
+
443
+
444
+ logsumexp = ops_special.logsumexp
445
+
446
+
447
+ @custom_derivatives.custom_jvp
448
+ def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
449
+ """Compute x*log(y), returning 0 for x=0.
450
+
451
+ JAX implementation of :obj:`scipy.special.xlogy`.
452
+
453
+ This is defined to return zero when :math:`(x, y) = (0, 0)`, with a custom
454
+ derivative rule so that automatic differentiation is well-defined at this point.
455
+
456
+ Args:
457
+ x: arraylike, real-valued.
458
+ y: arraylike, real-valued.
459
+
460
+ Returns:
461
+ array containing xlogy values.
462
+
463
+ See also:
464
+ :func:`jax.scipy.special.xlog1py`
465
+ """
466
+ # Note: xlogy(0, 0) should return 0 according to the function documentation.
467
+ x, y = promote_args_inexact("xlogy", x, y)
468
+ x_ok = x != 0.
469
+ return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x))
470
+
471
+ def _xlogy_jvp(primals, tangents):
472
+ (x, y) = primals
473
+ (x_dot, y_dot) = tangents
474
+ result = xlogy(x, y)
475
+ return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
476
+ xlogy.defjvp(_xlogy_jvp)
477
+
478
+
479
+ @custom_derivatives.custom_jvp
480
+ def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
481
+ """Compute x*log(1 + y), returning 0 for x=0.
482
+
483
+ JAX implementation of :obj:`scipy.special.xlog1py`.
484
+
485
+ This is defined to return 0 when :math:`(x, y) = (0, -1)`, with a custom
486
+ derivative rule so that automatic differentiation is well-defined at this point.
487
+
488
+ Args:
489
+ x: arraylike, real-valued.
490
+ y: arraylike, real-valued.
491
+
492
+ Returns:
493
+ array containing xlog1py values.
494
+
495
+ See also:
496
+ :func:`jax.scipy.special.xlogy`
497
+ """
498
+ # Note: xlog1py(0, -1) should return 0 according to the function documentation.
499
+ x, y = promote_args_inexact("xlog1py", x, y)
500
+ x_ok = x != 0.
501
+ return jnp.where(x_ok, lax.mul(x, lax.log1p(y)), jnp.zeros_like(x))
502
+
503
+ def _xlog1py_jvp(primals, tangents):
504
+ (x, y) = primals
505
+ (x_dot, y_dot) = tangents
506
+ result = xlog1py(x, y)
507
+ return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)
508
+ xlog1py.defjvp(_xlog1py_jvp)
509
+
510
+ @custom_derivatives.custom_jvp
511
+ def _xlogx(x):
512
+ """Compute x log(x) with well-defined derivatives."""
513
+ return xlogy(x, x)
514
+
515
+ def _xlogx_jvp(primals, tangents):
516
+ x, = primals
517
+ x_dot, = tangents
518
+ return _xlogx(x), x_dot * (lax.log(x) + 1)
519
+ _xlogx.defjvp(_xlogx_jvp)
520
+
521
+
522
+ def entr(x: ArrayLike) -> Array:
523
+ r"""The entropy function
524
+
525
+ JAX implementation of :obj:`scipy.special.entr`.
526
+
527
+ .. math::
528
+
529
+ \mathrm{entr}(x) = \begin{cases}
530
+ -x\log(x) & x > 0 \\
531
+ 0 & x = 0\\
532
+ -\infty & x > 0
533
+ \end{cases}
534
+
535
+ Args:
536
+ x: arraylike, real-valued.
537
+
538
+ Returns:
539
+ array containing entropy values.
540
+
541
+ See also:
542
+ - :func:`jax.scipy.special.kl_div`
543
+ - :func:`jax.scipy.special.rel_entr`
544
+ """
545
+ x, = promote_args_inexact("entr", x)
546
+ return lax.select(lax.lt(x, _lax_const(x, 0)),
547
+ lax.full_like(x, -np.inf),
548
+ lax.neg(_xlogx(x)))
549
+
550
+
551
+ def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
552
+ r"""The natural log of the multivariate gamma function.
553
+
554
+ JAX implementation of :func:`scipy.special.multigammaln`.
555
+
556
+ .. math::
557
+
558
+ \mathrm{multigammaln}(a, d) = \log\Gamma_d(a)
559
+
560
+ where
561
+
562
+ .. math::
563
+
564
+ \Gamma_d(a) = \pi^{d(d-1)/4}\prod_{i=1}^d\Gamma(a-(i-1)/2)
565
+
566
+ and :math:`\Gamma(x)` is the :func:`~jax.scipy.special.gamma` function.
567
+
568
+ Args:
569
+ a: arraylike, real-valued.
570
+ d: int, the dimension of the integration space.
571
+
572
+ Returns:
573
+ array containing values of the log-multigamma function.
574
+
575
+ See also:
576
+ - :func:`jax.scipy.special.gamma`
577
+ """
578
+ d = core.concrete_or_error(int, d, "d argument of multigammaln")
579
+ a, d_ = promote_args_inexact("multigammaln", a, d)
580
+
581
+ constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_),
582
+ lax.sub(d_, _lax_const(a, 1))),
583
+ lax.log(_lax_const(a, np.pi)))
584
+ b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
585
+ res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
586
+ jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
587
+ axis=-1)
588
+ return res + constant
589
+
590
+
591
+ def kl_div(
592
+ p: ArrayLike,
593
+ q: ArrayLike,
594
+ ) -> Array:
595
+ r"""The Kullback-Leibler divergence.
596
+
597
+ JAX implementation of :obj:`scipy.special.kl_div`.
598
+
599
+ .. math::
600
+
601
+ \mathrm{kl\_div}(p, q) = \begin{cases}
602
+ p\log(p/q)-p+q & p>0,q>0\\
603
+ q & p=0,q\ge 0\\
604
+ \infty & \mathrm{otherwise}
605
+ \end{cases}
606
+
607
+ Args:
608
+ p: arraylike, real-valued.
609
+ q: arraylike, real-valued.
610
+
611
+ Returns:
612
+ array of KL-divergence values
613
+
614
+ See also:
615
+ - :func:`jax.scipy.special.entr`
616
+ - :func:`jax.scipy.special.rel_entr`
617
+ """
618
+ p, q = promote_args_inexact("kl_div", p, q)
619
+ zero = _lax_const(p, 0.0)
620
+ both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero))
621
+ one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero))
622
+
623
+ safe_p = jnp.where(both_gt_zero_mask, p, 1)
624
+ safe_q = jnp.where(both_gt_zero_mask, q, 1)
625
+
626
+ log_val = lax.sub(
627
+ lax.add(
628
+ lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)),
629
+ safe_q,
630
+ ),
631
+ safe_p,
632
+ )
633
+ result = jnp.where(
634
+ both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, np.inf)
635
+ )
636
+ return result
637
+
638
+
639
+ def rel_entr(
640
+ p: ArrayLike,
641
+ q: ArrayLike,
642
+ ) -> Array:
643
+ r"""The relative entropy function.
644
+
645
+ JAX implementation of :obj:`scipy.special.rel_entr`.
646
+
647
+ .. math::
648
+
649
+ \mathrm{rel\_entr}(p, q) = \begin{cases}
650
+ p\log(p/q) & p>0,q>0\\
651
+ 0 & p=0,q\ge 0\\
652
+ \infty & \mathrm{otherwise}
653
+ \end{cases}
654
+
655
+ Args:
656
+ p: arraylike, real-valued.
657
+ q: arraylike, real-valued.
658
+
659
+ Returns:
660
+ array of relative entropy values.
661
+
662
+ See also:
663
+ - :func:`jax.scipy.special.entr`
664
+ - :func:`jax.scipy.special.kl_div`
665
+ """
666
+ p, q = promote_args_inexact("rel_entr", p, q)
667
+ zero = _lax_const(p, 0.0)
668
+ both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero))
669
+ one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero))
670
+
671
+ safe_p = jnp.where(both_gt_zero_mask, p, 1)
672
+ safe_q = jnp.where(both_gt_zero_mask, q, 1)
673
+ log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q))
674
+ result = jnp.where(
675
+ both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf)
676
+ )
677
+ return result
678
+
679
+ # coefs of (2k)! / B_{2k} where B are bernoulli numbers
680
+ # those numbers are obtained using https://www.wolframalpha.com
681
+ _BERNOULLI_COEFS = [
682
+ 12,
683
+ -720,
684
+ 30240,
685
+ -1209600,
686
+ 47900160,
687
+ -1307674368000 / 691,
688
+ 74724249600,
689
+ -10670622842880000 / 3617,
690
+ 5109094217170944000 / 43867,
691
+ -802857662698291200000 / 174611,
692
+ 14101100039391805440000 / 77683,
693
+ -1693824136731743669452800000 / 236364091,
694
+ 186134520519971831808000000 / 657931,
695
+ -37893265687455865519472640000000 / 3392780147,
696
+ 759790291646040068357842010112000000 / 1723168255201,
697
+ -134196726836183700385281186201600000000 / 7709321041217,
698
+ ]
699
+
700
+
701
+ @custom_derivatives.custom_jvp
702
+ def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array:
703
+ r"""The Hurwitz zeta function.
704
+
705
+ JAX implementation of :func:`scipy.special.zeta`. JAX does not implement
706
+ the Riemann zeta function (i.e. ``q = None``).
707
+
708
+ .. math::
709
+
710
+ \zeta(x, q) = \sum_{n=0}^\infty \frac{1}{(n + q)^x}
711
+
712
+ Args:
713
+ x: arraylike, real-valued
714
+ q: arraylike, real-valued
715
+
716
+ Returns:
717
+ array of zeta function values
718
+ """
719
+ if q is None:
720
+ raise NotImplementedError(
721
+ "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.")
722
+ x, q = promote_args_inexact("zeta", x, q)
723
+ return lax.zeta(x, q)
724
+
725
+
726
+ # There is no general closed-form derivative for the zeta function, so we compute
727
+ # derivatives via a series expansion
728
+ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
729
+ if q is None:
730
+ raise NotImplementedError(
731
+ "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.")
732
+ # Reference: Johansson, Fredrik.
733
+ # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
734
+ # Numerical Algorithms 69.2 (2015): 253-270.
735
+ # https://arxiv.org/abs/1309.2877 - formula (5)
736
+ # here we keep the same notation as in reference
737
+ s, a = promote_args_inexact("zeta", x, q)
738
+ dtype = lax.dtype(a).type
739
+ s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
740
+ # precision ~ N, M
741
+ N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
742
+ assert M <= len(_BERNOULLI_COEFS)
743
+ k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
744
+ S = jnp.sum((a_ + k) ** -s_, -1)
745
+ I = lax.div((a + N) ** (dtype(1) - s), s - dtype(1))
746
+ T0 = (a + N) ** -s
747
+ m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
748
+ s_over_a = (s_ + m) / (a_ + N)
749
+ T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
750
+ T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
751
+ coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
752
+ tuple(range(a.ndim)))
753
+ T1 = T1 / coefs
754
+ T = T0 * (dtype(0.5) + T1.sum(-1))
755
+ return S + I + T
756
+
757
+ zeta.defjvp(partial(jvp, _zeta_series_expansion))
758
+
759
+
760
+ def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
761
+ r"""The polygamma function.
762
+
763
+ JAX implementation of :func:`scipy.special.polygamma`.
764
+
765
+ .. math::
766
+
767
+ \mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^n}{\mathrm{d}x^n}\log \Gamma(x)
768
+
769
+ where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
770
+
771
+ Args:
772
+ n: arraylike, integer-valued. The order of the derivative.
773
+ x: arraylike, real-valued. The value at which to evaluate the function.
774
+
775
+ Returns:
776
+ array
777
+
778
+ See also:
779
+ - :func:`jax.scipy.special.gamma`
780
+ - :func:`jax.scipy.special.digamma`
781
+ """
782
+ assert jnp.issubdtype(lax.dtype(n), jnp.integer)
783
+ n_arr, x_arr = promote_args_inexact("polygamma", n, x)
784
+ return lax.polygamma(n_arr, x_arr)
785
+
786
+
787
+ # Normal distributions
788
+
789
+ # Functions "ndtr" and "ndtri" are derived from calculations made in:
790
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
791
+ # The "spence" function is also based on the Cephes library with
792
+ # the corresponding spence.c file located in the tarball:
793
+ # https://netlib.org/cephes/misc.tgz
794
+ # In the following email exchange, the author gives his consent to redistribute
795
+ # derived works under an Apache 2.0 license.
796
+ #
797
+ # From: Stephen Moshier <steve@moshier.net>
798
+ # Date: Sat, Jun 9, 2018 at 2:36 PM
799
+ # Subject: Re: Licensing cephes under Apache (BSD-like) license.
800
+ # To: rif <rif@google.com>
801
+ #
802
+ #
803
+ #
804
+ # Hello Rif,
805
+ #
806
+ # Yes, Google may distribute Cephes files under the Apache 2 license.
807
+ #
808
+ # If clarification is needed, I do not favor BSD over other free licenses.
809
+ # I would agree that Apache 2 seems to cover the concern you mentioned
810
+ # about sublicensees.
811
+ #
812
+ # Best wishes for good luck with your projects!
813
+ # Steve Moshier
814
+ #
815
+ #
816
+ #
817
+ # On Thu, 31 May 2018, rif wrote:
818
+ #
819
+ # > Hello Steve.
820
+ # > My name is Rif. I work on machine learning software at Google.
821
+ # >
822
+ # > Your cephes software continues to be incredibly useful and widely used. I
823
+ # > was wondering whether it would be permissible for us to use the Cephes code
824
+ # > under the Apache 2.0 license, which is extremely similar in permissions to
825
+ # > the BSD license (Wikipedia comparisons). This would be quite helpful to us
826
+ # > in terms of avoiding multiple licenses on software.
827
+ # >
828
+ # > I'm sorry to bother you with this (I can imagine you're sick of hearing
829
+ # > about this by now), but I want to be absolutely clear we're on the level and
830
+ # > not misusing your important software. In former conversation with Eugene
831
+ # > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD,
832
+ # > the formal way that has been handled is simply to add a statement to the
833
+ # > effect that you are incorporating the Cephes software by permission of the
834
+ # > author." I wanted to confirm that (a) we could use the Apache license, (b)
835
+ # > that we don't need to (and probably you don't want to) keep getting
836
+ # > contacted about individual uses, because your intent is generally to allow
837
+ # > this software to be reused under "BSD-like" license, and (c) you're OK
838
+ # > letting incorporators decide whether a license is sufficiently BSD-like?
839
+ # >
840
+ # > Best,
841
+ # >
842
+ # > rif
843
+ # >
844
+ # >
845
+ # >
846
+
847
+ # log_ndtr uses different functions over the ranges
848
+ # (-infty, lower](lower, upper](upper, infty)
849
+ # Lower bound values were chosen by examining where the support of ndtr
850
+ # appears to be zero, relative to scipy's (which is always 64bit). They were
851
+ # then made more conservative just to be safe. (Conservative means use the
852
+ # expansion more than we probably need to.)
853
+ _LOGNDTR_FLOAT64_LOWER = np.array(-20, np.float64)
854
+ _LOGNDTR_FLOAT32_LOWER = np.array(-10, np.float32)
855
+
856
+ # Upper bound values were chosen by examining for which values of 'x'
857
+ # Log[cdf(x)] is 0, after which point we need to use the approximation
858
+ # Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly
859
+ # conservative, meaning we use the approximation earlier than needed.
860
+ _LOGNDTR_FLOAT64_UPPER = np.array(8, np.float64)
861
+ _LOGNDTR_FLOAT32_UPPER = np.array(5, np.float32)
862
+
863
+
864
+ def ndtr(x: ArrayLike) -> Array:
865
+ r"""Normal distribution function.
866
+
867
+ JAX implementation of :obj:`scipy.special.ndtr`.
868
+
869
+ Returns the area under the Gaussian probability density function, integrated
870
+ from minus infinity to x:
871
+
872
+ .. math::
873
+ \begin{align}
874
+ \mathrm{ndtr}(x) =&
875
+ \ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\
876
+ =&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\
877
+ =&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}})
878
+ \end{align}
879
+
880
+ Args:
881
+ x: An array of type `float32`, `float64`.
882
+
883
+ Returns:
884
+ An array with `dtype=x.dtype`.
885
+
886
+ Raises:
887
+ TypeError: if `x` is not floating-type.
888
+ """
889
+ x = jnp.asarray(x)
890
+ dtype = lax.dtype(x)
891
+ if dtype not in (jnp.float32, jnp.float64):
892
+ raise TypeError(
893
+ "x.dtype={} is not supported, see docstring for supported types."
894
+ .format(dtype))
895
+ return _ndtr(x)
896
+
897
+
898
+ def _ndtr(x: ArrayLike) -> Array:
899
+ """Implements ndtr core logic."""
900
+ dtype = lax.dtype(x).type
901
+ half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
902
+ w = x * half_sqrt_2
903
+ z = lax.abs(w)
904
+ y = lax.select(lax.lt(z, half_sqrt_2),
905
+ dtype(1.) + lax.erf(w),
906
+ lax.select(lax.gt(w, dtype(0.)),
907
+ dtype(2.) - lax.erfc(z),
908
+ lax.erfc(z)))
909
+ return dtype(0.5) * y
910
+
911
+
912
+ def ndtri(p: ArrayLike) -> Array:
913
+ r"""The inverse of the CDF of the Normal distribution function.
914
+
915
+ JAX implementation of :obj:`scipy.special.ndtri`.
916
+
917
+ Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal
918
+ to `p`.
919
+
920
+ A piece-wise rational approximation is done for the function.
921
+ This is based on the implementation in netlib.
922
+
923
+ Args:
924
+ p: an array of type `float32`, `float64`.
925
+
926
+ Returns:
927
+ an array with `dtype=p.dtype`.
928
+
929
+ Raises:
930
+ TypeError: if `p` is not floating-type.
931
+ """
932
+ dtype = lax.dtype(p)
933
+ if dtype not in (jnp.float32, jnp.float64):
934
+ raise TypeError(
935
+ "x.dtype={} is not supported, see docstring for supported types."
936
+ .format(dtype))
937
+ return _ndtri(p)
938
+
939
+
940
+ def _ndtri(p: ArrayLike) -> Array:
941
+ """Implements ndtri core logic."""
942
+
943
+ # Constants used in piece-wise rational approximations. Taken from the cephes
944
+ # library:
945
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
946
+ p0 = list(reversed([-5.99633501014107895267E1,
947
+ 9.80010754185999661536E1,
948
+ -5.66762857469070293439E1,
949
+ 1.39312609387279679503E1,
950
+ -1.23916583867381258016E0]))
951
+ q0 = list(reversed([1.0,
952
+ 1.95448858338141759834E0,
953
+ 4.67627912898881538453E0,
954
+ 8.63602421390890590575E1,
955
+ -2.25462687854119370527E2,
956
+ 2.00260212380060660359E2,
957
+ -8.20372256168333339912E1,
958
+ 1.59056225126211695515E1,
959
+ -1.18331621121330003142E0]))
960
+ p1 = list(reversed([4.05544892305962419923E0,
961
+ 3.15251094599893866154E1,
962
+ 5.71628192246421288162E1,
963
+ 4.40805073893200834700E1,
964
+ 1.46849561928858024014E1,
965
+ 2.18663306850790267539E0,
966
+ -1.40256079171354495875E-1,
967
+ -3.50424626827848203418E-2,
968
+ -8.57456785154685413611E-4]))
969
+ q1 = list(reversed([1.0,
970
+ 1.57799883256466749731E1,
971
+ 4.53907635128879210584E1,
972
+ 4.13172038254672030440E1,
973
+ 1.50425385692907503408E1,
974
+ 2.50464946208309415979E0,
975
+ -1.42182922854787788574E-1,
976
+ -3.80806407691578277194E-2,
977
+ -9.33259480895457427372E-4]))
978
+ p2 = list(reversed([3.23774891776946035970E0,
979
+ 6.91522889068984211695E0,
980
+ 3.93881025292474443415E0,
981
+ 1.33303460815807542389E0,
982
+ 2.01485389549179081538E-1,
983
+ 1.23716634817820021358E-2,
984
+ 3.01581553508235416007E-4,
985
+ 2.65806974686737550832E-6,
986
+ 6.23974539184983293730E-9]))
987
+ q2 = list(reversed([1.0,
988
+ 6.02427039364742014255E0,
989
+ 3.67983563856160859403E0,
990
+ 1.37702099489081330271E0,
991
+ 2.16236993594496635890E-1,
992
+ 1.34204006088543189037E-2,
993
+ 3.28014464682127739104E-4,
994
+ 2.89247864745380683936E-6,
995
+ 6.79019408009981274425E-9]))
996
+
997
+ dtype = lax.dtype(p).type
998
+ shape = jnp.shape(p)
999
+
1000
+ def _create_polynomial(var, coeffs):
1001
+ """Compute n_th order polynomial via Horner's method."""
1002
+ coeffs = np.array(coeffs, dtype)
1003
+ if not coeffs.size:
1004
+ return jnp.zeros_like(var)
1005
+ return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
1006
+
1007
+
1008
+ maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
1009
+ # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
1010
+ # later on. The result from the computation when p == 0 is not used so any
1011
+ # number that doesn't result in NaNs is fine.
1012
+ sanitized_mcp = jnp.where(
1013
+ maybe_complement_p == dtype(0.),
1014
+ jnp.full(shape, dtype(0.5)),
1015
+ maybe_complement_p)
1016
+
1017
+ # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
1018
+ w = sanitized_mcp - dtype(0.5)
1019
+ ww = lax.square(w)
1020
+ x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
1021
+ / _create_polynomial(ww, q0))
1022
+ x_for_big_p *= -dtype(np.sqrt(2. * np.pi))
1023
+
1024
+ # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
1025
+ # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
1026
+ # arrays based on whether p < exp(-32).
1027
+ z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
1028
+ first_term = z - lax.log(z) / z
1029
+ second_term_small_p = (
1030
+ _create_polynomial(dtype(1.) / z, p2) /
1031
+ _create_polynomial(dtype(1.) / z, q2) / z)
1032
+ second_term_otherwise = (
1033
+ _create_polynomial(dtype(1.) / z, p1) /
1034
+ _create_polynomial(dtype(1.) / z, q1) / z)
1035
+ x_for_small_p = first_term - second_term_small_p
1036
+ x_otherwise = first_term - second_term_otherwise
1037
+
1038
+ x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)),
1039
+ x_for_big_p,
1040
+ jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))
1041
+
1042
+ x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
1043
+ infinity = jnp.full(shape, dtype(np.inf))
1044
+ x_fix_boundaries = jnp.where(
1045
+ p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x))
1046
+ return x_fix_boundaries
1047
+
1048
+
1049
+ @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,))
1050
+ def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array:
1051
+ r"""Log Normal distribution function.
1052
+
1053
+ JAX implementation of :obj:`scipy.special.log_ndtr`.
1054
+
1055
+ For details of the Normal distribution function see `ndtr`.
1056
+
1057
+ This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
1058
+ :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:
1059
+
1060
+ - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
1061
+ :math:`\log(1-x) \approx -x, x \ll 1`.
1062
+ - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
1063
+ and take a log.
1064
+ - For `x <= lower_segment`, we use the series approximation of `erf` to compute
1065
+ the log CDF directly.
1066
+
1067
+ The `lower_segment` is set based on the precision of the input:
1068
+
1069
+ .. math::
1070
+ \begin{align}
1071
+ \mathit{lower\_segment} =&
1072
+ \ \begin{cases}
1073
+ -20 & x.\mathrm{dtype}=\mathit{float64} \\
1074
+ -10 & x.\mathrm{dtype}=\mathit{float32} \\
1075
+ \end{cases} \\
1076
+ \mathit{upper\_segment} =&
1077
+ \ \begin{cases}
1078
+ 8& x.\mathrm{dtype}=\mathit{float64} \\
1079
+ 5& x.\mathrm{dtype}=\mathit{float32} \\
1080
+ \end{cases}
1081
+ \end{align}
1082
+
1083
+
1084
+ When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
1085
+
1086
+ .. math::
1087
+ \begin{align}
1088
+ \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\
1089
+ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
1090
+ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
1091
+ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
1092
+ \end{align}
1093
+
1094
+ where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
1095
+ `double-factorial
1096
+ <https://en.wikipedia.org/wiki/Double_factorial>`_ operator.
1097
+
1098
+
1099
+ Args:
1100
+ x: an array of type `float32`, `float64`.
1101
+ series_order: Positive Python integer. Maximum depth to
1102
+ evaluate the asymptotic expansion. This is the `N` above.
1103
+
1104
+ Returns:
1105
+ an array with `dtype=x.dtype`.
1106
+
1107
+ Raises:
1108
+ TypeError: if `x.dtype` is not handled.
1109
+ TypeError: if `series_order` is a not Python `integer.`
1110
+ ValueError: if `series_order` is not in `[0, 30]`.
1111
+ """
1112
+ if not isinstance(series_order, int):
1113
+ raise TypeError("series_order must be a Python integer.")
1114
+ if series_order < 0:
1115
+ raise ValueError("series_order must be non-negative.")
1116
+ if series_order > 30:
1117
+ raise ValueError("series_order must be <= 30.")
1118
+
1119
+ x_arr = jnp.asarray(x)
1120
+ dtype = lax.dtype(x_arr)
1121
+
1122
+ if dtype == jnp.float64:
1123
+ lower_segment: np.ndarray = _LOGNDTR_FLOAT64_LOWER
1124
+ upper_segment: np.ndarray = _LOGNDTR_FLOAT64_UPPER
1125
+ elif dtype == jnp.float32:
1126
+ lower_segment = _LOGNDTR_FLOAT32_LOWER
1127
+ upper_segment = _LOGNDTR_FLOAT32_UPPER
1128
+ else:
1129
+ raise TypeError(f"x.dtype={np.dtype(dtype)} is not supported.")
1130
+
1131
+ # The basic idea here was ported from:
1132
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
1133
+ # We copy the main idea, with a few changes
1134
+ # * For x >> 1, and X ~ Normal(0, 1),
1135
+ # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
1136
+ # which extends the range of validity of this function.
1137
+ # * We use one fixed series_order for all of 'x', rather than adaptive.
1138
+ # * Our docstring properly reflects that this is an asymptotic series, not a
1139
+ # Taylor series. We also provided a correct bound on the remainder.
1140
+ # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
1141
+ # x=0. This happens even though the branch is unchosen because when x=0
1142
+ # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
1143
+ # regardless of whether dy is finite. Note that the minimum is a NOP if
1144
+ # the branch is chosen.
1145
+ return jnp.where(
1146
+ lax.gt(x_arr, upper_segment),
1147
+ -_ndtr(-x_arr), # log(1-x) ~= -x, x << 1
1148
+ jnp.where(lax.gt(x_arr, lower_segment),
1149
+ lax.log(_ndtr(lax.max(x_arr, lower_segment))),
1150
+ _log_ndtr_lower(lax.min(x_arr, lower_segment),
1151
+ series_order)))
1152
+
1153
+ def _log_ndtr_jvp(series_order, primals, tangents):
1154
+ (x,), (t,) = primals, tangents
1155
+ ans = log_ndtr(x, series_order=series_order)
1156
+ t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans)))
1157
+ return ans, t_out
1158
+ log_ndtr.defjvp(_log_ndtr_jvp)
1159
+
1160
+ def _log_ndtr_lower(x, series_order):
1161
+ """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
1162
+ dtype = lax.dtype(x).type
1163
+ x_2 = lax.square(x)
1164
+ # Log of the term multiplying (1 + sum)
1165
+ log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi))
1166
+ return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
1167
+
1168
+
1169
+ def _log_ndtr_asymptotic_series(x, series_order):
1170
+ """Calculates the asymptotic series used in log_ndtr."""
1171
+ dtype = lax.dtype(x).type
1172
+ if series_order <= 0:
1173
+ return np.array(1, dtype)
1174
+ x_2 = lax.square(x)
1175
+ even_sum = jnp.zeros_like(x)
1176
+ odd_sum = jnp.zeros_like(x)
1177
+ x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1.
1178
+ for n in range(1, series_order + 1):
1179
+ y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n
1180
+ if n % 2:
1181
+ odd_sum += y
1182
+ else:
1183
+ even_sum += y
1184
+ x_2n *= x_2
1185
+ return dtype(1.) + even_sum - odd_sum
1186
+
1187
+
1188
+ def _double_factorial(n: int) -> np.ndarray:
1189
+ """The double factorial function for small Python integer `n`."""
1190
+ return np.prod(np.arange(n, 1, -2))
1191
+
1192
+
1193
+ _norm_logpdf_constant = np.log(np.sqrt(2 * np.pi))
1194
+
1195
+ def _norm_logpdf(x):
1196
+ neg_half = _lax_const(x, -0.5)
1197
+ log_normalizer = _lax_const(x, _norm_logpdf_constant)
1198
+ return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
1199
+
1200
+
1201
+ def i0e(x: ArrayLike) -> Array:
1202
+ r"""Exponentially scaled modified bessel function of zeroth order.
1203
+
1204
+ JAX implementation of :obj:`scipy.special.i0e`.
1205
+
1206
+ .. math::
1207
+
1208
+ \mathrm{i0e}(x) = e^{-|x|} I_0(x)
1209
+
1210
+ where :math:`I_0(x)` is the modified Bessel function :func:`~jax.scipy.special.i0`.
1211
+
1212
+ Args:
1213
+ x: array, real-valued
1214
+
1215
+ Returns:
1216
+ array of bessel function values.
1217
+
1218
+ See also:
1219
+ - :func:`jax.scipy.special.i0`
1220
+ - :func:`jax.scipy.special.i1`
1221
+ - :func:`jax.scipy.special.i1e`
1222
+ """
1223
+ x, = promote_args_inexact("i0e", x)
1224
+ return lax.bessel_i0e(x)
1225
+
1226
+
1227
+ def i0(x: ArrayLike) -> Array:
1228
+ r"""Modified bessel function of zeroth order.
1229
+
1230
+ JAX implementation of :obj:`scipy.special.i0`.
1231
+
1232
+ .. math::
1233
+
1234
+ \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^\infty \frac{(x^2/4)^k}{(k!)^2}
1235
+
1236
+ Args:
1237
+ x: array, real-valued
1238
+
1239
+ Returns:
1240
+ array of bessel function values.
1241
+
1242
+ See also:
1243
+ - :func:`jax.scipy.special.i0e`
1244
+ - :func:`jax.scipy.special.i1`
1245
+ - :func:`jax.scipy.special.i1e`
1246
+ """
1247
+ x, = promote_args_inexact("i0", x)
1248
+ return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
1249
+
1250
+
1251
+ def i1e(x: ArrayLike) -> Array:
1252
+ r"""Exponentially scaled modified bessel function of first order.
1253
+
1254
+ JAX implementation of :obj:`scipy.special.i1e`.
1255
+
1256
+ .. math::
1257
+
1258
+ \mathrm{i1e}(x) = e^{-|x|} I_1(x)
1259
+
1260
+ where :math:`I_1(x)` is the modified Bessel function :func:`~jax.scipy.special.i1`.
1261
+
1262
+ Args:
1263
+ x: array, real-valued
1264
+
1265
+ Returns:
1266
+ array of bessel function values
1267
+
1268
+ See also:
1269
+ - :func:`jax.scipy.special.i0`
1270
+ - :func:`jax.scipy.special.i0e`
1271
+ - :func:`jax.scipy.special.i1`
1272
+ """
1273
+ x, = promote_args_inexact("i1e", x)
1274
+ return lax.bessel_i1e(x)
1275
+
1276
+
1277
+ def i1(x: ArrayLike) -> Array:
1278
+ r"""Modified bessel function of first order.
1279
+
1280
+ JAX implementation of :obj:`scipy.special.i1`.
1281
+
1282
+ .. math::
1283
+
1284
+ \mathrm{i1}(x) = I_1(x) = \frac{1}{2}x\sum_{k=0}^\infty\frac{(x^2/4)^k}{k!(k+1)!}
1285
+
1286
+ Args:
1287
+ x: array, real-valued
1288
+
1289
+ Returns:
1290
+ array of bessel function values
1291
+
1292
+ See also:
1293
+ - :func:`jax.scipy.special.i0`
1294
+ - :func:`jax.scipy.special.i0e`
1295
+ - :func:`jax.scipy.special.i1e`
1296
+ """
1297
+ x, = promote_args_inexact("i1", x)
1298
+ return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
1299
+
1300
+ def _bessel_jn_scan_body_fun(carry, k):
1301
+ f0, f1, bs, z = carry
1302
+ f = 2.0 * (k + 1.0) * f1 / z - f0
1303
+
1304
+ def true_fn_update_bs(u):
1305
+ bs, f = u
1306
+ return bs + 2.0 * f
1307
+
1308
+ def false_fn_update_bs(u):
1309
+ bs, _ = u
1310
+ return bs
1311
+
1312
+ bs = lax.cond(jnp.mod(k, 2) == 0, true_fn_update_bs,
1313
+ false_fn_update_bs, operand=(bs, f))
1314
+
1315
+ f0 = f1
1316
+ f1 = f
1317
+ return (f0, f1, bs, z), f
1318
+
1319
+
1320
+ def _bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
1321
+ f0 = _lax_const(z, 0.0)
1322
+ f1 = _lax_const(z, 1E-16)
1323
+ f = _lax_const(z, 0.0)
1324
+ bs = _lax_const(z, 0.0)
1325
+
1326
+ (_, _, bs, _), j_vals = lax.scan(
1327
+ f=_bessel_jn_scan_body_fun, init=(f0, f1, bs, z),
1328
+ xs=lax.iota(lax.dtype(z), n_iter+1), reverse=True)
1329
+
1330
+ f = j_vals[0] # Use the value at the last iteration.
1331
+ j_vals = j_vals[:v+1]
1332
+ j_vals = j_vals / (bs - f)
1333
+
1334
+ return j_vals
1335
+
1336
+
1337
+ @partial(jit, static_argnames=["v", "n_iter"])
1338
+ def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
1339
+ """Bessel function of the first kind of integer order and real argument.
1340
+
1341
+ Reference:
1342
+ Shanjie Zhang and Jian-Ming Jin. Computation of special functions.
1343
+ Wiley-Interscience, 1996.
1344
+
1345
+ Args:
1346
+ z: The sampling point(s) at which the Bessel function of the first kind are
1347
+ computed.
1348
+ v: The order (int) of the Bessel function.
1349
+ n_iter: The number of iterations required for updating the function
1350
+ values. As a rule of thumb, `n_iter` is the smallest nonnegative integer
1351
+ that satisfies the condition
1352
+ `int(0.5 * log10(6.28 + n_iter) - n_iter * log10(1.36 + abs(z) / n_iter)) > 20`.
1353
+ Details in `BJNDD` (https://people.sc.fsu.edu/~jburkardt/f77_src/special_functions/special_functions.f)
1354
+
1355
+ Returns:
1356
+ An array of shape `(v+1, *z.shape)` containing the values of the Bessel
1357
+ function of orders 0, 1, ..., v. The return type matches the type of `z`.
1358
+
1359
+ Raises:
1360
+ TypeError if `v` is not integer.
1361
+ ValueError if elements of array `z` are not float.
1362
+ """
1363
+ z = jnp.asarray(z)
1364
+ z, = promote_dtypes_inexact(z)
1365
+ z_dtype = lax.dtype(z)
1366
+ if dtypes.issubdtype(z_dtype, complex):
1367
+ raise ValueError("complex input not supported.")
1368
+
1369
+ v = core.concrete_or_error(operator.index, v, 'Argument v of bessel_jn.')
1370
+ n_iter = core.concrete_or_error(int, n_iter, 'Argument n_iter of bessel_jn.')
1371
+
1372
+ bessel_jn_fun = partial(_bessel_jn, v=v, n_iter=n_iter)
1373
+ for _ in range(z.ndim):
1374
+ bessel_jn_fun = vmap(bessel_jn_fun)
1375
+ return jnp.moveaxis(bessel_jn_fun(z), -1, 0)
1376
+
1377
+
1378
+ def _gen_recurrence_mask(
1379
+ l_max: int, is_normalized: bool, dtype: Any
1380
+ ) -> tuple[Array, Array]:
1381
+ """Generates a mask for recurrence relation on the remaining entries.
1382
+
1383
+ The remaining entries are with respect to the diagonal and offdiagonal
1384
+ entries.
1385
+
1386
+ Args:
1387
+ l_max: see `gen_normalized_legendre`.
1388
+ is_normalized: True if the recurrence mask is used by normalized associated
1389
+ Legendre functions.
1390
+
1391
+ Returns:
1392
+ Arrays representing the mask used by the recurrence relations.
1393
+ """
1394
+
1395
+ # Computes all coefficients.
1396
+ m_mat, l_mat = jnp.meshgrid(
1397
+ jnp.arange(l_max + 1, dtype=dtype),
1398
+ jnp.arange(l_max + 1, dtype=dtype),
1399
+ indexing='ij')
1400
+ if is_normalized:
1401
+ c0 = l_mat * l_mat
1402
+ c1 = m_mat * m_mat
1403
+ c2 = 2.0 * l_mat
1404
+ c3 = (l_mat - 1.0) * (l_mat - 1.0)
1405
+ d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
1406
+ d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
1407
+ else:
1408
+ d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
1409
+ d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
1410
+
1411
+ d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
1412
+ d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
1413
+ d_zeros = jnp.zeros((l_max + 1, l_max + 1), dtype=dtype)
1414
+ d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
1415
+ d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])
1416
+
1417
+ # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
1418
+ # i = jnp.arange(l_max + 1)[:, None, None]
1419
+ # j = jnp.arange(l_max + 1)[None, :, None]
1420
+ # k = jnp.arange(l_max + 1)[None, None, :]
1421
+ i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
1422
+ mask = (i + j - k == 0).astype(dtype)
1423
+
1424
+ d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
1425
+ d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
1426
+
1427
+ return (d0_mask_3d, d1_mask_3d)
1428
+
1429
+
1430
+ @partial(jit, static_argnums=(2))
1431
+ def _gen_derivatives(p: Array,
1432
+ x: Array,
1433
+ is_normalized: bool) -> Array:
1434
+ """Generates derivatives of associated Legendre functions of the first kind.
1435
+
1436
+ Args:
1437
+ p: The 3D array containing the values of associated Legendre functions; the
1438
+ dimensions are in the sequence of order (m), degree (l), and evaluation
1439
+ points.
1440
+ x: A vector of type `float32` or `float64` containing the sampled points.
1441
+ is_normalized: True if the associated Legendre functions are normalized.
1442
+ Returns:
1443
+ The 3D array representing the derivatives of associated Legendre functions
1444
+ of the first kind.
1445
+ """
1446
+
1447
+ num_m, num_l, num_x = p.shape
1448
+
1449
+ # p_{l-1}^m.
1450
+ p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]
1451
+
1452
+ # p_{l-1}^{m+2}.
1453
+ p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]
1454
+
1455
+ # p_{l-1}^{m-2}.
1456
+ p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]
1457
+
1458
+ # Derivative computation requires negative orders.
1459
+ if is_normalized:
1460
+ raise NotImplementedError(
1461
+ 'Negative orders for normalization is not implemented yet.')
1462
+ else:
1463
+ if num_l > 1:
1464
+ l_vec = jnp.arange(1, num_l - 1, dtype=x.dtype)
1465
+ p_p1 = p[1, 1:num_l - 1, :]
1466
+ coeff = -1.0 / ((l_vec + 1) * l_vec)
1467
+ update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
1468
+ p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)
1469
+
1470
+ if num_l > 2:
1471
+ l_vec = jnp.arange(2, num_l - 1, dtype=x.dtype)
1472
+ p_p2 = p[2, 2:num_l - 1, :]
1473
+ coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1))
1474
+ update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
1475
+ p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)
1476
+
1477
+ m_mat, l_mat = jnp.meshgrid(
1478
+ jnp.arange(num_m, dtype=x.dtype),
1479
+ jnp.arange(num_l, dtype=x.dtype),
1480
+ indexing='ij')
1481
+
1482
+ coeff_zeros = jnp.zeros((num_m, num_l), dtype=x.dtype)
1483
+ upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
1484
+ zero_vec = jnp.zeros((num_l,), dtype=x.dtype)
1485
+
1486
+ a0 = -0.5 / (m_mat - 1.0)
1487
+ a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
1488
+ a0_masked = a0_masked.at[1, :].set(zero_vec)
1489
+
1490
+ b0 = l_mat + m_mat
1491
+ c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
1492
+ c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
1493
+ c0_masked = c0_masked.at[1, :].set(zero_vec)
1494
+
1495
+ # p_l^{m-1}.
1496
+ p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
1497
+ jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))
1498
+
1499
+ d0 = -0.5 / (m_mat + 1.0)
1500
+ d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
1501
+ e0 = d0 * b0 * (b0 + 1.0)
1502
+ e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])
1503
+
1504
+ # p_l^{m+1}.
1505
+ p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
1506
+ jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))
1507
+
1508
+ f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
1509
+ f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
1510
+ p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l
1511
+
1512
+ # Special treatment of the singularity at m = 1.
1513
+ if num_m > 1:
1514
+ l_vec = jnp.arange(num_l, dtype=p.dtype)
1515
+ g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
1516
+ if num_l > 2:
1517
+ g0 = g0 - p[2, :, :]
1518
+ p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
1519
+ p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
1520
+ p_derivative = p_derivative.at[1, 0, :].set(0)
1521
+
1522
+ return p_derivative
1523
+
1524
+
1525
+ @partial(jit, static_argnums=(0, 2))
1526
+ def _gen_associated_legendre(l_max: int,
1527
+ x: Array,
1528
+ is_normalized: bool) -> Array:
1529
+ r"""Computes associated Legendre functions (ALFs) of the first kind.
1530
+
1531
+ The ALFs of the first kind are used in spherical harmonics. The spherical
1532
+ harmonic of degree `l` and order `m` can be written as
1533
+ `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
1534
+ normalization factor and θ and φ are the colatitude and longitude,
1535
+ respectively. `N_l^m` is chosen in the way that the spherical harmonics form
1536
+ a set of orthonormal basis functions of L^2(S^2). For the computational
1537
+ efficiency of spherical harmonics transform, the normalization factor is
1538
+ used in the computation of the ALFs. In addition, normalizing `P_l^m`
1539
+ avoids overflow/underflow and achieves better numerical stability. Three
1540
+ recurrence relations are used in the computation.
1541
+
1542
+ Args:
1543
+ l_max: The maximum degree of the associated Legendre function. Both the
1544
+ degrees and orders are `[0, 1, 2, ..., l_max]`.
1545
+ x: A vector of type `float32`, `float64` containing the sampled points in
1546
+ spherical coordinates, at which the ALFs are computed; `x` is essentially
1547
+ `cos(θ)`. For the numerical integration used by the spherical harmonics
1548
+ transforms, `x` contains the quadrature points in the interval of
1549
+ `[-1, 1]`. There are several approaches to provide the quadrature points:
1550
+ Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
1551
+ method (`scipy.special.roots_chebyu`), and Driscoll & Healy
1552
+ method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
1553
+ transforms and convolutions on the 2-sphere." Advances in applied
1554
+ mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
1555
+ points are nearly equal-spaced along θ and provide exact discrete
1556
+ orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
1557
+ operation, `W` is a diagonal matrix containing the quadrature weights,
1558
+ and `I` is the identity matrix. The Gauss-Chebyshev points are equally
1559
+ spaced, which only provide approximate discrete orthogonality. The
1560
+ Driscoll & Healy quadrature points are equally spaced and provide the
1561
+ exact discrete orthogonality. The number of sampling points is required to
1562
+ be twice as the number of frequency points (modes) in the Driscoll & Healy
1563
+ approach, which enables FFT and achieves a fast spherical harmonics
1564
+ transform.
1565
+ is_normalized: True if the associated Legendre functions are normalized.
1566
+ With normalization, `N_l^m` is applied such that the spherical harmonics
1567
+ form a set of orthonormal basis functions of L^2(S^2).
1568
+
1569
+ Returns:
1570
+ The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
1571
+ of the ALFs at `x`; the dimensions in the sequence of order, degree, and
1572
+ evaluation points.
1573
+ """
1574
+ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype)
1575
+
1576
+ a_idx = jnp.arange(1, l_max + 1, dtype=x.dtype)
1577
+ b_idx = jnp.arange(l_max, dtype=x.dtype)
1578
+ if is_normalized:
1579
+ initial_value: ArrayLike = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
1580
+ f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
1581
+ f_b = jnp.sqrt(2.0 * b_idx + 3.0)
1582
+ else:
1583
+ initial_value = 1.0 # The initial value p(0,0).
1584
+ f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
1585
+ f_b = 2.0 * b_idx + 1.0
1586
+
1587
+ p = p.at[(0, 0)].set(initial_value)
1588
+
1589
+ # Compute the diagonal entries p(l,l) with recurrence.
1590
+ y = jnp.cumprod(
1591
+ jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])),
1592
+ axis=0)
1593
+ p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
1594
+ diag_indices = jnp.diag_indices(l_max + 1)
1595
+ p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)
1596
+
1597
+ # Compute the off-diagonal entries with recurrence.
1598
+ p_offdiag = jnp.einsum('ij,ij->ij',
1599
+ jnp.einsum('i,j->ij', f_b, x),
1600
+ p[jnp.diag_indices(l_max)])
1601
+ offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
1602
+ p = p.at[offdiag_indices].set(p_offdiag)
1603
+
1604
+ # Compute the remaining entries with recurrence.
1605
+ d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(
1606
+ l_max, is_normalized=is_normalized, dtype=x.dtype)
1607
+
1608
+ def body_fun(i, p_val):
1609
+ coeff_0 = d0_mask_3d[i]
1610
+ coeff_1 = d1_mask_3d[i]
1611
+ h = (jnp.einsum('ij,ijk->ijk',
1612
+ coeff_0,
1613
+ jnp.einsum(
1614
+ 'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
1615
+ jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1)))
1616
+ p_val = p_val + h
1617
+ return p_val
1618
+
1619
+ # TODO(jakevdp): use some sort of fixed-point procedure here instead?
1620
+ p = p.astype(jnp.result_type(p, x, d0_mask_3d))
1621
+ if l_max > 1:
1622
+ p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p)
1623
+
1624
+ return p
1625
+
1626
+
1627
+ def lpmn(m: int, n: int, z: Array) -> tuple[Array, Array]:
1628
+ """The associated Legendre functions (ALFs) of the first kind.
1629
+
1630
+ Args:
1631
+ m: The maximum order of the associated Legendre functions.
1632
+ n: The maximum degree of the associated Legendre function, often called
1633
+ `l` in describing ALFs. Both the degrees and orders are
1634
+ `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
1635
+ z: A vector of type `float32` or `float64` containing the sampling
1636
+ points at which the ALFs are computed.
1637
+
1638
+ Returns:
1639
+ A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing
1640
+ the values and derivatives of the associated Legendre functions of the
1641
+ first kind. The return type matches the type of `z`.
1642
+
1643
+ Raises:
1644
+ TypeError if elements of array `z` are not in (float32, float64).
1645
+ ValueError if array `z` is not 1D.
1646
+ NotImplementedError if `m!=n`.
1647
+ """
1648
+ dtype = lax.dtype(z)
1649
+ if dtype not in (jnp.float32, jnp.float64):
1650
+ raise TypeError(
1651
+ 'z.dtype={} is not supported, see docstring for supported types.'
1652
+ .format(dtype))
1653
+
1654
+ if z.ndim != 1:
1655
+ raise ValueError('z must be a 1D array.')
1656
+
1657
+ m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
1658
+ n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
1659
+
1660
+ if m != n:
1661
+ raise NotImplementedError('Computations for m!=n are not yet supported.')
1662
+
1663
+ l_max = n
1664
+ is_normalized = False
1665
+ p_vals = _gen_associated_legendre(l_max, z, is_normalized)
1666
+ p_derivatives = _gen_derivatives(p_vals, z, is_normalized)
1667
+
1668
+ return (p_vals, p_derivatives)
1669
+
1670
+
1671
+ def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array:
1672
+ r"""The associated Legendre functions (ALFs) of the first kind.
1673
+
1674
+ Unlike `lpmn`, this function only computes the values of ALFs.
1675
+ The ALFs of the first kind can be used in spherical harmonics. The
1676
+ spherical harmonic of degree `l` and order `m` can be written as
1677
+ :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
1678
+ where :math:`N_l^m` is the normalization factor and θ and φ are the
1679
+ colatitude and longitude, respectively. :math:`N_l^m` is chosen in the
1680
+ way that the spherical harmonics form a set of orthonormal basis function
1681
+ of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
1682
+ and achieves better numerical stability.
1683
+
1684
+ Args:
1685
+ m: The maximum order of the associated Legendre functions.
1686
+ n: The maximum degree of the associated Legendre function, often called
1687
+ `l` in describing ALFs. Both the degrees and orders are
1688
+ `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
1689
+ z: A vector of type `float32` or `float64` containing the sampling
1690
+ points at which the ALFs are computed.
1691
+ is_normalized: True if the associated Legendre functions are normalized.
1692
+ With normalization, :math:`N_l^m` is applied such that the spherical
1693
+ harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`.
1694
+
1695
+ Returns:
1696
+ A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing
1697
+ the values of the associated Legendre functions of the first kind. The
1698
+ return type matches the type of `z`.
1699
+
1700
+ Raises:
1701
+ TypeError if elements of array `z` are not in (float32, float64).
1702
+ ValueError if array `z` is not 1D.
1703
+ NotImplementedError if `m!=n`.
1704
+ """
1705
+ dtype = lax.dtype(z)
1706
+ if dtype not in (jnp.float32, jnp.float64):
1707
+ raise TypeError(
1708
+ 'z.dtype={} is not supported, see docstring for supported types.'
1709
+ .format(dtype))
1710
+
1711
+ if z.ndim != 1:
1712
+ raise ValueError('z must be a 1D array.')
1713
+
1714
+ m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
1715
+ n = core.concrete_or_error(int, n, 'Argument n of lpmn.')
1716
+
1717
+ if m != n:
1718
+ raise NotImplementedError('Computations for m!=n are not yet supported.')
1719
+
1720
+ l_max = n
1721
+
1722
+ return _gen_associated_legendre(l_max, z, is_normalized)
1723
+
1724
+
1725
+
1726
+ @partial(jit, static_argnums=(4,))
1727
+ def _sph_harm(m: Array,
1728
+ n: Array,
1729
+ theta: Array,
1730
+ phi: Array,
1731
+ n_max: int) -> Array:
1732
+ """Computes the spherical harmonics."""
1733
+
1734
+ cos_colatitude = jnp.cos(phi)
1735
+
1736
+ legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
1737
+ legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
1738
+
1739
+ angle = abs(m) * theta
1740
+ vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
1741
+ harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
1742
+ legendre_val * jnp.imag(vandermonde))
1743
+
1744
+ # Negative order.
1745
+ harmonics = jnp.where(m < 0,
1746
+ (-1.0)**abs(m) * jnp.conjugate(harmonics),
1747
+ harmonics)
1748
+
1749
+ return harmonics
1750
+
1751
+
1752
+ def sph_harm(m: Array,
1753
+ n: Array,
1754
+ theta: Array,
1755
+ phi: Array,
1756
+ n_max: int | None = None) -> Array:
1757
+ r"""Computes the spherical harmonics.
1758
+
1759
+ The JAX version has one extra argument `n_max`, the maximum value in `n`.
1760
+
1761
+ The spherical harmonic of degree `n` and order `m` can be written as
1762
+ :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`,
1763
+ where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!}
1764
+ {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and
1765
+ :math:`\theta` are the colatitude and longitude, respectively. :math:`N_n^m` is
1766
+ chosen in the way that the spherical harmonics form a set of orthonormal basis
1767
+ functions of :math:`L^2(S^2)`.
1768
+
1769
+ Args:
1770
+ m: The order of the harmonic; must have `|m| <= n`. Return values for
1771
+ `|m| > n` are undefined.
1772
+ n: The degree of the harmonic; must have `n >= 0`. The standard notation for
1773
+ degree in descriptions of spherical harmonics is `l (lower case L)`. We
1774
+ use `n` here to be consistent with `scipy.special.sph_harm`. Return
1775
+ values for `n < 0` are undefined.
1776
+ theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi].
1777
+ phi: The polar (colatitudinal) coordinate; must be in [0, pi].
1778
+ n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true
1779
+ maximum value of `n`, the results are clipped to `n_max`. For example,
1780
+ `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)`
1781
+ actually returns
1782
+ `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)`
1783
+ Returns:
1784
+ A 1D array containing the spherical harmonics at (m, n, theta, phi).
1785
+ """
1786
+
1787
+ if jnp.isscalar(phi):
1788
+ phi = jnp.array([phi])
1789
+
1790
+ if n_max is None:
1791
+ n_max = np.max(n)
1792
+ n_max = core.concrete_or_error(
1793
+ int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
1794
+ 'be statically specified to use `sph_harm` within JAX transformations.')
1795
+
1796
+ return _sph_harm(m, n, theta, phi, n_max)
1797
+
1798
+
1799
+ # exponential integrals
1800
+ # these algorithms are ported over from the files ei.c and expn.c in the Cephes mathematical library.
1801
+ # https://fossies.org/dox/cephes-math-28/ei_8c_source.html
1802
+ # https://fossies.org/dox/cephes-math-28/expn_8c_source.html
1803
+
1804
+
1805
+ def _expint1(x: Array) -> Array:
1806
+ # 0 < x <= 2
1807
+ A = [
1808
+ -5.350447357812542947283e0,
1809
+ 2.185049168816613393830e2,
1810
+ -4.176572384826693777058e3,
1811
+ 5.541176756393557601232e4,
1812
+ -3.313381331178144034309e5,
1813
+ 1.592627163384945414220e6,
1814
+ ]
1815
+ B = [
1816
+ 1.0,
1817
+ -5.250547959112862969197e1,
1818
+ 1.259616186786790571525e3,
1819
+ -1.756549581973534652631e4,
1820
+ 1.493062117002725991967e5,
1821
+ -7.294949239640527645655e5,
1822
+ 1.592627163384945429726e6,
1823
+ ]
1824
+ A_arr = jnp.array(A, dtype=x.dtype)
1825
+ B_arr = jnp.array(B, dtype=x.dtype)
1826
+ f = jnp.polyval(A_arr, x) / jnp.polyval(B_arr, x)
1827
+ return x * f + jnp.euler_gamma + jnp.log(x)
1828
+
1829
+
1830
+ def _eval_expint_k(A: list[float], B: list[float], x: Array) -> Array:
1831
+ # helper function for all subsequent intervals
1832
+ A_arr = jnp.array(A, dtype=x.dtype)
1833
+ B_arr = jnp.array(B, dtype=x.dtype)
1834
+ one = _lax_const(x, 1.0)
1835
+ w = one / x
1836
+ f = jnp.polyval(A_arr, w) / jnp.polyval(B_arr, w)
1837
+ f = w * f + one
1838
+ return jnp.exp(x) * w * f
1839
+
1840
+
1841
+ def _expint2(x: Array) -> Array:
1842
+ # 2 <= x < 4
1843
+ A = [
1844
+ 1.981808503259689673238e-2,
1845
+ -1.271645625984917501326e0,
1846
+ -2.088160335681228318920e0,
1847
+ 2.755544509187936721172e0,
1848
+ -4.409507048701600257171e-1,
1849
+ 4.665623805935891391017e-2,
1850
+ -1.545042679673485262580e-3,
1851
+ 7.059980605299617478514e-5,
1852
+ ]
1853
+ B = [
1854
+ 1.0,
1855
+ 1.476498670914921440652e0,
1856
+ 5.629177174822436244827e-1,
1857
+ 1.699017897879307263248e-1,
1858
+ 2.291647179034212017463e-2,
1859
+ 4.450150439728752875043e-3,
1860
+ 1.727439612206521482874e-4,
1861
+ 3.953167195549672482304e-5,
1862
+ ]
1863
+ return _eval_expint_k(A, B, x)
1864
+
1865
+
1866
+ def _expint3(x: Array) -> Array:
1867
+ # 4 <= x <= 8
1868
+ A = [
1869
+ -1.373215375871208729803e0,
1870
+ -7.084559133740838761406e-1,
1871
+ 1.580806855547941010501e0,
1872
+ -2.601500427425622944234e-1,
1873
+ 2.994674694113713763365e-2,
1874
+ -1.038086040188744005513e-3,
1875
+ 4.371064420753005429514e-5,
1876
+ 2.141783679522602903795e-6,
1877
+ ]
1878
+ B = [
1879
+ 1.0,
1880
+ 8.585231423622028380768e-1,
1881
+ 4.483285822873995129957e-1,
1882
+ 7.687932158124475434091e-2,
1883
+ 2.449868241021887685904e-2,
1884
+ 8.832165941927796567926e-4,
1885
+ 4.590952299511353531215e-4,
1886
+ -4.729848351866523044863e-6,
1887
+ 2.665195537390710170105e-6,
1888
+ ]
1889
+ return _eval_expint_k(A, B, x)
1890
+
1891
+
1892
+ def _expint4(x: Array) -> Array:
1893
+ # 8 <= x <= 16
1894
+ A = [
1895
+ -2.106934601691916512584e0,
1896
+ 1.732733869664688041885e0,
1897
+ -2.423619178935841904839e-1,
1898
+ 2.322724180937565842585e-2,
1899
+ 2.372880440493179832059e-4,
1900
+ -8.343219561192552752335e-5,
1901
+ 1.363408795605250394881e-5,
1902
+ -3.655412321999253963714e-7,
1903
+ 1.464941733975961318456e-8,
1904
+ 6.176407863710360207074e-10,
1905
+ ]
1906
+ B = [
1907
+ 1.0,
1908
+ -2.298062239901678075778e-1,
1909
+ 1.105077041474037862347e-1,
1910
+ -1.566542966630792353556e-2,
1911
+ 2.761106850817352773874e-3,
1912
+ -2.089148012284048449115e-4,
1913
+ 1.708528938807675304186e-5,
1914
+ -4.459311796356686423199e-7,
1915
+ 1.394634930353847498145e-8,
1916
+ 6.150865933977338354138e-10,
1917
+ ]
1918
+ return _eval_expint_k(A, B, x)
1919
+
1920
+
1921
+ def _expint5(x):
1922
+ # 16 <= x <= 32
1923
+ A = [
1924
+ -2.458119367674020323359e-1,
1925
+ -1.483382253322077687183e-1,
1926
+ 7.248291795735551591813e-2,
1927
+ -1.348315687380940523823e-2,
1928
+ 1.342775069788636972294e-3,
1929
+ -7.942465637159712264564e-5,
1930
+ 2.644179518984235952241e-6,
1931
+ -4.239473659313765177195e-8,
1932
+ ]
1933
+ B = [
1934
+ 1.0,
1935
+ -1.044225908443871106315e-1,
1936
+ -2.676453128101402655055e-1,
1937
+ 9.695000254621984627876e-2,
1938
+ -1.601745692712991078208e-2,
1939
+ 1.496414899205908021882e-3,
1940
+ -8.462452563778485013756e-5,
1941
+ 2.728938403476726394024e-6,
1942
+ -4.239462431819542051337e-8,
1943
+ ]
1944
+ return _eval_expint_k(A, B, x)
1945
+
1946
+
1947
+ def _expint6(x):
1948
+ # 32 <= x <= 64
1949
+ A = [
1950
+ 1.212561118105456670844e-1,
1951
+ -5.823133179043894485122e-1,
1952
+ 2.348887314557016779211e-1,
1953
+ -3.040034318113248237280e-2,
1954
+ 1.510082146865190661777e-3,
1955
+ -2.523137095499571377122e-5,
1956
+ ]
1957
+ B = [
1958
+ 1.0,
1959
+ -1.002252150365854016662e0,
1960
+ 2.928709694872224144953e-1,
1961
+ -3.337004338674007801307e-2,
1962
+ 1.560544881127388842819e-3,
1963
+ -2.523137093603234562648e-5,
1964
+ ]
1965
+ return _eval_expint_k(A, B, x)
1966
+
1967
+
1968
+ def _expint7(x):
1969
+ # x > 64
1970
+ A = [
1971
+ -7.657847078286127362028e-1,
1972
+ 6.886192415566705051750e-1,
1973
+ -2.132598113545206124553e-1,
1974
+ 3.346107552384193813594e-2,
1975
+ -3.076541477344756050249e-3,
1976
+ 1.747119316454907477380e-4,
1977
+ -6.103711682274170530369e-6,
1978
+ 1.218032765428652199087e-7,
1979
+ -1.086076102793290233007e-9,
1980
+ ]
1981
+ B = [
1982
+ 1.0,
1983
+ -1.888802868662308731041e0,
1984
+ 1.066691687211408896850e0,
1985
+ -2.751915982306380647738e-1,
1986
+ 3.930852688233823569726e-2,
1987
+ -3.414684558602365085394e-3,
1988
+ 1.866844370703555398195e-4,
1989
+ -6.345146083130515357861e-6,
1990
+ 1.239754287483206878024e-7,
1991
+ -1.086076102793126632978e-9,
1992
+ ]
1993
+ return _eval_expint_k(A, B, x)
1994
+
1995
+
1996
+ def _expi_pos(x: Array) -> Array:
1997
+ # x >= 0
1998
+ _c = _lax_const
1999
+ conds = [(_c(x, 0) < x) & (x <= _c(x, 2))] + [
2000
+ (_c(x, 2 ** i) < x) & (x <= _c(x, 2 ** (i + 1))) for i in range(1, 6)
2001
+ ]
2002
+ return jnp.piecewise(
2003
+ x,
2004
+ conds,
2005
+ [_expint1, _expint2, _expint3, _expint4, _expint5, _expint6, _expint7],
2006
+ )
2007
+
2008
+ def _expi_neg(x: Array) -> Array:
2009
+ # x < 0
2010
+ return -exp1(-x)
2011
+
2012
+ @custom_derivatives.custom_jvp
2013
+ @jit
2014
+ def expi(x: ArrayLike) -> Array:
2015
+ r"""Exponential integral function.
2016
+
2017
+ JAX implementation of :obj:`scipy.special.expi`
2018
+
2019
+ .. math::
2020
+
2021
+ \mathrm{expi}(x) = \int_{-\infty}^x \frac{e^t}{t} \mathrm{d}t
2022
+
2023
+ Args:
2024
+ x: arraylike, real-valued
2025
+
2026
+ Returns:
2027
+ array of expi values
2028
+
2029
+ See also:
2030
+ - :func:`jax.scipy.special.expn`
2031
+ - :func:`jax.scipy.special.exp1`
2032
+ """
2033
+ x_arr, = promote_args_inexact("expi", x)
2034
+ return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos])
2035
+
2036
+
2037
+ @expi.defjvp
2038
+ @jit
2039
+ def expi_jvp(primals, tangents):
2040
+ (x,) = primals
2041
+ (x_dot,) = tangents
2042
+ return expi(x), jnp.exp(x) / x * x_dot
2043
+
2044
+
2045
+ def _expn1(n: int, x_in: ArrayLike) -> Array:
2046
+ # exponential integral En
2047
+ _c = _lax_const
2048
+ x = jnp.asarray(x_in)
2049
+ MACHEP = jnp.finfo(x.dtype).eps
2050
+
2051
+ zero = _c(x, 0.0)
2052
+ one = _c(x, 1.0)
2053
+ psi = -jnp.euler_gamma - jnp.log(x)
2054
+ psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi)
2055
+ n1 = jnp.where(n == _c(n, 1), one + one, n)
2056
+ init = dict(
2057
+ x=x,
2058
+ z=-x,
2059
+ xk=zero,
2060
+ yk=one,
2061
+ pk=one - n,
2062
+ ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)),
2063
+ t=jnp.inf,
2064
+ )
2065
+
2066
+ def body(d):
2067
+ d["xk"] += one
2068
+ d["yk"] *= d["z"] / d["xk"]
2069
+ d["pk"] += one
2070
+ d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero)
2071
+ d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one)
2072
+ return d
2073
+
2074
+ def cond(d):
2075
+ return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP)
2076
+
2077
+ d = lax.while_loop(cond, body, init)
2078
+ t = n
2079
+ r = n - _c(n, 1)
2080
+ return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]
2081
+
2082
+
2083
+ def _expn2(n: int, x: Array) -> Array:
2084
+ # x > 1.
2085
+ _c = _lax_const
2086
+ BIG = _c(x, 1.44115188075855872e17)
2087
+ MACHEP = jnp.finfo(BIG.dtype).eps # ?
2088
+ zero = _c(x, 0.0)
2089
+ one = _c(x, 1.0)
2090
+
2091
+ init = dict(
2092
+ k=_c(n, 1),
2093
+ pkm2=one,
2094
+ qkm2=x,
2095
+ pkm1=one,
2096
+ qkm1=x + n,
2097
+ ans=one / (x + n),
2098
+ t=_c(x, jnp.inf),
2099
+ r=zero,
2100
+ x=x,
2101
+ )
2102
+
2103
+ def body(d):
2104
+ x = d["x"]
2105
+ d["k"] += _c(d["k"], 1)
2106
+ k = d["k"]
2107
+ odd = k % _c(k, 2) == _c(k, 1)
2108
+ yk = jnp.where(odd, one, x)
2109
+ xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2))
2110
+ pk = d["pkm1"] * yk + d["pkm2"] * xk
2111
+ qk = d["qkm1"] * yk + d["qkm2"] * xk
2112
+ nz = qk != zero
2113
+ d["r"] = r = jnp.where(nz, pk / qk, d["r"])
2114
+ d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one)
2115
+ d["ans"] = jnp.where(nz, r, d["ans"])
2116
+ d["pkm2"] = d["pkm1"]
2117
+ d["pkm1"] = pk
2118
+ d["qkm2"] = d["qkm1"]
2119
+ d["qkm1"] = qk
2120
+ is_big = abs(pk) > BIG
2121
+ for s in "pq":
2122
+ for i in "12":
2123
+ key = s + "km" + i
2124
+ d[key] = jnp.where(is_big, d[key] / BIG, d[key])
2125
+ return d
2126
+
2127
+ def cond(d):
2128
+ return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP)
2129
+
2130
+ d = lax.while_loop(cond, body, init)
2131
+ return d["ans"] * jnp.exp(-x)
2132
+
2133
+
2134
+ def _expn3(n: int, x: Array) -> Array:
2135
+ # n >= 5000
2136
+ _c = _lax_const
2137
+ one = _c(x, 1.0)
2138
+ xk = x + n
2139
+ yk = one / (xk * xk)
2140
+ t = n
2141
+ ans = yk * t * (_c(x, 6) * x * x - _c(x, 8) * t * x + t * t)
2142
+ ans = yk * (ans + t * (t - _c(x, 2) * x))
2143
+ ans = yk * (ans + t)
2144
+ return (ans + one) * jnp.exp(-x) / xk
2145
+
2146
+
2147
+ @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,))
2148
+ @jnp.vectorize
2149
+ @jit
2150
+ def expn(n: ArrayLike, x: ArrayLike) -> Array:
2151
+ r"""Generalized exponential integral function.
2152
+
2153
+ JAX implementation of :obj:`scipy.special.expn`.
2154
+
2155
+ .. math::
2156
+
2157
+ \mathrm{expn}(x) = E_n(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t^n}\mathrm{d}t
2158
+
2159
+ Args:
2160
+ n: arraylike, real-valued
2161
+ x: arraylike, real-valued
2162
+
2163
+ Returns:
2164
+ array of expn values
2165
+
2166
+ See also:
2167
+ - :func:`jax.scipy.special.expi`
2168
+ - :func:`jax.scipy.special.exp1`
2169
+ """
2170
+ n, x = promote_args_inexact("expn", n, x)
2171
+ _c = _lax_const
2172
+ zero = _c(x, 0)
2173
+ one = _c(x, 1)
2174
+ conds = [
2175
+ (n < _c(n, 0)) | (x < zero),
2176
+ (x == zero) & (n < _c(n, 2)),
2177
+ (x == zero) & (n >= _c(n, 2)),
2178
+ (n == _c(n, 0)) & (x >= zero),
2179
+ (n >= _c(n, 5000)),
2180
+ (x > one),
2181
+ ]
2182
+ n1 = jnp.where(n == _c(n, 1), n + n, n)
2183
+ vals = [
2184
+ jnp.nan,
2185
+ jnp.inf,
2186
+ one / n1, # prevent div by zero
2187
+ jnp.exp(-x) / x,
2188
+ partial(_expn3, n),
2189
+ partial(_expn2, n),
2190
+ partial(_expn1, n),
2191
+ ]
2192
+ ret = jnp.piecewise(x, conds, vals)
2193
+ return ret
2194
+
2195
+
2196
+ @expn.defjvp
2197
+ @jit
2198
+ def expn_jvp(n, primals, tangents):
2199
+ (x,), (x_dot,) = primals, tangents
2200
+ return expn(n, x), lax.mul(
2201
+ lax.neg(x_dot), expn(lax.sub(n, _lax_const(n, 1)), x)
2202
+ )
2203
+
2204
+
2205
+ def exp1(x: ArrayLike) -> Array:
2206
+ r"""Exponential integral function.
2207
+
2208
+ JAX implementation of :obj:`scipy.special.exp1`
2209
+
2210
+ .. math::
2211
+
2212
+ \mathrm{exp1}(x) = E_1(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t}\mathrm{d}t
2213
+
2214
+
2215
+ Args:
2216
+ x: arraylike, real-valued
2217
+
2218
+ Returns:
2219
+ array of exp1 values
2220
+
2221
+ See also:
2222
+ - :func:`jax.scipy.special.expi`
2223
+ - :func:`jax.scipy.special.expn`
2224
+ """
2225
+ x, = promote_args_inexact("exp1", x)
2226
+ # Casting because custom_jvp generic does not work correctly with mypy.
2227
+ return cast(Array, expn(1, x))
2228
+
2229
+
2230
+ def _spence_poly(w: Array) -> Array:
2231
+ A = jnp.array([4.65128586073990045278E-5,
2232
+ 7.31589045238094711071E-3,
2233
+ 1.33847639578309018650E-1,
2234
+ 8.79691311754530315341E-1,
2235
+ 2.71149851196553469920E0,
2236
+ 4.25697156008121755724E0,
2237
+ 3.29771340985225106936E0,
2238
+ 1.00000000000000000126E0,
2239
+ ], dtype=w.dtype)
2240
+
2241
+ B = jnp.array([6.90990488912553276999E-4,
2242
+ 2.54043763932544379113E-2,
2243
+ 2.82974860602568089943E-1,
2244
+ 1.41172597751831069617E0,
2245
+ 3.63800533345137075418E0,
2246
+ 5.03278880143316990390E0,
2247
+ 3.54771340985225096217E0,
2248
+ 9.99999999999999998740E-1,
2249
+ ],dtype=w.dtype)
2250
+
2251
+ return -w * jnp.polyval(A, w) / jnp.polyval(B, w)
2252
+
2253
+
2254
+ def _spence_calc(x: Array) -> Array:
2255
+ x2_bool = x > 2.0
2256
+ x = jnp.piecewise(x, [x2_bool],
2257
+ [lambda x: 1.0 / x, lambda x: x])
2258
+
2259
+ x1_5_bool = x > 1.5
2260
+ x_5_bool = x < 0.5
2261
+ x2_bool = x2_bool | x1_5_bool
2262
+
2263
+ w = jnp.piecewise(x,
2264
+ [x1_5_bool, x_5_bool],
2265
+ [lambda x: 1.0 / x - 1.0,
2266
+ lambda x: -x,
2267
+ lambda x: x - 1.0])
2268
+
2269
+ y = _spence_poly(w)
2270
+ y_flag_one = jnp.pi ** 2 / 6.0 - jnp.log(x) * jnp.log(1.0 - x) - y
2271
+ y = jnp.where(x_5_bool, y_flag_one, y)
2272
+ y_flag_two = -0.5 * jnp.log(x) ** 2 - y
2273
+ return jnp.where(x2_bool, y_flag_two, y)
2274
+
2275
+
2276
+ def _spence(x: Array) -> Array:
2277
+ return jnp.piecewise(x,
2278
+ [x < 0.0, x == 1.0, x == 0.0],
2279
+ [jnp.nan, 0, jnp.pi ** 2 / 6, _spence_calc])
2280
+
2281
+
2282
+ def spence(x: Array) -> Array:
2283
+ r"""Spence's function, also known as the dilogarithm for real values.
2284
+
2285
+ JAX implementation of :obj:`scipy.special.spence`.
2286
+
2287
+ It is defined to be:
2288
+
2289
+ .. math::
2290
+ \mathrm{spence}(x) = \begin{equation}
2291
+ \int_1^x \frac{\log(t)}{1 - t}dt
2292
+ \end{equation}
2293
+
2294
+ Unlike the SciPy implementation, this is only defined for positive
2295
+ real values of `z`. For negative values, `NaN` is returned.
2296
+
2297
+ Args:
2298
+ z: An array of type `float32`, `float64`.
2299
+
2300
+ Returns:
2301
+ An array with `dtype=z.dtype`.
2302
+ computed values of Spence's function.
2303
+
2304
+ Raises:
2305
+ TypeError: if elements of array `z` are not in (float32, float64).
2306
+
2307
+ Notes:
2308
+ There is a different convention which defines Spence's function by the
2309
+ integral:
2310
+
2311
+ .. math::
2312
+ \begin{equation}
2313
+ -\int_0^z \frac{\log(1 - t)}{t}dt
2314
+ \end{equation}
2315
+
2316
+ This is our spence(1 - z).
2317
+ """
2318
+ x = jnp.asarray(x)
2319
+ dtype = lax.dtype(x)
2320
+ if dtype not in (jnp.float32, jnp.float64):
2321
+ raise TypeError(
2322
+ f"x.dtype={dtype} is not supported, see docstring for supported types.")
2323
+ return _spence(x)
2324
+
2325
+
2326
+ def bernoulli(n: int) -> Array:
2327
+ """Generate the first N Bernoulli numbers.
2328
+
2329
+ JAX implementation of :func:`scipy.special.bernoulli`.
2330
+
2331
+ Args:
2332
+ n: integer, the number of Bernoulli terms to generate.
2333
+
2334
+ Returns:
2335
+ Array containing the first ``n`` Bernoulli numbers.
2336
+
2337
+ Notes:
2338
+ ``bernoulli`` generates numbers using the :math:`B_n^-` convention,
2339
+ such that :math:`B_1=-1/2`.
2340
+ """
2341
+ # Generate Bernoulli numbers using the Chowla and Hartung algorithm.
2342
+ n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli")
2343
+ if n < 0:
2344
+ raise ValueError("n must be a non-negative integer.")
2345
+ b3 = jnp.array([1, -1/2, 1/6])
2346
+ if n < 3:
2347
+ return b3[:n + 1]
2348
+ bn = jnp.zeros(n + 1).at[:3].set(b3)
2349
+ m = jnp.arange(4, n + 1, 2, dtype=bn.dtype)
2350
+ q1 = (1. / jnp.pi ** 2) * jnp.cumprod(-(m - 1) * m / 4 / jnp.pi ** 2)
2351
+ k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15
2352
+ q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0)
2353
+ return bn.at[4::2].set(q1 * (1 + q2))
2354
+
2355
+
2356
+ @custom_derivatives.custom_jvp
2357
+ def poch(z: ArrayLike, m: ArrayLike) -> Array:
2358
+ r"""The Pochammer symbol.
2359
+
2360
+ JAX implementation of :obj:`scipy.special.poch`.
2361
+
2362
+ .. math::
2363
+
2364
+ \mathrm{poch}(z, m) = (z)_m = \frac{\Gamma(z + m)}{\Gamma(z)}
2365
+
2366
+ where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function.
2367
+
2368
+ Args:
2369
+ z: arraylike, real-valued
2370
+ m: arraylike, real-valued
2371
+
2372
+ Returns:
2373
+ array of Pochammer values.
2374
+
2375
+ Notes:
2376
+ The JAX version supports only real-valued inputs.
2377
+ """
2378
+ z, m = promote_args_inexact("poch", z, m)
2379
+
2380
+ return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z))
2381
+
2382
+
2383
+ def _poch_z_derivative(z, m):
2384
+ """
2385
+ Defined in :
2386
+ https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/
2387
+ """
2388
+
2389
+ return (digamma(z + m) - digamma(z)) * poch(z, m)
2390
+
2391
+
2392
+ def _poch_m_derivative(z, m):
2393
+ """
2394
+ Defined in :
2395
+ https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/
2396
+ """
2397
+
2398
+ return digamma(z + m) * poch(z, m)
2399
+
2400
+
2401
+ poch.defjvps(
2402
+ lambda z_dot, primal_out, z, m: _poch_z_derivative(z, m) * z_dot,
2403
+ lambda m_dot, primal_out, z, m: _poch_m_derivative(z, m) * m_dot,
2404
+ )
2405
+
2406
+
2407
+ def _hyp1f1_serie(a, b, x):
2408
+ """
2409
+ Compute the 1F1 hypergeometric function using the taylor expansion
2410
+ See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014
2411
+ https://doi.org/10.48550/arXiv.1407.7786
2412
+ """
2413
+
2414
+ precision = jnp.finfo(x.dtype).eps
2415
+
2416
+ def body(state):
2417
+ serie, k, term = state
2418
+ serie += term
2419
+ term *= (a + k) / (b + k) * x / (k + 1)
2420
+ k += 1
2421
+
2422
+ return serie, k, term
2423
+
2424
+ def cond(state):
2425
+ serie, k, term = state
2426
+
2427
+ return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2428
+
2429
+ init = 1, 1, a / b * x
2430
+
2431
+ return lax.while_loop(cond, body, init)[0]
2432
+
2433
+
2434
+ def _hyp1f1_asymptotic(a, b, x):
2435
+ """
2436
+ Compute the 1F1 hypergeometric function using asymptotic expansion
2437
+ See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014
2438
+ https://doi.org/10.48550/arXiv.1407.7786
2439
+ """
2440
+
2441
+ precision = jnp.finfo(x.dtype).eps
2442
+
2443
+ def body(state):
2444
+ serie, k, term = state
2445
+ serie += term
2446
+ term *= (b - a + k) * (1 - a + k) / (k + 1) / x
2447
+ k += 1
2448
+
2449
+ return serie, k, term
2450
+
2451
+ def cond(state):
2452
+ serie, k, term = state
2453
+
2454
+ return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2455
+
2456
+ init = 1, 1, (b - a) * (1 - a) / x
2457
+ serie = lax.while_loop(cond, body, init)[0]
2458
+
2459
+ return gamma(b) / gamma(a) * lax.exp(x) * x ** (a - b) * serie
2460
+
2461
+
2462
+ @jit
2463
+ @jnp.vectorize
2464
+ def _hyp1f1_a_derivative(a, b, x):
2465
+ """
2466
+ Define it as a serie using :
2467
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/
2468
+ """
2469
+
2470
+ precision = jnp.finfo(x.dtype).eps
2471
+
2472
+ def body(state):
2473
+ serie, k, term = state
2474
+ serie += term * (digamma(a + k) - digamma(a))
2475
+ term *= (a + k) / (b + k) * x / (k + 1)
2476
+ k += 1
2477
+
2478
+ return serie, k, term
2479
+
2480
+ def cond(state):
2481
+ serie, k, term = state
2482
+
2483
+ return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2484
+
2485
+ init = 0, 1, a / b * x
2486
+
2487
+ return lax.while_loop(cond, body, init)[0]
2488
+
2489
+
2490
+ @jit
2491
+ @jnp.vectorize
2492
+ def _hyp1f1_b_derivative(a, b, x):
2493
+ """
2494
+ Define it as a serie using :
2495
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/
2496
+ """
2497
+
2498
+ precision = jnp.finfo(x.dtype).eps
2499
+
2500
+ def body(state):
2501
+ serie, k, term = state
2502
+ serie += term * (digamma(b) - digamma(b + k))
2503
+ term *= (a + k) / (b + k) * x / (k + 1)
2504
+ k += 1
2505
+
2506
+ return serie, k, term
2507
+
2508
+ def cond(state):
2509
+ serie, k, term = state
2510
+
2511
+ return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision)
2512
+
2513
+ init = 0, 1, a / b * x
2514
+
2515
+ return lax.while_loop(cond, body, init)[0]
2516
+
2517
+
2518
+ @jit
2519
+ def _hyp1f1_x_derivative(a, b, x):
2520
+ """
2521
+ Define it as a serie using :
2522
+ https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/
2523
+ """
2524
+
2525
+ return a / b * hyp1f1(a + 1, b + 1, x)
2526
+
2527
+
2528
+ @custom_derivatives.custom_jvp
2529
+ @jit
2530
+ @jnp.vectorize
2531
+ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
2532
+ r"""The 1F1 hypergeometric function.
2533
+
2534
+ JAX implementation of :obj:`scipy.special.hyp1f1`.
2535
+
2536
+ .. math::
2537
+
2538
+ \mathrm{hyp1f1}(a, b, x) = {}_1F_1(x;a, b) = \sum_{k=0}^\infty \frac{(a)_k}{(b)_kk!}x^k
2539
+
2540
+ where :math:`(\cdot)_k` is the Pochammer symbol (refer to :func:`~jax.scipy.special.poch`).
2541
+
2542
+ The JAX version only accepts positive and real inputs. Values of ``a``, ``b``,
2543
+ and ``x``, leading to high values of 1F1 may lead to erroneous results;
2544
+ consider enabling double precision in this case. The convention for
2545
+ ``a = b = 0`` is ``1``, unlike in scipy's implementation.
2546
+
2547
+ Args:
2548
+ a: arraylike, real-valued
2549
+ b: arraylike, real-valued
2550
+ x: arraylike, real-valued
2551
+
2552
+ Returns:
2553
+ array of 1F1 values.
2554
+ """
2555
+ # This is backed by https://doi.org/10.48550/arXiv.1407.7786
2556
+ # There is room for improvement in the implementation using recursion to
2557
+ # evaluate lower values of hyp1f1 when a or b or both are > 60-80
2558
+ a, b, x = promote_args_inexact('hyp1f1', a, b, x)
2559
+
2560
+ result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x)
2561
+ index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3
2562
+
2563
+ return lax.select_n(index,
2564
+ result,
2565
+ jnp.array(1, dtype=x.dtype),
2566
+ jnp.exp(x),
2567
+ jnp.array(jnp.inf, dtype=x.dtype))
2568
+
2569
+
2570
+ hyp1f1.defjvps(
2571
+ lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot,
2572
+ lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot,
2573
+ lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot
2574
+ )