componavt's picture
+ temp old streamlit
723b602
Raw
History Blame Contribute Delete
3.11 kB
# Conditional import for neural ODE solver to handle cases where torch is not available
try:
import torch
import torch.nn as nn
import torch.optim as optim
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
import numpy as np
if TORCH_AVAILABLE:
class NeuralODE(nn.Module):
"""
Neural Network surrogate for solving ODEs
Takes time t as input and outputs (x(t), y(t))
"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 2), # Output x, y
)
def forward(self, t):
return self.net(t)
def train_neural_ode(t_train, x_train, y_train, epochs=2000, lr=1e-3):
"""
Train a neural network to approximate the solution on the training interval
Args:
t_train: array of time points for training
x_train: array of x values for training
y_train: array of y values for training
epochs: number of training epochs
lr: learning rate
Returns:
Trained model
"""
# Prepare training data
t_train_tensor = torch.tensor(t_train[:, None], dtype=torch.float32)
xy_train_tensor = torch.tensor(np.column_stack([x_train, y_train]), dtype=torch.float32)
# Initialize model
model = NeuralODE()
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
# Training loop
for epoch in range(epochs):
optimizer.zero_grad()
pred = model(t_train_tensor)
loss = loss_fn(pred, xy_train_tensor)
loss.backward()
optimizer.step()
if epoch % 500 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.6f}")
return model
def predict_with_neural_ode(model, t_eval):
"""
Predict solution using the trained neural network
Args:
model: trained NeuralODE model
t_eval: array of time points to evaluate
Returns:
x_pred, y_pred arrays
"""
if model is None:
return None, None
with torch.no_grad():
t_tensor = torch.tensor(t_eval[:, None], dtype=torch.float32)
pred = model(t_tensor).numpy()
x_pred = pred[:, 0]
y_pred = pred[:, 1]
return x_pred, y_pred
else:
def train_neural_ode(t_train, x_train, y_train, epochs=2000, lr=1e-3):
"""
Placeholder function when torch is not available
"""
print("PyTorch not available, skipping neural network training")
return None
def predict_with_neural_ode(model, t_eval):
"""
Placeholder function when torch is not available
"""
return None, None