| import jax.numpy as jnp |
| import jax |
| import torch |
| from dataclasses import dataclass |
| import sympy |
| import sympy as sp |
| from sympy import Matrix, Symbol |
| import math |
| from sde_redefined_param import SDEDimension |
| @dataclass |
| class SDEConfig: |
| name = "Custom" |
| variable = Symbol('t', nonnegative=True, real=True) |
|
|
| drift_dimension = SDEDimension.SCALAR |
| diffusion_dimension = SDEDimension.SCALAR |
| diffusion_matrix_dimension = SDEDimension.SCALAR |
|
|
| |
| drift_parameters = Matrix([sympy.symbols("f1")]) |
| diffusion_parameters = Matrix([sympy.symbols("l1")]) |
| |
| drift =-variable**2 * drift_parameters[0]**2 |
| k = 1 |
| diffusion = sympy.Piecewise((k * sympy.sin(variable/2 * sympy.pi), variable < 1), (k*1, variable >= 1)) |
| |
| diffusion_matrix = 1 |
|
|
| initial_variable_value = 0 |
| max_variable_value = 1 |
| min_sample_value = 1e-6 |
|
|
| module = 'jax' |
|
|
| drift_integral_form=True |
| diffusion_integral_form=True |
| diffusion_integral_decomposition = 'cholesky' |
|
|
|
|
|
|
| target = "epsilon" |
|
|