import matplotlib.pyplot as plt, numpy as np import gradio as gr # Try to import numba for speed boost try: from numba import jit NUMBA_AVAILABLE = True except ImportError: NUMBA_AVAILABLE = False # Create a dummy decorator if numba isn't available def jit(nopython=True): def decorator(func): return func return decorator @jit(nopython=True) def drift_simulation_numba(N, T, s, f0, Nsim=500): """Fast numba-compiled drift simulation with early stopping""" trajs = np.zeros((Nsim, T+1)) for sim in range(Nsim): trajs[sim, 0] = f0 f = f0 for t in range(1, T+1): # Selection update if s != 0: f = f * (1 + s) / (f * s + 1) # Early stopping conditions if f <= 0: trajs[sim, t:] = 0.0 break elif f >= 1: trajs[sim, t:] = 1.0 break # Binomial sampling for drift count = np.random.binomial(2 * N, f) f = count / (2 * N) trajs[sim, t] = f # Early stopping if fixed or lost if f == 0 or f == 1: trajs[sim, t+1:] = f break return trajs def drift_simulation_numpy(N, T, s, f0, Nsim=500): """Vectorized numpy implementation for speed""" trajs = np.zeros((Nsim, T+1)) trajs[:, 0] = f0 # Track which simulations are still active active = np.ones(Nsim, dtype=bool) f_current = np.full(Nsim, f0) for t in range(1, T+1): if not np.any(active): break # Selection update for active simulations if s != 0: f_current[active] = (f_current[active] * (1 + s) / (f_current[active] * s + 1)) # Check boundary conditions lost = active & (f_current <= 0) fixed = active & (f_current >= 1) # Set boundary values f_current[lost] = 0.0 f_current[fixed] = 1.0 trajs[lost, t:] = 0.0 trajs[fixed, t:] = 1.0 # Update active status active = active & ~lost & ~fixed if np.any(active): # Vectorized binomial sampling for all active simulations active_idx = np.where(active)[0] counts = np.random.binomial(2 * N, f_current[active_idx]) f_current[active_idx] = counts / (2 * N) trajs[active_idx, t] = f_current[active_idx] # Check for new fixations/losses new_lost = f_current[active_idx] == 0 new_fixed = f_current[active_idx] == 1 if np.any(new_lost): lost_idx = active_idx[new_lost] trajs[lost_idx, t+1:] = 0.0 active[lost_idx] = False if np.any(new_fixed): fixed_idx = active_idx[new_fixed] trajs[fixed_idx, t+1:] = 1.0 active[fixed_idx] = False return trajs def drift_main(N, T, s, f0, Nsim=500): # Convert to integers to ensure compatibility N, T = int(N), int(T) # Choose the fastest available implementation if NUMBA_AVAILABLE: trajs = drift_simulation_numba(N, T, s, f0, Nsim) else: trajs = drift_simulation_numpy(N, T, s, f0, Nsim) # Find fixation and loss times tfix = np.argmax(trajs == 1, axis=1) tfix = tfix[tfix > 0] tloss = np.argmax(trajs == 0, axis=1) tloss = tloss[tloss > 0] fig = plt.figure(137) fig.clear() ax = fig.add_subplot(1,1,1) for traj in trajs: ax.plot(traj) ax.set_ylim(0,1) ax.set_xlim(0,T) ax.set_xlabel('Time (generations)') ax.set_ylabel('Mutation frequency') fig.tight_layout() fig2 = plt.figure(139) fig2.clear() ax = fig2.add_subplot(1,1,1) _,_,lf = ax.hist(tloss, bins=20, alpha=0.5, label='Losses (total=%d)'%len(tloss)) _,_,fh = ax.hist(tfix, bins=20, alpha=0.5, label='Fixations (total=%d)'%len(tfix)) ax.axvline(np.mean(tloss) if len(tloss) > 0 else 0, color=lf[0].get_facecolor(), ls='--', lw=3) ax.axvline(np.mean(tfix) if len(tfix) > 0 else 0, color=fh[0].get_facecolor(), ls='--', lw=3) ax.legend() ax.set_xlabel('Time to loss or fixation (generations)') ax.set_ylabel('Number of simulations') fig2.tight_layout() return fig, fig2 description = """This app runs 500 simulations of a single biallelic mutation in a constant-size population under the Wright-Fisher model. Use the sliders to adjust the values of the population size N, the number of generations to run T, the selection coefficient s, and the initial allele frequency f0. The top plot shows the allele frequency trajectories over time for all simulations. The bottom plot shows histograms of the times to loss and fixation for the mutation across the simulations. Note that you may need to extend the running time T to see more fixations or losses, especially for larger population sizes or weaker selection. """ article = """ Questions? Contact [Ryan Gutenkunst at rgutenk@arizona.edu](mailto:rgutenk@arizona.edu). """ # Set up Gradio drift = gr.Interface( fn=drift_main, inputs=[gr.Slider(0, 2000, 100, label='Population size N'), gr.Slider(0, 10000, 1000, label='Generations to run T'), gr.Slider(-0.01, 0.01, 0, label='Selection coefficient s'), gr.Slider(0, 1, 0.5, label='Initial frequency f0')], outputs=[gr.Plot(), gr.Plot()], live=True, clear_btn=None, title="Drift simulation in a single constant-size population", description=description, article=article, flagging_mode="never", ) if __name__ == "__main__": drift.launch() # Uncomment this line to run the Gradio interface directly