|
|
import torch |
|
|
def epxm1_x(x): |
|
|
|
|
|
result = torch.special.expm1(x) / x |
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
mask = torch.abs(x) < 1e-2 |
|
|
result = torch.where(mask, 1 + x/2. + x**2 / 6., result) |
|
|
return result |
|
|
def epxm1mx_x2(x): |
|
|
|
|
|
result = (torch.special.expm1(x) - x) / x**2 |
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
mask = torch.abs(x**2) < 1e-2 |
|
|
result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) |
|
|
return result |
|
|
|
|
|
def expm1mxmhx2_x3(x): |
|
|
|
|
|
result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 |
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
mask = torch.abs(x**3) < 1e-2 |
|
|
result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) |
|
|
return result |
|
|
|
|
|
def exp_1mcosh_GD(gamma_t, delta): |
|
|
""" |
|
|
Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) |
|
|
|
|
|
Parameters: |
|
|
gamma_t: Γ*t term (could be a scalar or tensor) |
|
|
delta: Δ term (could be a scalar or tensor) |
|
|
|
|
|
Returns: |
|
|
Result of the computation with numerical stability handling |
|
|
""" |
|
|
|
|
|
is_positive = delta > 0 |
|
|
sqrt_abs_delta = torch.sqrt(torch.abs(delta)) |
|
|
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta |
|
|
numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 |
|
|
numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) |
|
|
numerator = torch.where(is_positive, numerator_pos, numerator_neg) |
|
|
result = numerator / (delta * gamma_t**2 ) |
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
|
|
|
mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 |
|
|
taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) |
|
|
result = torch.where(mask, taylor, result) |
|
|
return result |
|
|
|
|
|
def exp_sinh_GsqrtD(gamma_t, delta): |
|
|
""" |
|
|
Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) |
|
|
|
|
|
Parameters: |
|
|
gamma_t: Γ*t term (could be a scalar or tensor) |
|
|
delta: Δ term (could be a scalar or tensor) |
|
|
|
|
|
Returns: |
|
|
Result of the computation with numerical stability handling |
|
|
""" |
|
|
|
|
|
is_positive = delta > 0 |
|
|
sqrt_abs_delta = torch.sqrt(torch.abs(delta)) |
|
|
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta |
|
|
numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 |
|
|
denominator_pos = gamma_t_sqrt_delta |
|
|
result_pos = numerator_pos / gamma_t_sqrt_delta |
|
|
result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) |
|
|
|
|
|
|
|
|
mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 |
|
|
taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) |
|
|
result_pos = torch.where(mask, taylor, result_pos) |
|
|
|
|
|
|
|
|
result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) |
|
|
result = torch.where(is_positive, result_pos, result_neg) |
|
|
return result |
|
|
|
|
|
def exp_cosh(gamma_t, delta): |
|
|
""" |
|
|
Compute e^(-Γt) * cosh(Γt√Δ) |
|
|
|
|
|
Parameters: |
|
|
gamma_t: Γ*t term (could be a scalar or tensor) |
|
|
delta: Δ term (could be a scalar or tensor) |
|
|
|
|
|
Returns: |
|
|
Result of the computation with numerical stability handling |
|
|
""" |
|
|
exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) |
|
|
result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result |
|
|
return result |
|
|
def exp_sinh_sqrtD(gamma_t, delta): |
|
|
""" |
|
|
Compute e^(-Γt) * sinh(Γt√Δ) / √Δ |
|
|
Parameters: |
|
|
gamma_t: Γ*t term (could be a scalar or tensor) |
|
|
delta: Δ term (could be a scalar or tensor) |
|
|
Returns: |
|
|
Result of the computation with numerical stability handling |
|
|
""" |
|
|
exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) |
|
|
result = gamma_t * exp_sinh_GsqrtD_result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def zeta1(gamma_t, delta): |
|
|
|
|
|
half_gamma_t = gamma_t / 2 |
|
|
exp_cosh_term = exp_cosh(half_gamma_t, delta) |
|
|
exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) |
|
|
|
|
|
|
|
|
|
|
|
numerator = 1 - (exp_cosh_term + exp_sinh_term) |
|
|
denominator = gamma_t * (1 - delta) / 4 |
|
|
result = 1 - numerator / denominator |
|
|
|
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
|
|
|
|
|
|
mask = torch.abs(denominator) < 5e-3 |
|
|
term1 = epxm1_x(-gamma_t) |
|
|
term2 = epxm1mx_x2(-gamma_t) |
|
|
term3 = expm1mxmhx2_x3(-gamma_t) |
|
|
taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 |
|
|
result = torch.where(mask, taylor, result) |
|
|
|
|
|
return result |
|
|
|
|
|
def exp_cosh_minus_terms(gamma_t, delta): |
|
|
""" |
|
|
Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) |
|
|
|
|
|
Parameters: |
|
|
gamma_t: Γ*t term (could be a scalar or tensor) |
|
|
delta: Δ term (could be a scalar or tensor) |
|
|
|
|
|
Returns: |
|
|
Result of the computation with numerical stability handling |
|
|
""" |
|
|
exp_term = torch.exp(-gamma_t) |
|
|
|
|
|
exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term |
|
|
exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) |
|
|
|
|
|
|
|
|
|
|
|
numerator = exp_cosh_term - exp_cosh_delta_term |
|
|
denominator = gamma_t * (1 - delta) |
|
|
|
|
|
result = numerator / denominator |
|
|
|
|
|
|
|
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) |
|
|
|
|
|
|
|
|
mask = (torch.abs(denominator) < 1e-1) |
|
|
exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) |
|
|
taylor = ( |
|
|
gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) |
|
|
- denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) |
|
|
) |
|
|
result = torch.where(mask, taylor, result) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def zeta2(gamma_t, delta): |
|
|
half_gamma_t = gamma_t / 2 |
|
|
return exp_sinh_GsqrtD(half_gamma_t, delta) |
|
|
|
|
|
def sig11(gamma_t, delta): |
|
|
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) |
|
|
|
|
|
|
|
|
def Zcoefs(gamma_t, delta): |
|
|
Zeta1 = zeta1(gamma_t, delta) |
|
|
Zeta2 = zeta2(gamma_t, delta) |
|
|
|
|
|
sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 |
|
|
amplitude = torch.sqrt(sq_total) |
|
|
Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude |
|
|
Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 |
|
|
|
|
|
|
|
|
|
|
|
Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) |
|
|
|
|
|
return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude |
|
|
|
|
|
def Zcoefs_asymp(gamma_t, delta): |
|
|
A_t = (gamma_t * (1 - delta) )/4 |
|
|
return epxm1_x(- 2 * A_t) |
|
|
|
|
|
class StochasticHarmonicOscillator: |
|
|
""" |
|
|
Simulates a stochastic harmonic oscillator governed by the equations: |
|
|
dy(t) = q(t) dt |
|
|
dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt |
|
|
|
|
|
Also define v(t) = q(t) / √Γ, which is numerically more stable. |
|
|
|
|
|
Where: |
|
|
y(t) - Position variable |
|
|
q(t) - Velocity variable |
|
|
Γ - Damping coefficient |
|
|
A - Harmonic potential strength |
|
|
C - Constant force term |
|
|
D - Noise amplitude |
|
|
dw(t) - Wiener process (Brownian motion) |
|
|
""" |
|
|
def __init__(self, Gamma, A, C, D): |
|
|
self.Gamma = Gamma |
|
|
self.A = A |
|
|
self.C = C |
|
|
self.D = D |
|
|
self.Delta = 1 - 4 * A / Gamma |
|
|
def sig11(self, gamma_t, delta): |
|
|
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) |
|
|
def sig22(self, gamma_t, delta): |
|
|
return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) |
|
|
def dynamics(self, y0, v0, t): |
|
|
""" |
|
|
Calculates the position and velocity variables at time t. |
|
|
|
|
|
Parameters: |
|
|
y0 (float): Initial position |
|
|
v0 (float): Initial velocity v(0) = q(0) / √Γ |
|
|
t (float): Time at which to evaluate the dynamics |
|
|
Returns: |
|
|
tuple: (y(t), v(t)) |
|
|
""" |
|
|
|
|
|
dummyzero = y0.new_zeros(1) |
|
|
Delta = self.Delta + dummyzero |
|
|
Gamma_hat = self.Gamma * t + dummyzero |
|
|
A = self.A + dummyzero |
|
|
C = self.C + dummyzero |
|
|
D = self.D + dummyzero |
|
|
Gamma = self.Gamma + dummyzero |
|
|
zeta_1 = zeta1( Gamma_hat, Delta) |
|
|
zeta_2 = zeta2( Gamma_hat, Delta) |
|
|
EE = 1 - Gamma_hat * zeta_2 |
|
|
|
|
|
if v0 is None: |
|
|
v0 = torch.randn_like(y0) * D / 2 ** 0.5 |
|
|
|
|
|
|
|
|
|
|
|
term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t |
|
|
y_mean = term1 + y0 |
|
|
v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 |
|
|
|
|
|
cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) |
|
|
cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 |
|
|
cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) |
|
|
|
|
|
|
|
|
|
|
|
batch_shape = y0.shape |
|
|
cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) |
|
|
cov_matrix[..., 0, 0] = cov_yy |
|
|
cov_matrix[..., 0, 1] = cov_yv |
|
|
cov_matrix[..., 1, 0] = cov_yv |
|
|
cov_matrix[..., 1, 1] = cov_vv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) |
|
|
tol = 1e-8 |
|
|
cov_yy = torch.clamp( cov_yy, min = tol ) |
|
|
sd_yy = torch.sqrt( cov_yy ) |
|
|
inv_sd_yy = 1/(sd_yy) |
|
|
|
|
|
scale_tril[..., 0, 0] = sd_yy |
|
|
scale_tril[..., 0, 1] = 0. |
|
|
scale_tril[..., 1, 0] = cov_yv * inv_sd_yy |
|
|
scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 |
|
|
|
|
|
|
|
|
|
|
|
mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) |
|
|
mean[..., 0] = y_mean |
|
|
mean[..., 1] = v_mean |
|
|
new_yv = torch.distributions.MultivariateNormal( |
|
|
loc=mean, |
|
|
scale_tril=scale_tril |
|
|
).sample() |
|
|
|
|
|
return new_yv[...,0], new_yv[...,1] |