drift-sim / app.py
RyanGutenkunst's picture
Update description
a99f577
import matplotlib.pyplot as plt, numpy as np
import gradio as gr
# Try to import numba for speed boost
try:
from numba import jit
NUMBA_AVAILABLE = True
except ImportError:
NUMBA_AVAILABLE = False
# Create a dummy decorator if numba isn't available
def jit(nopython=True):
def decorator(func):
return func
return decorator
@jit(nopython=True)
def drift_simulation_numba(N, T, s, f0, Nsim=500):
"""Fast numba-compiled drift simulation with early stopping"""
trajs = np.zeros((Nsim, T+1))
for sim in range(Nsim):
trajs[sim, 0] = f0
f = f0
for t in range(1, T+1):
# Selection update
if s != 0:
f = f * (1 + s) / (f * s + 1)
# Early stopping conditions
if f <= 0:
trajs[sim, t:] = 0.0
break
elif f >= 1:
trajs[sim, t:] = 1.0
break
# Binomial sampling for drift
count = np.random.binomial(2 * N, f)
f = count / (2 * N)
trajs[sim, t] = f
# Early stopping if fixed or lost
if f == 0 or f == 1:
trajs[sim, t+1:] = f
break
return trajs
def drift_simulation_numpy(N, T, s, f0, Nsim=500):
"""Vectorized numpy implementation for speed"""
trajs = np.zeros((Nsim, T+1))
trajs[:, 0] = f0
# Track which simulations are still active
active = np.ones(Nsim, dtype=bool)
f_current = np.full(Nsim, f0)
for t in range(1, T+1):
if not np.any(active):
break
# Selection update for active simulations
if s != 0:
f_current[active] = (f_current[active] * (1 + s) /
(f_current[active] * s + 1))
# Check boundary conditions
lost = active & (f_current <= 0)
fixed = active & (f_current >= 1)
# Set boundary values
f_current[lost] = 0.0
f_current[fixed] = 1.0
trajs[lost, t:] = 0.0
trajs[fixed, t:] = 1.0
# Update active status
active = active & ~lost & ~fixed
if np.any(active):
# Vectorized binomial sampling for all active simulations
active_idx = np.where(active)[0]
counts = np.random.binomial(2 * N, f_current[active_idx])
f_current[active_idx] = counts / (2 * N)
trajs[active_idx, t] = f_current[active_idx]
# Check for new fixations/losses
new_lost = f_current[active_idx] == 0
new_fixed = f_current[active_idx] == 1
if np.any(new_lost):
lost_idx = active_idx[new_lost]
trajs[lost_idx, t+1:] = 0.0
active[lost_idx] = False
if np.any(new_fixed):
fixed_idx = active_idx[new_fixed]
trajs[fixed_idx, t+1:] = 1.0
active[fixed_idx] = False
return trajs
def drift_main(N, T, s, f0, Nsim=500):
# Convert to integers to ensure compatibility
N, T = int(N), int(T)
# Choose the fastest available implementation
if NUMBA_AVAILABLE:
trajs = drift_simulation_numba(N, T, s, f0, Nsim)
else:
trajs = drift_simulation_numpy(N, T, s, f0, Nsim)
# Find fixation and loss times
tfix = np.argmax(trajs == 1, axis=1)
tfix = tfix[tfix > 0]
tloss = np.argmax(trajs == 0, axis=1)
tloss = tloss[tloss > 0]
fig = plt.figure(137)
fig.clear()
ax = fig.add_subplot(1,1,1)
for traj in trajs:
ax.plot(traj)
ax.set_ylim(0,1)
ax.set_xlim(0,T)
ax.set_xlabel('Time (generations)')
ax.set_ylabel('Mutation frequency')
fig.tight_layout()
fig2 = plt.figure(139)
fig2.clear()
ax = fig2.add_subplot(1,1,1)
_,_,lf = ax.hist(tloss, bins=20, alpha=0.5, label='Losses (total=%d)'%len(tloss))
_,_,fh = ax.hist(tfix, bins=20, alpha=0.5, label='Fixations (total=%d)'%len(tfix))
ax.axvline(np.mean(tloss) if len(tloss) > 0 else 0, color=lf[0].get_facecolor(), ls='--', lw=3)
ax.axvline(np.mean(tfix) if len(tfix) > 0 else 0, color=fh[0].get_facecolor(), ls='--', lw=3)
ax.legend()
ax.set_xlabel('Time to loss or fixation (generations)')
ax.set_ylabel('Number of simulations')
fig2.tight_layout()
return fig, fig2
description = """This app runs 500 simulations of a single biallelic mutation in a constant-size population under the Wright-Fisher model.
Use the sliders to adjust the values of the population size N, the number of generations to run T, the selection coefficient s, and the initial allele frequency f0.
The top plot shows the allele frequency trajectories over time for all simulations. The bottom plot shows histograms of the times to loss and fixation for the mutation across the simulations.
Note that you may need to extend the running time T to see more fixations or losses, especially for larger population sizes or weaker selection.
"""
article = """
Questions? Contact [Ryan Gutenkunst at rgutenk@arizona.edu](mailto:rgutenk@arizona.edu).
"""
# Set up Gradio
drift = gr.Interface(
fn=drift_main,
inputs=[gr.Slider(0, 2000, 100, label='Population size N'),
gr.Slider(0, 10000, 1000, label='Generations to run T'),
gr.Slider(-0.01, 0.01, 0, label='Selection coefficient s'),
gr.Slider(0, 1, 0.5, label='Initial frequency f0')],
outputs=[gr.Plot(),
gr.Plot()],
live=True,
clear_btn=None,
title="Drift simulation in a single constant-size population",
description=description,
article=article,
flagging_mode="never",
)
if __name__ == "__main__":
drift.launch() # Uncomment this line to run the Gradio interface directly