AICME-runtime / sim_priors_pk /data /data_generation /compartment_models.py
cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
import random
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from torchdiffeq import odeint
from torchtyping import TensorType
from sim_priors_pk.config_classes.data_config import (
DosingConfig,
DosingWithDurationConfig,
MetaDosingConfig,
MetaDosingWithDurationConfig,
MetaStudyConfig,
)
from sim_priors_pk.config_classes.node_pk_config import NodePKExperimentConfig
@dataclass
class StudyConfig:
"""
This corresponds to the configuration of one study
"""
drug_id: str # Identifier for the drug
num_individuals: int # Number of individuals in the population
num_peripherals: int # Number of peripheral compartments
log_k_a_mean: float # Mean absorption rate constant
log_k_a_std: float # Standard deviation for absorption rate constant
k_a_tmag: float # Magnitude of time-dependent variation of absorption rate constant
k_a_tscl: float # Scale of time-dependent variation of absorption rate constant
log_k_e_mean: float # Mean elimination rate constant
log_k_e_std: float # Standard deviation for elimination rate constant
k_e_tmag: float # Magnitude of time-dependent variation of elimination rate constant
k_e_tscl: float # Scale of time-dependent variation of elimination rate constant
log_V_mean: float # Mean volume of central compartment
log_V_std: float # Standard deviation for volume of central compartment
V_tmag: float # Magnitude of time-dependent variation of volume of central compartment
V_tscl: float # Scale of time-dependent variation of volume of central compartment
log_k_1p_mean: List[float] # Mean rate constants (central to other peripherals)
log_k_1p_std: List[float] # Standard deviations for k_1p
k_1p_tmag: List[float] # Magnitude of time-dependent variation of k_1p
k_1p_tscl: List[float] # Scale of time-dependent variation of k_1p
log_k_p1_mean: List[float] # Mean rate constants (other peripherals to central)
log_k_p1_std: List[float] # Standard deviations for k_p1
k_p1_tmag: List[float] # Magnitude of time-dependent variation of k_p1
k_p1_tscl: List[float] # Scale of time-dependent variation of k_p1
time_start: float # Start time for the study
time_stop: float # End time for the study
rel_ruv: float # Relative residual unexplained variability for the study
@dataclass
class IndividualConfig:
"""
This corresponds to the configuration of one individual.
"""
num_peripherals: int = 2 # Number of peripheral compartments
k_a: Callable[[float], float] = lambda t: 0.1 # Absorption rate constant (gut to central)
k_e: Callable[[float], float] = lambda t: 0.05 # Elimination rate constant (central)
V: Callable[[float], float] = lambda t: 0.05 # Volume of central compartment
k_1p: List[Callable[[float], float]] = field(
default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
) # Rate constants from central to other peripherals
k_p1: List[Callable[[float], float]] = field(
default_factory=lambda: [lambda t: 0.01, lambda t: 0.01]
) # Rate constants from other peripherals to central
rel_ruv: float = 0.1 # Relative residual unexplained variability per individual
def sample_study_config(config: MetaStudyConfig):
"""
Samples a StudyConfig object based on the MetaStudyConfig.
"""
# Generate random values for each parameter
drug_id = random.choice(config.drug_id_options)
num_individuals = random.randint(*config.num_individuals_range)
num_peripherals = random.randint(*config.num_peripherals_range)
# Sample mean, std, and tmag for each rate constant
log_k_a_mean = random.uniform(*config.log_k_a_mean_range)
log_k_a_std = random.uniform(*config.log_k_a_std_range)
k_a_tmag = random.uniform(*config.k_a_tmag_range)
k_a_tscl = random.uniform(*config.k_a_tscl_range)
log_k_e_mean = random.uniform(*config.log_k_e_mean_range)
log_k_e_std = random.uniform(*config.log_k_e_std_range)
k_e_tmag = random.uniform(*config.k_e_tmag_range)
k_e_tscl = random.uniform(*config.k_e_tscl_range)
log_V_mean = random.uniform(*config.log_V_mean_range)
log_V_std = random.uniform(*config.log_V_std_range)
V_tmag = random.uniform(*config.V_tmag_range)
V_tscl = random.uniform(*config.V_tscl_range)
log_k_1p_mean = [random.uniform(*config.log_k_1p_mean_range) for _ in range(num_peripherals)]
log_k_1p_std = [random.uniform(*config.log_k_1p_std_range) for _ in range(num_peripherals)]
k_1p_tmag = [random.uniform(*config.k_1p_tmag_range) for _ in range(num_peripherals)]
k_1p_tscl = [random.uniform(*config.k_1p_tscl_range) for _ in range(num_peripherals)]
log_k_p1_mean = [random.uniform(*config.log_k_p1_mean_range) for _ in range(num_peripherals)]
log_k_p1_std = [random.uniform(*config.log_k_p1_std_range) for _ in range(num_peripherals)]
k_p1_tmag = [random.uniform(*config.k_p1_tmag_range) for _ in range(num_peripherals)]
k_p1_tscl = [random.uniform(*config.k_p1_tscl_range) for _ in range(num_peripherals)]
rel_ruv = random.uniform(*config.rel_ruv_range)
return StudyConfig(
drug_id=drug_id,
num_individuals=num_individuals,
num_peripherals=num_peripherals,
log_k_a_mean=log_k_a_mean,
log_k_a_std=log_k_a_std,
k_a_tmag=k_a_tmag,
k_a_tscl=k_a_tscl,
log_k_e_mean=log_k_e_mean,
log_k_e_std=log_k_e_std,
k_e_tmag=k_e_tmag,
k_e_tscl=k_e_tscl,
log_V_mean=log_V_mean,
log_V_std=log_V_std,
V_tmag=V_tmag,
V_tscl=V_tscl,
log_k_1p_mean=log_k_1p_mean,
log_k_1p_std=log_k_1p_std,
k_1p_tmag=k_1p_tmag,
k_1p_tscl=k_1p_tscl,
log_k_p1_mean=log_k_p1_mean,
log_k_p1_std=log_k_p1_std,
k_p1_tmag=k_p1_tmag,
k_p1_tscl=k_p1_tscl,
time_start=config.time_start,
time_stop=config.time_stop,
rel_ruv=rel_ruv,
)
def sample_rate_function(mean_rate, variability, variability_type="sinusoidal"):
"""
Samples a time-dependent rate function.
:param mean_rate: Mean rate constant
:param variability: Variability in the rate constant
:param variability_type: Type of variability ("sinusoidal" or "decaying")
:return: A time-dependent rate function
"""
if variability_type == "sinusoidal":
def rate_function(t):
return mean_rate + variability * torch.sin(t) # Sinusoidal variability
elif variability_type == "decaying":
def rate_function(t):
return mean_rate * torch.exp(-variability * t) # Decaying variability
else:
raise ValueError(f"Unknown variability_type: {variability_type}")
return rate_function
def simulate_ou_process(
mu: float, sigma: float, theta: float, dt: float, T: float, seed: Optional[int] = None
) -> np.ndarray:
"""Simulate a mean-reverting Ornstein-Uhlenbeck process."""
if seed is not None:
np.random.seed(seed)
N = int(T / dt)
X = np.zeros(N)
# Start from the stationary distribution
X[0] = np.random.normal(mu, np.sqrt(sigma**2 / (2 * theta)))
for t in range(1, N):
dW = np.random.normal(0, np.sqrt(dt))
X[t] = X[t - 1] + theta * (mu - X[t - 1]) * dt + sigma * dW
return X
def sample_individual_configs(study_config: StudyConfig, n: Optional[int] = None):
"""
Samples parameters for a population of individuals.
Parameters
----------
study_config : StudyConfig
Configuration object with parameter distributions.
n : int, optional
Number of individuals to sample. If None, defaults to
study_config.num_individuals.
Returns
-------
List[IndividualConfig]
A list of sampled individual configurations.
"""
num_individuals = n if n is not None else study_config.num_individuals
individual_configs = []
for _ in range(num_individuals):
# Sample parameters from lognormal distributions
k_a = np.random.lognormal(study_config.log_k_a_mean, study_config.log_k_a_std)
k_e = np.random.lognormal(study_config.log_k_e_mean, study_config.log_k_e_std)
V = np.random.lognormal(study_config.log_V_mean, study_config.log_V_std)
k_1p = [
np.random.lognormal(mean, std)
for mean, std in zip(study_config.log_k_1p_mean, study_config.log_k_1p_std)
]
k_p1 = [
np.random.lognormal(mean, std)
for mean, std in zip(study_config.log_k_p1_mean, study_config.log_k_p1_std)
]
# Ornstein–Uhlenbeck processes for time-dependent variability
dt = 0.1
ou_times = np.arange(study_config.time_start, study_config.time_stop, dt)
ou_k_a = k_a * np.exp(
simulate_ou_process(
0,
study_config.k_a_tmag * np.sqrt(2 * study_config.k_a_tscl),
study_config.k_a_tmag,
dt,
study_config.time_stop - study_config.time_start,
)
)
ou_k_e = k_e * np.exp(
simulate_ou_process(
0,
study_config.k_e_tmag * np.sqrt(2 * study_config.k_e_tscl),
study_config.k_e_tmag,
dt,
study_config.time_stop - study_config.time_start,
)
)
ou_V = V * np.exp(
simulate_ou_process(
0,
study_config.V_tmag * np.sqrt(2 * study_config.V_tscl),
study_config.V_tmag,
dt,
study_config.time_stop - study_config.time_start,
)
)
# Time-dependent rate functions
def k_a_fn(t, ou_k_a=ou_k_a):
return np.interp(t, ou_times, ou_k_a)
def k_e_fn(t, ou_k_e=ou_k_e):
return np.interp(t, ou_times, ou_k_e)
def V_fn(t, ou_V=ou_V):
return np.interp(t, ou_times, ou_V)
# Peripheral exchange rates (sinusoidal modulation as placeholder)
k_1p_fn = [
lambda t,
k_1p_i=k_1p[i],
tmag_i=study_config.k_1p_tmag[i],
tscl_i=study_config.k_1p_tscl[i]: k_1p_i * (1 + tmag_i * np.sin(t / tscl_i))
for i in range(len(k_1p))
]
k_p1_fn = [
lambda t,
k_p1_i=k_p1[i],
tmag_i=study_config.k_p1_tmag[i],
tscl_i=study_config.k_p1_tscl[i]: k_p1_i * (1 + tmag_i * np.sin(t / tscl_i))
for i in range(len(k_p1))
]
# Create config for this individual
config = IndividualConfig(
num_peripherals=study_config.num_peripherals,
k_a=k_a_fn,
k_e=k_e_fn,
V=V_fn,
k_1p=k_1p_fn,
k_p1=k_p1_fn,
rel_ruv=study_config.rel_ruv,
)
individual_configs.append(config)
return individual_configs
def create_dynamic_ode_matrix(config: IndividualConfig, t: float):
"""
Creates the ODE matrix for the compartment model at time t.
:param config: IndividualConfig object
:param t: Current time
:return: ODE matrix as a torch tensor
"""
num_compartments = 2 + config.num_peripherals # gut, central, and peripherals
ode_matrix = torch.zeros((num_compartments, num_compartments))
# Gut compartment
ode_matrix[0, 0] = -config.k_a(t) # d_gut/dt = -k_a(t) * gut
# Central compartment
ode_matrix[1, 0] = config.k_a(t) # d_central/dt += k_a(t) * gut
ode_matrix[1, 1] = -config.k_e(t) # d_central/dt += -k_e(t) * central
# Peripheral compartments
for i in range(config.num_peripherals):
ode_matrix[1, 1] -= config.k_1p[i](t) # d_central/dt += - sum_p(k_1p(t)) * central
ode_matrix[1, 2 + i] = config.k_p1[i](t) # d_central/dt += k_p1[i](t) * peripheral(i)
ode_matrix[2 + i, 1] = config.k_1p[i](t) # d_peripheral(i)/dt += k_1p[i](t) * central
ode_matrix[2 + i, 2 + i] = -config.k_p1[i](
t
) # d_peripheral(i)/dt += -k_p1[i](t) * peripheral(i)
return ode_matrix
def create_dynamic_ode_matrix_batched(configs, t, num_peripherals):
"""
Creates batched ODE matrices for multiple individuals.
Parameters:
----------
configs : list
List of IndividualConfig objects.
t : float
Current time point.
num_peripherals : int
Number of peripheral compartments (same for all individuals).
Returns:
-------
A_all : torch.Tensor
Tensor of shape (N, M, M) containing ODE matrices for all individuals.
"""
import torch
N = len(configs)
M = 2 + num_peripherals
A_all = torch.zeros((N, M, M), dtype=torch.float32)
# Compute batched rate parameters
k_a_all = torch.tensor([config.k_a(t) for config in configs], dtype=torch.float32)
k_e_all = torch.tensor([config.k_e(t) for config in configs], dtype=torch.float32)
k_1p_all = torch.tensor(
[[config.k_1p[i](t) for i in range(num_peripherals)] for config in configs],
dtype=torch.float32,
)
k_p1_all = torch.tensor(
[[config.k_p1[i](t) for i in range(num_peripherals)] for config in configs],
dtype=torch.float32,
)
# Populate the batched ODE matrices
A_all[:, 0, 0] = -k_a_all # Gut compartment
A_all[:, 1, 0] = k_a_all # Absorption into central
A_all[:, 1, 1] = -k_e_all - k_1p_all.sum(dim=1) # Central compartment
A_all[:, 1, 2 : 2 + num_peripherals] = k_p1_all # Central to peripheral
A_all[:, 2 : 2 + num_peripherals, 1] = k_1p_all # Peripheral to central
for i in range(num_peripherals):
A_all[:, 2 + i, 2 + i] = -k_p1_all[:, i] # Peripheral compartments
return A_all
def sample_study(
individual_config_array, dosing_config_array, t: torch.Tensor, solver_method: str = "rk4"
) -> Tuple[
torch.Tensor, # [N, T] concentration profiles
torch.Tensor, # [N, T] time points
torch.Tensor, # [N] dosing amounts
torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
]:
"""
Simulates the pharmacokinetic study for a group of individuals and returns
concentration profiles, time points, and dosing metadata.
Parameters:
----------
individual_config_array : list
List of IndividualConfig objects for each individual.
dosing_config_array : list
List of DosingConfig objects for each individual.
t : torch.Tensor
A 1D tensor of time points [T].
Returns:
-------
full_simulation : torch.Tensor
Concentration profiles [N, T].
full_simulation_times : torch.Tensor
Time points [N, T].
dosing_amounts : torch.Tensor
Dosing amounts [N].
dosing_route_types : torch.Tensor
Route types [N], 0 = oral, 1 = iv.
"""
# Sanity check
if len(individual_config_array) != len(dosing_config_array):
raise ValueError("Number of individuals and dosing configurations must match.")
N = len(individual_config_array)
num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
# Extract dosing info
dosing_amounts = torch.tensor(
[cfg.dose for cfg in dosing_config_array], dtype=torch.float32
) # [N]
routes_str = [cfg.route for cfg in dosing_config_array]
route_map = {"oral": 0, "iv": 1}
dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
if all_same_peripherals:
P = num_peripherals_list[0]
M = 2 + P
y0 = torch.zeros((N, M), dtype=torch.float32)
is_oral = dosing_route_types == 0
is_iv = dosing_route_types == 1
y0[is_oral, 0] = dosing_amounts[is_oral]
y0[is_iv, 1] = dosing_amounts[is_iv]
def ode_func(t, y):
A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
V_all = torch.tensor(
[[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
) # [N, T]
full_simulation = y[:, :, 1].T / V_all # [N, T]
full_simulation *= (
1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
)
else:
full_simulation = []
for config, dosing_config in zip(individual_config_array, dosing_config_array):
P = config.num_peripherals
M = 2 + P
if dosing_config.route == "oral":
y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
elif dosing_config.route == "iv":
y0 = torch.tensor([0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32)
else:
raise ValueError(f"Unsupported route: {dosing_config.route}")
def ode_func(t, y):
A = create_dynamic_ode_matrix(config, t.item())
return torch.matmul(A, y)
y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
concentration = y[:, 1] / V
concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
full_simulation.append(concentration)
full_simulation = torch.stack(full_simulation) # [N, T]
full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
return full_simulation, full_times, dosing_amounts, dosing_route_types
def sample_study_with_duration(
individual_config_array,
dosing_config_array: List[DosingWithDurationConfig],
t: torch.Tensor,
solver_method: str = "rk4",
) -> Tuple[
torch.Tensor, # [N, T] concentration profiles
torch.Tensor, # [N, T] time points
torch.Tensor, # [N] dosing amounts
torch.Tensor, # [N] dosing route types (0 = oral, 1 = iv)
]:
"""
Simulates the pharmacokinetic study for a group of individuals and returns
concentration profiles, time points, and dosing metadata.
This is a parallel implementation to sample_study that supports infusion dosing with duration.
Once validated, the two can be merged.
Parameters:
----------
individual_config_array : list
List of IndividualConfig objects for each individual.
dosing_config_array : list
List of DosingWithDurationConfig objects for each individual.
t : torch.Tensor
A 1D tensor of time points [T].
Returns:
-------
full_simulation : torch.Tensor
Concentration profiles [N, T].
full_simulation_times : torch.Tensor
Time points [N, T].
dosing_amounts : torch.Tensor
Dosing amounts [N].
dosing_route_types : torch.Tensor
Route types [N], 0 = oral, 1 = iv.
"""
# Sanity check
if len(individual_config_array) != len(dosing_config_array):
raise ValueError("Number of individuals and dosing configurations must match.")
N = len(individual_config_array)
num_peripherals_list = [cfg.num_peripherals for cfg in individual_config_array]
all_same_peripherals = all(n == num_peripherals_list[0] for n in num_peripherals_list)
# Extract dosing info
dosing_amounts = torch.tensor(
[cfg.dose for cfg in dosing_config_array], dtype=torch.float32
) # [N]
routes_str = [cfg.route for cfg in dosing_config_array]
route_map = {"oral": 0, "iv": 1}
dosing_route_types = torch.tensor([route_map[r] for r in routes_str], dtype=torch.int64) # [N]
dosing_durations = torch.tensor(
[cfg.duration for cfg in dosing_config_array], dtype=torch.float32
) # [N]
if all_same_peripherals and all(dosing_durations == 0):
P = num_peripherals_list[0]
M = 2 + P # gut, central, peripherals
y0 = torch.zeros((N, M), dtype=torch.float32)
is_oral = dosing_route_types == 0
is_iv_bolus = dosing_route_types == 1
y0[is_oral, 0] = dosing_amounts[is_oral]
y0[is_iv_bolus, 1] = dosing_amounts[is_iv_bolus]
def ode_func(t, y):
A_all = create_dynamic_ode_matrix_batched(individual_config_array, t.item(), P)
return torch.bmm(A_all, y.unsqueeze(-1)).squeeze(-1)
# ODE solving during infusion
y = odeint(ode_func, y0, t, method=solver_method) # [T, N, M]
V_all = torch.tensor(
[[cfg.V(ti.item()) for ti in t] for cfg in individual_config_array], dtype=torch.float32
) # [N, T]
full_simulation = y[:, :, 1].T / V_all # [N, T]
full_simulation *= (
1 + torch.randn_like(full_simulation) * individual_config_array[0].rel_ruv
)
else:
full_simulation = []
for config, dosing_config in zip(individual_config_array, dosing_config_array):
P = config.num_peripherals
M = 2 + P # gut, central, peripherals
if dosing_config.route == "oral":
assert dosing_config.duration == 0, "Oral dosing cannot have a duration."
y0 = torch.tensor([dosing_config.dose] + [0.0] * (M - 1), dtype=torch.float32)
elif dosing_config.route == "iv":
if dosing_config.duration > 0:
# Infusion dosing
y0 = torch.tensor(
[0.0, 0.0] + [0.0] * (M - 2),
dtype=torch.float32,
)
else: # Bolus dosing
y0 = torch.tensor(
[0.0, dosing_config.dose] + [0.0] * (M - 2), dtype=torch.float32
)
else:
raise ValueError(f"Unsupported route: {dosing_config.route}")
def ode_func(t, y):
A = create_dynamic_ode_matrix(config, t.item())
b = torch.zeros_like(y)
if (
dosing_config.route == "iv"
and dosing_config.duration > 0
and t.item() < dosing_config.duration
):
# During infusion, add rate to central compartment
b[1] = dosing_config.dose / dosing_config.duration
return torch.matmul(A, y) + b
y = odeint(ode_func, y0, t, method=solver_method) # [T, M]
V = torch.tensor([config.V(ti.item()) for ti in t], dtype=torch.float32) # [T]
concentration = y[:, 1] / V
concentration *= 1 + torch.randn_like(concentration) * config.rel_ruv
full_simulation.append(concentration)
full_simulation = torch.stack(full_simulation) # [N, T]
full_times = t.unsqueeze(0).repeat(N, 1) # [N, T]
return full_simulation, full_times, dosing_amounts, dosing_route_types
def derive_timescale_parameters(config: StudyConfig, meta_config: MetaStudyConfig):
"""
Derive peak time and terminal half life for typical parameters,
which can then be used to inform a study-specific sampling schedule.
"""
k_a = np.exp(config.log_k_a_mean)
k_e = np.exp(config.log_k_e_mean)
tmax = (np.log(k_e) - np.log(k_a)) / (k_e - k_a)
# mean residence time approximation for terminal half-life
MRT = 1 / k_e
# for i in range(config.num_peripherals):
# k_1i = np.exp(config.log_k_p1_mean[i])
# MRT += 1/k_1i
t12 = np.log(2) * MRT
if t12 > meta_config.time_stop:
t12 = float(meta_config.time_stop / 2.0)
if tmax > t12:
tmax = float(t12 * 0.5)
return torch.Tensor([tmax, t12])
def sample_dosing_configs(config: MetaDosingConfig):
"""
Sample a dosing configuration based on the meta dosing configuration.
Route may be the same for all individuals in the study or not.
Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
"""
dosing_configs = []
if config.same_route:
route = np.random.choice(config.route_options, p=config.route_weights)
route = np.repeat(route, config.num_individuals)
else:
route = np.random.choice(
config.route_options, p=config.route_weights, size=config.num_individuals
)
# Draw lognormal distribution parameters for dose
logdose_mean = np.random.uniform(*config.logdose_mean_range)
logdose_std = np.random.uniform(*config.logdose_std_range)
dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
for i in range(config.num_individuals):
time = config.time
this_config = DosingConfig(dose=dose[i], route=route[i], time=time)
dosing_configs.append(this_config)
return dosing_configs
def sample_dosing_with_duration_configs(config: MetaDosingWithDurationConfig):
"""
Sample a dosing configuration based on the meta dosing configuration.
Route may be the same for all individuals in the study or not.
Doses are lognormally distributed with log-mean and log-std sample uniformly from the specified range.
In the special case of logdose_std_range = (0, 0), the dose is identical for all individuals.
"""
dosing_configs = []
if config.same_route:
route = np.random.choice(config.route_options, p=config.route_weights)
route = np.repeat(route, config.num_individuals)
else:
route = np.random.choice(
config.route_options, p=config.route_weights, size=config.num_individuals
)
# Draw durations for infusion dosing depending on route
duration_raw = np.random.uniform(
config.duration_range[0], config.duration_range[1], size=config.num_individuals
)
# Draw lognormal distribution parameters for dose
logdose_mean = np.random.uniform(*config.logdose_mean_range)
logdose_std = np.random.uniform(*config.logdose_std_range)
dose = np.random.lognormal(logdose_mean, logdose_std, size=config.num_individuals)
for i in range(config.num_individuals):
time = config.time
# Add duration flag based on route duration weights
duration_flag = np.random.binomial(1, config.route_duration_weights[route[i]], size=1)[0]
# Define a dosing config with a (default) duration of 0; can be modified once MetaDosingConfig supports it
this_config = DosingWithDurationConfig(
dose=dose[i],
route=route[i],
time=time,
duration=duration_raw[i] * duration_flag,
)
dosing_configs.append(this_config)
return dosing_configs
def get_random_simulation(
model_config: NodePKExperimentConfig,
) -> Tuple[TensorType["I", "T"], TensorType["I", "T"]]:
"""
Generates random simulation data based on the model configuration.
Args:
model_config (NodePKConfig): Configuration for the simulation.
Returns:
Tuple[TensorType["I", "T"], TensorType["I", "T"]]: Time steps and simulation points.
"""
I = model_config.meta_study.num_individuals_range[0]
T = model_config.meta_study.time_num_steps
# Generate time steps using linspace
time_steps = (
torch.linspace(
model_config.meta_study.time_start,
model_config.meta_study.time_stop,
T,
dtype=torch.float32,
)
.unsqueeze(0)
.repeat(I, 1)
) # Shape: [I, T]
# Generate random simulation points with the same shape
simulation_points = torch.rand(I, T) # Shape: [I, T]
simulation_points = simulation_points / model_config.meta_study.time_stop
return simulation_points, time_steps