download
raw
6.19 kB
from numpy import *
from matplotlib.pyplot import *
import sys
def solver(I, a, T, dt, theta):
"""Solve u'=-a*u, u(0)=I, for t in (0,T] with steps of dt."""
dt = float(dt) # avoid integer division
Nt = int(round(T/dt)) # no of time intervals
T = Nt*dt # adjust T to fit time step dt
u = zeros(Nt+1) # array of u[n] values
t = linspace(0, T, Nt+1) # time mesh
u[0] = I # assign initial condition
for n in range(0, Nt): # n=0,1,...,Nt-1
u[n+1] = (1 - (1-theta)*a*dt)/(1 + theta*dt*a)*u[n]
return u, t
def verify_three_steps():
"""
Run three steps and compare with pre-computed correct values.
Return True if the computations are right.
"""
theta = 0.8; a = 2; I = 0.1; dt = 0.8
A = (1 - (1-theta)*a*dt)/(1 + theta*dt*a)
u1 = A*I
u2 = A*u1
u3 = A*u2
Nt = 3 # number of time steps
u, t = solver(I=I, a=a, T=Nt*dt, dt=dt, theta=theta)
tol = 1E-15 # tolerance for comparing floats
difference = abs(u1-u[1]) + abs(u2-u[2]) + abs(u3-u[3])
success = difference <= tol
return success
def verify_exact_discrete_solution():
"""
Compare the solution computed by solver
with a formula for the exact discrete solution.
Return True if the computations are right.
"""
def exact_discrete_solution(n, I, a, theta, dt):
A = (1 - (1-theta)*a*dt)/(1 + theta*dt*a)
return I*A**n
theta = 0.8; a = 2; I = 0.1; dt = 0.8
Nt = int(8/dt) # no of steps
u, t = solver(I=I, a=a, T=Nt*dt, dt=dt, theta=theta)
u_de = array([exact_discrete_solution(n, I, a, theta, dt)
for n in range(Nt+1)])
difference = abs(u_de - u).max() # max deviation
tol = 1E-15 # tolerance for comparing floats
success = difference <= tol
return success
def exact_solution(t, I, a):
return I*exp(-a*t)
def explore(I, a, T, dt, theta=0.5, makeplot=True):
"""
Run a case with the solver, compute error measure,
and plot the numerical and exact solutions (if makeplot=True).
"""
u, t = solver(I, a, T, dt, theta) # Numerical solution
u_e = exact_solution(t, I, a)
e = u_e - u
E = sqrt(dt*sum(e**2))
if makeplot:
figure() # create new plot
t_e = linspace(0, T, 1001) # very fine mesh for u_e
u_e = exact_solution(t_e, I, a)
plot(t, u, 'r--o') # red dashes w/circles
plot(t_e, u_e, 'b-') # blue line for u_e
legend(['numerical', 'exact'])
xlabel('t')
ylabel('u')
title('Method: theta-rule, theta=%g, dt=%g' % (theta, dt))
theta2name = {0: 'FE', 1: 'BE', 0.5: 'CN'}
savefig('%s_%g.png' % (theta2name[theta], dt))
show()
return E
def define_command_line_options():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--I', '--initial_condition', type=float,
default=1.0, help='initial condition, u(0)',
metavar='I')
parser.add_argument('--a', type=float,
default=1.0, help='coefficient in ODE',
metavar='a')
parser.add_argument('--T', '--stop_time', type=float,
default=1.0, help='end time of simulation',
metavar='T')
parser.add_argument('--makeplot', action='store_true',
help='display plot or not')
parser.add_argument('--dt', '--time_step_values', type=float,
default=[1.0], help='time step values',
metavar='dt', nargs='+', dest='dt_values')
return parser
def read_command_line(use_argparse=True):
if use_argparse:
parser = define_command_line_options()
args = parser.parse_args()
print 'I={}, a={}, makeplot={}, dt_values={}'.format(
args.I, args.a, args.makeplot, args.dt_values)
return args.I, args.a, args.T, args.makeplot, args.dt_values
else:
if len(sys.argv) < 6:
print 'Usage: %s I a on/off dt1 dt2 dt3 ...' % \
sys.argv[0]; sys.exit(1)
I = float(sys.argv[1])
a = float(sys.argv[2])
T = float(sys.argv[3])
makeplot = sys.argv[4] in ('on', 'True')
dt_values = [float(arg) for arg in sys.argv[5:]]
return I, a, T, makeplot, dt_values
def main():
I, a, T, makeplot, dt_values = read_command_line()
r = {}
for theta in 0, 0.5, 1:
E_values = []
for dt in dt_values:
E = explore(I, a, T, dt, theta, makeplot=False)
E_values.append(E)
# Compute convergence rates
m = len(dt_values)
r[theta] = [log(E_values[i-1]/E_values[i])/
log(dt_values[i-1]/dt_values[i])
for i in range(1, m, 1)]
for theta in r:
print '\nPairwise convergence rates for theta=%g:' % theta
print ' '.join(['%.2f' % r_ for r_ in r[theta]])
return r
def verify_convergence_rate():
r = main()
tol = 0.1
expected_rates = {0: 1, 1: 1, 0.5: 2}
for theta in r:
r_final = r[theta][-1]
diff = abs(expected_rates[theta] - r_final)
if diff > tol:
return False
return True # all tests passed
if __name__ == '__main__':
# Use verify and verify_rates on the command line to
# run verification tests. These command-line arguments
# must be removed before parameters are read from the
# command line.
if 'verify' in sys.argv:
sys.argv.remove('verify')
if verify_three_steps() and verify_discrete_solution():
pass # ok
else:
print 'Bug in the implementation!'
elif 'verify_rates' in sys.argv:
sys.argv.remove('verify_rates')
if not '--dt' in sys.argv:
print 'Must assign several dt values through the --dt option'
sys.exit(1) # abort
if verify_convergence_rate():
pass
else:
print 'Bug in the implementation!'
else:
# Perform simulations
main()

Xet Storage Details

Size:
6.19 kB
·
Xet hash:
d1301aca19c502bc0fb21ab47cd3efbd3966b266258cc2136be4cbc8f5b71d96

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.