AICME-runtime / sim_priors_pk /data /extra /compartment_models_vectorized.py
cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
import numpy as np
import torch
def sample_individual_configs_vectorized(study_config):
"""
Vectorizes the sampling of parameters for a population of individuals.
Parameters
----------
study_config : StudyConfig
Contains the study settings and distribution parameters.
Returns
-------
config_dict : dict
Dictionary containing the vectorized parameters and time-magnitudes.
Keys:
'k_a', 'k_e', 'V': Tensors of shape (N,)
'k_1p', 'k_p1': Tensors of shape (N, P)
'k_a_tmag', 'k_e_tmag', 'V_tmag': Scalars
'k_1p_tmag', 'k_p1_tmag': Tensors of shape (P,)
'num_peripherals': int
"""
N = study_config.num_individuals
P = study_config.num_peripherals
# Sample the central parameters as tensors of shape (N,)
k_a = torch.from_numpy(np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std, size=N)).float()
k_e = torch.from_numpy(np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std, size=N)).float()
V = torch.from_numpy(np.random.lognormal(study_config.log_V_mean, study_config.log_V_std, size=N)).float()
# Sample the peripheral parameters as tensors of shape (N, P)
k_1p = []
k_p1 = []
for i in range(P):
k_1p_i = torch.from_numpy(np.random.lognormal(study_config.log_k_1p_mean[i],
study_config.log_k_1p_std[i], size=N)).float()
k_p1_i = torch.from_numpy(np.random.lognormal(study_config.log_k_p1_mean[i],
study_config.log_k_p1_std[i], size=N)).float()
k_1p.append(k_1p_i)
k_p1.append(k_p1_i)
# Stack along the peripheral dimension: shape becomes (N, P)
k_1p = torch.stack(k_1p, dim=1)
k_p1 = torch.stack(k_p1, dim=1)
# Pack time-magnitudes (assumed scalars for central parameters and lists for peripherals)
k_a_tmag = study_config.k_a_tmag # scalar
k_e_tmag = study_config.k_e_tmag # scalar
V_tmag = study_config.V_tmag # scalar
# For peripherals, we assume the study_config gives lists/arrays of length P.
k_1p_tmag = torch.tensor(study_config.k_1p_tmag).float() # shape (P,)
k_p1_tmag = torch.tensor(study_config.k_p1_tmag).float() # shape (P,)
config_dict = {
'k_a': k_a,
'k_e': k_e,
'V': V,
'k_1p': k_1p,
'k_p1': k_p1,
'k_a_tmag': k_a_tmag,
'k_e_tmag': k_e_tmag,
'V_tmag': V_tmag,
'k_1p_tmag': k_1p_tmag,
'k_p1_tmag': k_p1_tmag,
'num_peripherals': P,
}
return config_dict
import torch
def compute_rates(config, t):
"""
Computes the dynamic rates for all individuals at a given time t.
Parameters
----------
config : dict
Dictionary returned by sample_individual_configs_vectorized.
t : float or torch.Tensor
Current time point.
Returns
-------
k_a, k_e, V : torch.Tensor
Tensors of shape (N,).
k_1p, k_p1 : torch.Tensor
Tensors of shape (N, P).
"""
# Ensure t is a tensor
if not isinstance(t, torch.Tensor):
t = torch.tensor(t, dtype=config['k_a_tmag'].dtype, device=config['k_a_tmag'].device)
k_a = config['k_a'] * torch.exp(-config['k_a_tmag'] * t)
k_e = config['k_e'] * torch.exp(-config['k_e_tmag'] * t)
V = config['V'] * torch.exp(-config['V_tmag'] * t)
# Use broadcasting for peripheral compartments
k_1p = config['k_1p'] * torch.exp(-config['k_1p_tmag'] * t)
k_p1 = config['k_p1'] * torch.exp(-config['k_p1_tmag'] * t)
return k_a, k_e, V, k_1p, k_p1
def ode_func(t_val, y, config):
"""
ODE function using vectorized rate computations.
Parameters
----------
t_val : torch.Tensor
Current time point.
y : torch.Tensor
Current state, shape (N, M) where M = 2 + num_peripherals.
config : dict
Vectorized individual configuration dictionary.
Returns
-------
dy_dt : torch.Tensor
Time derivative of y, shape (N, M).
"""
# Get the dynamic rates for all individuals at time t_val.
k_a, k_e, _, k_1p, k_p1 = compute_rates(config, t_val)
N = y.size(0)
P = config['num_peripherals']
M = 2 + P
# Build the ODE rate matrix A(t) in a vectorized fashion
A_all = torch.zeros((N, M, M), dtype=torch.float32)
A_all[:, 0, 0] = -k_a # Loss from gut
A_all[:, 1, 0] = k_a # Transfer gut -> central
A_all[:, 1, 1] = -k_e - k_1p.sum(dim=1) # Loss from central and distribution to peripherals
A_all[:, 1, 2:2+P] = k_p1 # Transfer central -> peripherals
A_all[:, 2:2+P, 1] = k_1p # Transfer peripherals -> central
# Peripheral compartments clearance:
for i in range(P):
A_all[:, 2 + i, 2 + i] = -k_p1[:, i]
# Compute dy/dt = A_all @ y for each individual.
dy_dt = torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
return dy_dt
def sample_study_vectorized(study_config, dosing_config, t, solver_method="rk4"):
"""
Simulates the pharmacokinetic study using vectorized individual configurations.
Parameters
----------
study_config : StudyConfig
Contains global study settings and distribution parameters.
dosing_config : DosingConfig
Contains dosing information.
t : torch.Tensor
Time points at which the simulation is evaluated.
Returns
-------
full_simulation : torch.Tensor
Concentration profiles (N, len(t)).
full_times : torch.Tensor
Time points replicated for each individual.
"""
from torchdiffeq import odeint
# Get the vectorized configuration dictionary
config = sample_individual_configs_vectorized(study_config)
N = study_config.num_individuals
P = study_config.num_peripherals
M = 2 + P
# Initial conditions: dose in the gut (first compartment), zeros elsewhere.
y0 = torch.zeros((N, M), dtype=torch.float32)
y0[:, 0] = dosing_config.dose
def wrapped_ode(t_val, y):
return ode_func(t_val, y, config)
# Solve the ODE system for all individuals in batch
y = odeint(wrapped_ode, y0, t, method=solver_method)
# Extract central compartment (index 1) for each individual
full_simulation = y[:, :, 1].T
full_times = t.unsqueeze(0).repeat(N, 1)
return full_simulation, full_times