Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt, numpy as np | |
| import gradio as gr | |
| # Try to import numba for speed boost | |
| try: | |
| from numba import jit | |
| NUMBA_AVAILABLE = True | |
| except ImportError: | |
| NUMBA_AVAILABLE = False | |
| # Create a dummy decorator if numba isn't available | |
| def jit(nopython=True): | |
| def decorator(func): | |
| return func | |
| return decorator | |
| def drift_simulation_numba(N, T, s, f0, Nsim=500): | |
| """Fast numba-compiled drift simulation with early stopping""" | |
| trajs = np.zeros((Nsim, T+1)) | |
| for sim in range(Nsim): | |
| trajs[sim, 0] = f0 | |
| f = f0 | |
| for t in range(1, T+1): | |
| # Selection update | |
| if s != 0: | |
| f = f * (1 + s) / (f * s + 1) | |
| # Early stopping conditions | |
| if f <= 0: | |
| trajs[sim, t:] = 0.0 | |
| break | |
| elif f >= 1: | |
| trajs[sim, t:] = 1.0 | |
| break | |
| # Binomial sampling for drift | |
| count = np.random.binomial(2 * N, f) | |
| f = count / (2 * N) | |
| trajs[sim, t] = f | |
| # Early stopping if fixed or lost | |
| if f == 0 or f == 1: | |
| trajs[sim, t+1:] = f | |
| break | |
| return trajs | |
| def drift_simulation_numpy(N, T, s, f0, Nsim=500): | |
| """Vectorized numpy implementation for speed""" | |
| trajs = np.zeros((Nsim, T+1)) | |
| trajs[:, 0] = f0 | |
| # Track which simulations are still active | |
| active = np.ones(Nsim, dtype=bool) | |
| f_current = np.full(Nsim, f0) | |
| for t in range(1, T+1): | |
| if not np.any(active): | |
| break | |
| # Selection update for active simulations | |
| if s != 0: | |
| f_current[active] = (f_current[active] * (1 + s) / | |
| (f_current[active] * s + 1)) | |
| # Check boundary conditions | |
| lost = active & (f_current <= 0) | |
| fixed = active & (f_current >= 1) | |
| # Set boundary values | |
| f_current[lost] = 0.0 | |
| f_current[fixed] = 1.0 | |
| trajs[lost, t:] = 0.0 | |
| trajs[fixed, t:] = 1.0 | |
| # Update active status | |
| active = active & ~lost & ~fixed | |
| if np.any(active): | |
| # Vectorized binomial sampling for all active simulations | |
| active_idx = np.where(active)[0] | |
| counts = np.random.binomial(2 * N, f_current[active_idx]) | |
| f_current[active_idx] = counts / (2 * N) | |
| trajs[active_idx, t] = f_current[active_idx] | |
| # Check for new fixations/losses | |
| new_lost = f_current[active_idx] == 0 | |
| new_fixed = f_current[active_idx] == 1 | |
| if np.any(new_lost): | |
| lost_idx = active_idx[new_lost] | |
| trajs[lost_idx, t+1:] = 0.0 | |
| active[lost_idx] = False | |
| if np.any(new_fixed): | |
| fixed_idx = active_idx[new_fixed] | |
| trajs[fixed_idx, t+1:] = 1.0 | |
| active[fixed_idx] = False | |
| return trajs | |
| def drift_main(N, T, s, f0, Nsim=500): | |
| # Convert to integers to ensure compatibility | |
| N, T = int(N), int(T) | |
| # Choose the fastest available implementation | |
| if NUMBA_AVAILABLE: | |
| trajs = drift_simulation_numba(N, T, s, f0, Nsim) | |
| else: | |
| trajs = drift_simulation_numpy(N, T, s, f0, Nsim) | |
| # Find fixation and loss times | |
| tfix = np.argmax(trajs == 1, axis=1) | |
| tfix = tfix[tfix > 0] | |
| tloss = np.argmax(trajs == 0, axis=1) | |
| tloss = tloss[tloss > 0] | |
| fig = plt.figure(137) | |
| fig.clear() | |
| ax = fig.add_subplot(1,1,1) | |
| for traj in trajs: | |
| ax.plot(traj) | |
| ax.set_ylim(0,1) | |
| ax.set_xlim(0,T) | |
| ax.set_xlabel('Time (generations)') | |
| ax.set_ylabel('Mutation frequency') | |
| fig.tight_layout() | |
| fig2 = plt.figure(139) | |
| fig2.clear() | |
| ax = fig2.add_subplot(1,1,1) | |
| _,_,lf = ax.hist(tloss, bins=20, alpha=0.5, label='Losses (total=%d)'%len(tloss)) | |
| _,_,fh = ax.hist(tfix, bins=20, alpha=0.5, label='Fixations (total=%d)'%len(tfix)) | |
| ax.axvline(np.mean(tloss) if len(tloss) > 0 else 0, color=lf[0].get_facecolor(), ls='--', lw=3) | |
| ax.axvline(np.mean(tfix) if len(tfix) > 0 else 0, color=fh[0].get_facecolor(), ls='--', lw=3) | |
| ax.legend() | |
| ax.set_xlabel('Time to loss or fixation (generations)') | |
| ax.set_ylabel('Number of simulations') | |
| fig2.tight_layout() | |
| return fig, fig2 | |
| description = """This app runs 500 simulations of a single biallelic mutation in a constant-size population under the Wright-Fisher model. | |
| Use the sliders to adjust the values of the population size N, the number of generations to run T, the selection coefficient s, and the initial allele frequency f0. | |
| The top plot shows the allele frequency trajectories over time for all simulations. The bottom plot shows histograms of the times to loss and fixation for the mutation across the simulations. | |
| Note that you may need to extend the running time T to see more fixations or losses, especially for larger population sizes or weaker selection. | |
| """ | |
| article = """ | |
| Questions? Contact [Ryan Gutenkunst at rgutenk@arizona.edu](mailto:rgutenk@arizona.edu). | |
| """ | |
| # Set up Gradio | |
| drift = gr.Interface( | |
| fn=drift_main, | |
| inputs=[gr.Slider(0, 2000, 100, label='Population size N'), | |
| gr.Slider(0, 10000, 1000, label='Generations to run T'), | |
| gr.Slider(-0.01, 0.01, 0, label='Selection coefficient s'), | |
| gr.Slider(0, 1, 0.5, label='Initial frequency f0')], | |
| outputs=[gr.Plot(), | |
| gr.Plot()], | |
| live=True, | |
| clear_btn=None, | |
| title="Drift simulation in a single constant-size population", | |
| description=description, | |
| article=article, | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| drift.launch() # Uncomment this line to run the Gradio interface directly |