shomez commited on
Commit
628ecb3
·
verified ·
1 Parent(s): e8d9ef4

Upload ctx_base.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ctx_base.py +494 -0
ctx_base.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import gt, lt
2
+
3
+ from .libmp.backend import xrange
4
+
5
+ from .functions.functions import SpecialFunctions
6
+ from .functions.rszeta import RSCache
7
+ from .calculus.quadrature import QuadratureMethods
8
+ from .calculus.inverselaplace import LaplaceTransformInversionMethods
9
+ from .calculus.calculus import CalculusMethods
10
+ from .calculus.optimization import OptimizationMethods
11
+ from .calculus.odes import ODEMethods
12
+ from .matrices.matrices import MatrixMethods
13
+ from .matrices.calculus import MatrixCalculusMethods
14
+ from .matrices.linalg import LinearAlgebraMethods
15
+ from .matrices.eigen import Eigen
16
+ from .identification import IdentificationMethods
17
+ from .visualization import VisualizationMethods
18
+
19
+ from . import libmp
20
+
21
+ class Context(object):
22
+ pass
23
+
24
+ class StandardBaseContext(Context,
25
+ SpecialFunctions,
26
+ RSCache,
27
+ QuadratureMethods,
28
+ LaplaceTransformInversionMethods,
29
+ CalculusMethods,
30
+ MatrixMethods,
31
+ MatrixCalculusMethods,
32
+ LinearAlgebraMethods,
33
+ Eigen,
34
+ IdentificationMethods,
35
+ OptimizationMethods,
36
+ ODEMethods,
37
+ VisualizationMethods):
38
+
39
+ NoConvergence = libmp.NoConvergence
40
+ ComplexResult = libmp.ComplexResult
41
+
42
+ def __init__(ctx):
43
+ ctx._aliases = {}
44
+ # Call those that need preinitialization (e.g. for wrappers)
45
+ SpecialFunctions.__init__(ctx)
46
+ RSCache.__init__(ctx)
47
+ QuadratureMethods.__init__(ctx)
48
+ LaplaceTransformInversionMethods.__init__(ctx)
49
+ CalculusMethods.__init__(ctx)
50
+ MatrixMethods.__init__(ctx)
51
+
52
+ def _init_aliases(ctx):
53
+ for alias, value in ctx._aliases.items():
54
+ try:
55
+ setattr(ctx, alias, getattr(ctx, value))
56
+ except AttributeError:
57
+ pass
58
+
59
+ _fixed_precision = False
60
+
61
+ # XXX
62
+ verbose = False
63
+
64
+ def warn(ctx, msg):
65
+ print("Warning:", msg)
66
+
67
+ def bad_domain(ctx, msg):
68
+ raise ValueError(msg)
69
+
70
+ def _re(ctx, x):
71
+ if hasattr(x, "real"):
72
+ return x.real
73
+ return x
74
+
75
+ def _im(ctx, x):
76
+ if hasattr(x, "imag"):
77
+ return x.imag
78
+ return ctx.zero
79
+
80
+ def _as_points(ctx, x):
81
+ return x
82
+
83
+ def fneg(ctx, x, **kwargs):
84
+ return -ctx.convert(x)
85
+
86
+ def fadd(ctx, x, y, **kwargs):
87
+ return ctx.convert(x)+ctx.convert(y)
88
+
89
+ def fsub(ctx, x, y, **kwargs):
90
+ return ctx.convert(x)-ctx.convert(y)
91
+
92
+ def fmul(ctx, x, y, **kwargs):
93
+ return ctx.convert(x)*ctx.convert(y)
94
+
95
+ def fdiv(ctx, x, y, **kwargs):
96
+ return ctx.convert(x)/ctx.convert(y)
97
+
98
+ def fsum(ctx, args, absolute=False, squared=False):
99
+ if absolute:
100
+ if squared:
101
+ return sum((abs(x)**2 for x in args), ctx.zero)
102
+ return sum((abs(x) for x in args), ctx.zero)
103
+ if squared:
104
+ return sum((x**2 for x in args), ctx.zero)
105
+ return sum(args, ctx.zero)
106
+
107
+ def fdot(ctx, xs, ys=None, conjugate=False):
108
+ if ys is not None:
109
+ xs = zip(xs, ys)
110
+ if conjugate:
111
+ cf = ctx.conj
112
+ return sum((x*cf(y) for (x,y) in xs), ctx.zero)
113
+ else:
114
+ return sum((x*y for (x,y) in xs), ctx.zero)
115
+
116
+ def fprod(ctx, args):
117
+ prod = ctx.one
118
+ for arg in args:
119
+ prod *= arg
120
+ return prod
121
+
122
+ def nprint(ctx, x, n=6, **kwargs):
123
+ """
124
+ Equivalent to ``print(nstr(x, n))``.
125
+ """
126
+ print(ctx.nstr(x, n, **kwargs))
127
+
128
+ def chop(ctx, x, tol=None):
129
+ """
130
+ Chops off small real or imaginary parts, or converts
131
+ numbers close to zero to exact zeros. The input can be a
132
+ single number or an iterable::
133
+
134
+ >>> from mpmath import *
135
+ >>> mp.dps = 15; mp.pretty = False
136
+ >>> chop(5+1e-10j, tol=1e-9)
137
+ mpf('5.0')
138
+ >>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
139
+ [1.0, 0.0, 3.0, -4.0, 2.0]
140
+
141
+ The tolerance defaults to ``100*eps``.
142
+ """
143
+ if tol is None:
144
+ tol = 100*ctx.eps
145
+ try:
146
+ x = ctx.convert(x)
147
+ absx = abs(x)
148
+ if abs(x) < tol:
149
+ return ctx.zero
150
+ if ctx._is_complex_type(x):
151
+ #part_tol = min(tol, absx*tol)
152
+ part_tol = max(tol, absx*tol)
153
+ if abs(x.imag) < part_tol:
154
+ return x.real
155
+ if abs(x.real) < part_tol:
156
+ return ctx.mpc(0, x.imag)
157
+ except TypeError:
158
+ if isinstance(x, ctx.matrix):
159
+ return x.apply(lambda a: ctx.chop(a, tol))
160
+ if hasattr(x, "__iter__"):
161
+ return [ctx.chop(a, tol) for a in x]
162
+ return x
163
+
164
+ def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
165
+ r"""
166
+ Determine whether the difference between `s` and `t` is smaller
167
+ than a given epsilon, either relatively or absolutely.
168
+
169
+ Both a maximum relative difference and a maximum difference
170
+ ('epsilons') may be specified. The absolute difference is
171
+ defined as `|s-t|` and the relative difference is defined
172
+ as `|s-t|/\max(|s|, |t|)`.
173
+
174
+ If only one epsilon is given, both are set to the same value.
175
+ If none is given, both epsilons are set to `2^{-p+m}` where
176
+ `p` is the current working precision and `m` is a small
177
+ integer. The default setting typically allows :func:`~mpmath.almosteq`
178
+ to be used to check for mathematical equality
179
+ in the presence of small rounding errors.
180
+
181
+ **Examples**
182
+
183
+ >>> from mpmath import *
184
+ >>> mp.dps = 15
185
+ >>> almosteq(3.141592653589793, 3.141592653589790)
186
+ True
187
+ >>> almosteq(3.141592653589793, 3.141592653589700)
188
+ False
189
+ >>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
190
+ True
191
+ >>> almosteq(1e-20, 2e-20)
192
+ True
193
+ >>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
194
+ False
195
+
196
+ """
197
+ t = ctx.convert(t)
198
+ if abs_eps is None and rel_eps is None:
199
+ rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
200
+ if abs_eps is None:
201
+ abs_eps = rel_eps
202
+ elif rel_eps is None:
203
+ rel_eps = abs_eps
204
+ diff = abs(s-t)
205
+ if diff <= abs_eps:
206
+ return True
207
+ abss = abs(s)
208
+ abst = abs(t)
209
+ if abss < abst:
210
+ err = diff/abst
211
+ else:
212
+ err = diff/abss
213
+ return err <= rel_eps
214
+
215
+ def arange(ctx, *args):
216
+ r"""
217
+ This is a generalized version of Python's :func:`~mpmath.range` function
218
+ that accepts fractional endpoints and step sizes and
219
+ returns a list of ``mpf`` instances. Like :func:`~mpmath.range`,
220
+ :func:`~mpmath.arange` can be called with 1, 2 or 3 arguments:
221
+
222
+ ``arange(b)``
223
+ `[0, 1, 2, \ldots, x]`
224
+ ``arange(a, b)``
225
+ `[a, a+1, a+2, \ldots, x]`
226
+ ``arange(a, b, h)``
227
+ `[a, a+h, a+h, \ldots, x]`
228
+
229
+ where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
230
+
231
+ Like Python's :func:`~mpmath.range`, the endpoint is not included. To
232
+ produce ranges where the endpoint is included, :func:`~mpmath.linspace`
233
+ is more convenient.
234
+
235
+ **Examples**
236
+
237
+ >>> from mpmath import *
238
+ >>> mp.dps = 15; mp.pretty = False
239
+ >>> arange(4)
240
+ [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
241
+ >>> arange(1, 2, 0.25)
242
+ [mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
243
+ >>> arange(1, -1, -0.75)
244
+ [mpf('1.0'), mpf('0.25'), mpf('-0.5')]
245
+
246
+ """
247
+ if not len(args) <= 3:
248
+ raise TypeError('arange expected at most 3 arguments, got %i'
249
+ % len(args))
250
+ if not len(args) >= 1:
251
+ raise TypeError('arange expected at least 1 argument, got %i'
252
+ % len(args))
253
+ # set default
254
+ a = 0
255
+ dt = 1
256
+ # interpret arguments
257
+ if len(args) == 1:
258
+ b = args[0]
259
+ elif len(args) >= 2:
260
+ a = args[0]
261
+ b = args[1]
262
+ if len(args) == 3:
263
+ dt = args[2]
264
+ a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
265
+ assert a + dt != a, 'dt is too small and would cause an infinite loop'
266
+ # adapt code for sign of dt
267
+ if a > b:
268
+ if dt > 0:
269
+ return []
270
+ op = gt
271
+ else:
272
+ if dt < 0:
273
+ return []
274
+ op = lt
275
+ # create list
276
+ result = []
277
+ i = 0
278
+ t = a
279
+ while 1:
280
+ t = a + dt*i
281
+ i += 1
282
+ if op(t, b):
283
+ result.append(t)
284
+ else:
285
+ break
286
+ return result
287
+
288
+ def linspace(ctx, *args, **kwargs):
289
+ """
290
+ ``linspace(a, b, n)`` returns a list of `n` evenly spaced
291
+ samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
292
+ is also valid.
293
+
294
+ This function is often more convenient than :func:`~mpmath.arange`
295
+ for partitioning an interval into subintervals, since
296
+ the endpoint is included::
297
+
298
+ >>> from mpmath import *
299
+ >>> mp.dps = 15; mp.pretty = False
300
+ >>> linspace(1, 4, 4)
301
+ [mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
302
+
303
+ You may also provide the keyword argument ``endpoint=False``::
304
+
305
+ >>> linspace(1, 4, 4, endpoint=False)
306
+ [mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
307
+
308
+ """
309
+ if len(args) == 3:
310
+ a = ctx.mpf(args[0])
311
+ b = ctx.mpf(args[1])
312
+ n = int(args[2])
313
+ elif len(args) == 2:
314
+ assert hasattr(args[0], '_mpi_')
315
+ a = args[0].a
316
+ b = args[0].b
317
+ n = int(args[1])
318
+ else:
319
+ raise TypeError('linspace expected 2 or 3 arguments, got %i' \
320
+ % len(args))
321
+ if n < 1:
322
+ raise ValueError('n must be greater than 0')
323
+ if not 'endpoint' in kwargs or kwargs['endpoint']:
324
+ if n == 1:
325
+ return [ctx.mpf(a)]
326
+ step = (b - a) / ctx.mpf(n - 1)
327
+ y = [i*step + a for i in xrange(n)]
328
+ y[-1] = b
329
+ else:
330
+ step = (b - a) / ctx.mpf(n)
331
+ y = [i*step + a for i in xrange(n)]
332
+ return y
333
+
334
+ def cos_sin(ctx, z, **kwargs):
335
+ return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
336
+
337
+ def cospi_sinpi(ctx, z, **kwargs):
338
+ return ctx.cospi(z, **kwargs), ctx.sinpi(z, **kwargs)
339
+
340
+ def _default_hyper_maxprec(ctx, p):
341
+ return int(1000 * p**0.25 + 4*p)
342
+
343
+ _gcd = staticmethod(libmp.gcd)
344
+ list_primes = staticmethod(libmp.list_primes)
345
+ isprime = staticmethod(libmp.isprime)
346
+ bernfrac = staticmethod(libmp.bernfrac)
347
+ moebius = staticmethod(libmp.moebius)
348
+ _ifac = staticmethod(libmp.ifac)
349
+ _eulernum = staticmethod(libmp.eulernum)
350
+ _stirling1 = staticmethod(libmp.stirling1)
351
+ _stirling2 = staticmethod(libmp.stirling2)
352
+
353
+ def sum_accurately(ctx, terms, check_step=1):
354
+ prec = ctx.prec
355
+ try:
356
+ extraprec = 10
357
+ while 1:
358
+ ctx.prec = prec + extraprec + 5
359
+ max_mag = ctx.ninf
360
+ s = ctx.zero
361
+ k = 0
362
+ for term in terms():
363
+ s += term
364
+ if (not k % check_step) and term:
365
+ term_mag = ctx.mag(term)
366
+ max_mag = max(max_mag, term_mag)
367
+ sum_mag = ctx.mag(s)
368
+ if sum_mag - term_mag > ctx.prec:
369
+ break
370
+ k += 1
371
+ cancellation = max_mag - sum_mag
372
+ if cancellation != cancellation:
373
+ break
374
+ if cancellation < extraprec or ctx._fixed_precision:
375
+ break
376
+ extraprec += min(ctx.prec, cancellation)
377
+ return s
378
+ finally:
379
+ ctx.prec = prec
380
+
381
+ def mul_accurately(ctx, factors, check_step=1):
382
+ prec = ctx.prec
383
+ try:
384
+ extraprec = 10
385
+ while 1:
386
+ ctx.prec = prec + extraprec + 5
387
+ max_mag = ctx.ninf
388
+ one = ctx.one
389
+ s = one
390
+ k = 0
391
+ for factor in factors():
392
+ s *= factor
393
+ term = factor - one
394
+ if (not k % check_step):
395
+ term_mag = ctx.mag(term)
396
+ max_mag = max(max_mag, term_mag)
397
+ sum_mag = ctx.mag(s-one)
398
+ #if sum_mag - term_mag > ctx.prec:
399
+ # break
400
+ if -term_mag > ctx.prec:
401
+ break
402
+ k += 1
403
+ cancellation = max_mag - sum_mag
404
+ if cancellation != cancellation:
405
+ break
406
+ if cancellation < extraprec or ctx._fixed_precision:
407
+ break
408
+ extraprec += min(ctx.prec, cancellation)
409
+ return s
410
+ finally:
411
+ ctx.prec = prec
412
+
413
+ def power(ctx, x, y):
414
+ r"""Converts `x` and `y` to mpmath numbers and evaluates
415
+ `x^y = \exp(y \log(x))`::
416
+
417
+ >>> from mpmath import *
418
+ >>> mp.dps = 30; mp.pretty = True
419
+ >>> power(2, 0.5)
420
+ 1.41421356237309504880168872421
421
+
422
+ This shows the leading few digits of a large Mersenne prime
423
+ (performing the exact calculation ``2**43112609-1`` and
424
+ displaying the result in Python would be very slow)::
425
+
426
+ >>> power(2, 43112609)-1
427
+ 3.16470269330255923143453723949e+12978188
428
+ """
429
+ return ctx.convert(x) ** ctx.convert(y)
430
+
431
+ def _zeta_int(ctx, n):
432
+ return ctx.zeta(n)
433
+
434
+ def maxcalls(ctx, f, N):
435
+ """
436
+ Return a wrapped copy of *f* that raises ``NoConvergence`` when *f*
437
+ has been called more than *N* times::
438
+
439
+ >>> from mpmath import *
440
+ >>> mp.dps = 15
441
+ >>> f = maxcalls(sin, 10)
442
+ >>> print(sum(f(n) for n in range(10)))
443
+ 1.95520948210738
444
+ >>> f(10) # doctest: +IGNORE_EXCEPTION_DETAIL
445
+ Traceback (most recent call last):
446
+ ...
447
+ NoConvergence: maxcalls: function evaluated 10 times
448
+
449
+ """
450
+ counter = [0]
451
+ def f_maxcalls_wrapped(*args, **kwargs):
452
+ counter[0] += 1
453
+ if counter[0] > N:
454
+ raise ctx.NoConvergence("maxcalls: function evaluated %i times" % N)
455
+ return f(*args, **kwargs)
456
+ return f_maxcalls_wrapped
457
+
458
+ def memoize(ctx, f):
459
+ """
460
+ Return a wrapped copy of *f* that caches computed values, i.e.
461
+ a memoized copy of *f*. Values are only reused if the cached precision
462
+ is equal to or higher than the working precision::
463
+
464
+ >>> from mpmath import *
465
+ >>> mp.dps = 15; mp.pretty = True
466
+ >>> f = memoize(maxcalls(sin, 1))
467
+ >>> f(2)
468
+ 0.909297426825682
469
+ >>> f(2)
470
+ 0.909297426825682
471
+ >>> mp.dps = 25
472
+ >>> f(2) # doctest: +IGNORE_EXCEPTION_DETAIL
473
+ Traceback (most recent call last):
474
+ ...
475
+ NoConvergence: maxcalls: function evaluated 1 times
476
+
477
+ """
478
+ f_cache = {}
479
+ def f_cached(*args, **kwargs):
480
+ if kwargs:
481
+ key = args, tuple(kwargs.items())
482
+ else:
483
+ key = args
484
+ prec = ctx.prec
485
+ if key in f_cache:
486
+ cprec, cvalue = f_cache[key]
487
+ if cprec >= prec:
488
+ return +cvalue
489
+ value = f(*args, **kwargs)
490
+ f_cache[key] = (prec, value)
491
+ return value
492
+ f_cached.__name__ = f.__name__
493
+ f_cached.__doc__ = f.__doc__
494
+ return f_cached