"""Gradio demo for bartab.""" from typing import Iterable, List, Union from functools import partial from io import TextIOWrapper import os os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue" from carabiner import cast, print_err from carabiner.decorators import decorator_with_params from carabiner.pd import read_table import gradio as gr import nemony as nm import numpy as np import pandas as pd import anndata anndata.settings.allow_write_nullable_strings = True import bartab from bartab.io import load_anndata from bartab.models.anndata import AnnDataWLSModel, AnnDataHillModel from bartab.plotting import ( count_vs_resid, dose_response, expansion_vs_count, expansion_vs_ratio, pred_vs_true, pred_vs_resid, time_vs_count, time_vs_ratio, volcano ) from bartab.transforms import compute_log_ratios pd.options.future.infer_string = True MODES: dict = { "single": "👟 Fitness in a single condition", "dose response": "📉 Fitness dose response to CRISPRi inducer", } def _message(s: str): print_err(f"[INFO] {s}") gr.Info(s, duration=10) return None def load_input_data( filename: str, cols: Iterable ) -> List[pd.DataFrame]: df = read_table(filename) print_err(df) out = [gr.update(value=df, visible=False)] for key, col in cols.items(): if isinstance(col, tuple): col_type = col[1] if col_type == "string": choices = list(df.select_dtypes(include="str")) elif col_type == "numeric": choices = list(df.select_dtypes(include="number")) else: choices = list(df) else: choices = list(df) choices = [""] + choices print_err(key, f"{choices=}") out.append( gr.update( choices=choices, value=key if key in choices else choices[0], interactive=True, visible=True, ) ) print_err(out) return out def load_barcode_names( df: pd.DataFrame, strain_col: str ) -> List[List[str]]: strains = sorted(df[strain_col].unique()) print_err(strain_col, f"{strains=}") return gr.update( choices=strains, value="wt" if "wt" in strains else strains[0], interactive=True, visible=True, ), gr.update( choices=strains, value="spike" if "spike" in strains else "", interactive=True, visible=True, ), gr.update( choices=strains, value=[], allow_custom_value=True, interactive=True, visible=True, ) def _prepare_to_fit( counts: pd.DataFrame, strain_sheet: pd.DataFrame, sample_sheet: pd.DataFrame, count_column: str, strain_id_column: str, timepoint_column: str, concentration_column: str, sample_id_column: str, culture_id_column: str, volume_column: str = "volume", growth_column: str = "growth", reference: str = "wt", spike_name: str = "spike", spike_mode: str = "Spike", growth_type: str = "density", pseudocount: float = 1. ): use_spike = (spike_mode == "Spike") adata = load_anndata( counts=counts, sample_meta=sample_sheet, strain_meta=strain_sheet, reference=reference, count_column=count_column, timepoint_column=timepoint_column, # t0=args.t0, concentration_column=concentration_column, strain_id=strain_id_column, sample_id=sample_id_column, culture_id=culture_id_column, spike=spike_name if spike_name else None, ) print_err(adata) adata = compute_log_ratios( adata=adata, pseudocount=pseudocount, volume_column=volume_column, growth_column=growth_column, growth_type=growth_type, use_spike=use_spike, ) print_err(adata) return adata def do_analysis(*args): args = [ a if not (isinstance(a, str) and a == "") else None for a in args ] mode = args[-1] concentration_column = args[6] print_err(args[:-1]) try: adata = _prepare_to_fit(*args[:-1]) except TypeError as e: print_err(*args[:-1]) raise e _message("Using a weighted least squares model") model = AnnDataWLSModel() results = model.fit(adata=adata) if mode == MODES["dose response"]: _message( "Using a Hill non-linear model with " f"'{concentration_column}' for concentration." ) model = AnnDataHillModel() results = model.fit(adata=results, concentration=concentration_column) return gr.update(value=results.obs, visible=True), results def _fig2img(fig): import PIL # img = PIL.Image.frombytes( # "RGBa", # fig.canvas.get_width_height(), # fig.canvas.buffer_rgba(), # ) import io buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = PIL.Image.open(buf) return img @decorator_with_params def _plot_wrapper(fn, message="Plotting..."): def _fn(*args, **kwargs): if args[1] == "": args[1] = None if message: _message(message) fig, axes = fn(*args, **kwargs) if isinstance(fig, tuple) and isinstance(axes, bool): fig, vis = fig return gr.update( value=_fig2img(fig) if fig is not None else fig, visible=vis, ) else: return gr.update(value=_fig2img(fig), visible=True) return _fn @_plot_wrapper() def _plot_dose_response( adata, highlight=None, control_prefix: str = "ctrl_", mode: str = MODES["single"] ): do_dose_response = mode == MODES["dose response"] if do_dose_response: print_err("Plotting dose response") fig, axes = dose_response( adata, highlight_barcodes=highlight, model_name="WLS", control_prefix=control_prefix, ) return fig, axes else: print_err("Skipping dose response") return (None, None), False @_plot_wrapper(message="Plotting time vs count") def _plot_time_vs_count( adata, highlight=None, control_prefix: str = "ctrl_", *args, **kwargs ): return time_vs_count( adata, highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting expansion vs count") def _plot_expansion_vs_count( adata, highlight=None, control_prefix: str = "ctrl_", *args, **kwargs ): return expansion_vs_count( adata, highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting time vs ratio") def _plot_time_vs_ratio( adata, highlight=None, control_prefix: str = "ctrl_", *args, **kwargs ): return time_vs_ratio( adata, highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting expansion vs ratio") def _plot_expansion_vs_ratio( adata, highlight=None, control_prefix: str = "ctrl_", *args, **kwargs ): return expansion_vs_ratio( adata, highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting predicted vs observed") def _plot_pred_vs_true( adata, highlight=None, control_prefix: str = "ctrl_", mode: str = MODES["single"] ): do_dose_response = mode == MODES["dose response"] return pred_vs_true( adata, model_name="HillFitnessModel" if do_dose_response else "WLS", highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting predicted vs residuals") def _plot_pred_vs_resid( adata, highlight=None, control_prefix: str = "ctrl_", mode: str = MODES["single"] ): do_dose_response = mode == MODES["dose response"] return pred_vs_resid( adata, model_name="HillFitnessModel" if do_dose_response else "WLS", highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting counts vs residuals") def _plot_count_vs_resid( adata, highlight=None, control_prefix: str = "ctrl_", mode: str = MODES["single"] ): do_dose_response = mode == MODES["dose response"] return count_vs_resid( adata, model_name="HillFitnessModel" if do_dose_response else "WLS", highlight_barcodes=highlight, control_prefix=control_prefix, ) @_plot_wrapper(message="Plotting volcano") def _plot_volcano( adata, highlight=None, control_prefix: str = "ctrl_", mode: str = MODES["single"] ): do_dose_response = mode == MODES["dose response"] return volcano( adata, highlight_barcodes=highlight, control_prefix=control_prefix, model_name="HillFitnessModel" if do_dose_response else "WLS", param="ic50" if do_dose_response else "fitness", xscale="log" if do_dose_response else "linear", vline=None if do_dose_response else 1., p="log_ic50_p" if do_dose_response else "slope_p", ) def download_tables( df: pd.DataFrame, adata ) -> str: df_hash = nm.hash(pd.util.hash_pandas_object(df).values) filename = f"bartab-{df_hash}" filename_csv = f"{filename}.csv" df.to_csv(filename, index=False) filename_adata = f"{filename}.h5ad" adata.write(filename_adata) return gr.update( value=filename_csv, visible=True, ), gr.update( value=filename_adata, visible=True, ) def _file_input(**kwargs): return partial(gr.File, file_types=[".xlsx", ".csv", ".tsv", ".txt"], )(**kwargs) def _load_from_file(*args): if len(args) > 1: return args else: return args[0] def _invisible_dropdown(**kwargs): return partial(gr.Dropdown, choices=[], interactive=False, visible=True, )(**kwargs) def _invisible_plot(**kwargs): return partial(gr.Image, visible=False, )(**kwargs) with gr.Blocks() as demo: gr.Markdown( f""" # 🍹 bartab: Fitness from pooled competition assays *Using* `bartab` v{bartab.__version__} | [Documentation](https://github.com/scbirlab/bartab) | [Tutorial on analysis principles](https://huggingface.co/spaces/scbirlab/tutorial-seq-fitness) Infer the competitive fitness of barcoded strains from next-generation sequencing of pooled growth experiments. Upload your count table, sample sheet, and barcode sheet, then click **Calculate fitness**. """ ) gr.Markdown( """ --- ## 1️⃣ Input tables Three tables are required. You can upload CSV, TSV, or XLSX files, or try one of the **example datasets** below. """ ) input_filenames = { "count_table": gr.Textbox(interactive=False, visible=False), "sample_sheet": gr.Textbox(interactive=False, visible=False), "strain_sheet": gr.Textbox(interactive=False, visible=False), } app_root = os.path.dirname(__file__) data_path = os.path.join(app_root, "data", "examples", "single-point") control_strains = { "reference": _invisible_dropdown( label="Reference (WT) barcode name", render=False, ), "spike": _invisible_dropdown( label="Spike-in barcode name (if using)", render=False, ), } analysis_opts = { "use_spike": gr.Radio( label="Culture expansion uses:", choices=["Spike", "Growth"], value="Spike", render=False, ), "growth_type": gr.Radio( label="Growth type", choices=["density", "generations"], value="density", visible=False, render=False, ), "pseudocount": gr.Number( label="Pseudocount", value=1., render=False, ), } plotting_opts = { "highlight": _invisible_dropdown( label="Strain(s) to highlight in plots", render=False, multiselect=True, ), "controls": gr.Textbox( label="Prefix of control barcode names", value="ctrl_", render=False, ), } mode_switch = gr.Radio( label="Analysis mode", choices=list(MODES.values()), value=MODES["single"], render=False, ) examples = gr.Examples( label="Examples with synthetic data", examples=[ [ os.path.join(data_path, "test_count.csv"), os.path.join(data_path, "test_sample_meta.csv"), os.path.join(data_path, "test_strain_meta.csv"), MODES["single"], "Spike", ], [ os.path.join(data_path, "test_count.csv"), os.path.join(data_path, "test_sample_meta.csv"), os.path.join(data_path, "test_strain_meta.csv"), MODES["single"], "Growth", ], [ os.path.join(data_path, "dose-response_count.csv"), os.path.join(data_path, "dose-response_sample_meta.csv"), os.path.join(data_path, "dose-response_strain_meta.csv"), MODES["dose response"], "Spike", ], [ os.path.join(data_path, "dose-response_count.csv"), os.path.join(data_path, "dose-response_sample_meta.csv"), os.path.join(data_path, "dose-response_strain_meta.csv"), MODES["dose response"], "Growth", ], ], example_labels=[ ["Single point, using spike-in"], ["Single point, using growth"], ["Dose response, using spike-in"], ["Dose response, using growth"], ], inputs=[ input_filenames["count_table"], input_filenames["sample_sheet"], input_filenames["strain_sheet"], mode_switch, analysis_opts["use_spike"], ], # cache_examples=True, # cache_mode="eager", ) input_files = {} input_cols = {} go_button = gr.Button( value="🚀 Calculate fitness!", interactive=False, render=False, ) with gr.Row(): with gr.Column(): gr.Markdown( """ --- ### 🧮 Count table One row per barcode per sample. Must contain: - a column of **barcode/strain identifiers** (matching your barcode sheet) - a column of **sample identifiers** (matching your sample sheet) - a column of **read or UMI counts** """ ) input_files["count_table"] = _file_input( label="Upload your barcode sequencing counts data here", ) input_cols["count_table"] = { "count": (_invisible_dropdown( label="Counts column", ), "numeric"), } with gr.Column(): gr.Markdown( """ --- ### 📶 Barcode information One row per unique barcode. Must contain: - a column of **barcode identifiers** Optionally: any metadata about strains (gene targets, constructs, etc.). These will be carried through to the output. """ ) input_files["strain_sheet"] = _file_input( label="Upload your barcode information here", ) input_cols["strain_sheet"] = { "strain_id": (_invisible_dropdown( label="Barcode identifier column", ), "string"), } with gr.Row(): gr.Markdown( r""" --- ### 🧪 Sample sheet One row per sample (sequencing library). Must contain: - **Sample ID**: unique identifier matching the count table - **Culture ID**: biological replicate identifier. Samples from the same culture share this label - **Timepoint**: numeric timepoint values. The earliest timepoint is treated as $t_0$. **For spike-in normalisation**: no extra columns needed. Just include your spike-in barcode in the count table and barcode sheet. **For growth-based normalisation**: add a column of OD600, CFU/mL, or generation counts measured at each sample. **For dose-response analysis**: add a column of inducer/drug concentrations. Samples with concentration = 0 are treated as uninduced controls. **For adaptive-volume sampling** (if you took different volumes from each sample): add a column of sampled volumes. """ ) input_files["sample_sheet"] = _file_input( label="Upload your sample information here", ) with gr.Row(): input_cols["sample_sheet"] = { "timepoint": (_invisible_dropdown( label="Timepoint column", ), "any"), "dose": (_invisible_dropdown( label="Concentration column", ), "numeric"), "sample_id": (_invisible_dropdown( label="Individual sample ID column", ), "string"), "replicate": (_invisible_dropdown( label="Culture / biological replicate column", ), "string"), "volume": (_invisible_dropdown( label="Volume column (if using)", ), "numeric"), "growth": (_invisible_dropdown( label="Growth column (if using)", ), "numeric"), } with gr.Row(): input_data = { key: gr.Dataframe( label=f"Input data: {key}", max_height=50, visible=False, interactive=False, ) for key in input_files } adata = gr.State() with gr.Row(): with gr.Column(): gr.Markdown( """ --- ## 2️⃣ Control strains - **Reference (WT)**: the strain relative to which all fitness values are calculated. Fitness = 1 by definition. - **Spike-in**: a non-growing strain (e.g. heat-killed or plasmid-only) added at a fixed concentration before library preparation. Used to infer how much the reference strain has expanded between timepoints, removing the need for growth measurements. Leave blank if using growth measurements instead. """ ) for key, val in control_strains.items(): val.render() with gr.Column(): gr.Markdown( """ --- ## 3️⃣ Analysis options - **Culture expansion**: choose **Spike** if you have a non-growing spike-in control, or **Growth** if you have OD600/CFU measurements. - **Pseudocount**: added to all counts before log transformation to avoid log(0). - **Analysis mode**: choose **Single concentration** for standard fitness screens, or **Dose response** if your sample sheet contains a concentration column. Dose response fitting uses a 2-parameter Hill model to estimate the IC₅₀ and maximum effectfor each barcode. """ ) for key, val in analysis_opts.items(): val.render() with gr.Column(): gr.Markdown( """ --- ## Plotting options """ ) for key, val in plotting_opts.items(): val.render() with gr.Column(): mode_switch.render() go_button.render() mode_switch.change( lambda x: gr.update(value=x), inputs=[mode_switch], outputs=[go_button], ) gr.Markdown( r""" ## 4️⃣ Results Fitness values are estimated by weighted least squares regression of the log-ratio of each barcode against the reference strain, using the spike-in or growth measurements as the x-axis. **Key output columns**: - For single concentration: - `fitness`: relative fitness ($w_i / w_{wt}$). Values < 1 indicate growth disadvantage; > 1 indicates advantage. - `fitness_low` / `fitness_high`: 95% confidence interval bounds. - `slope_p`: p-value for the slope being different from 0 (i.e. fitness ≠ 1). - For dose-response: - `log_ic50` (log₁₀ concentration at 50% inhibition) and `log_ic50_p`. Results and the full annotated dataset (`.h5ad`) can be downloaded below. """ ) plots = {} with gr.Row(): plots |= { "dose_response": ( _invisible_plot(label="Dose response"), _plot_dose_response, ), "count_time": ( _invisible_plot(label="Time vs count"), _plot_time_vs_count, ), "count_exp": ( _invisible_plot(label="Expansion vs count"), _plot_expansion_vs_count, ), } with gr.Row(): plots |= { "count_exp": ( _invisible_plot(label="Time vs ratio"), _plot_time_vs_ratio, ), "ratio_exp": ( _invisible_plot(label="Expansion vs ratio"), _plot_expansion_vs_ratio, ), } with gr.Row(): plots |= { "pred_obs": ( _invisible_plot(label="Predicted vs observed"), _plot_pred_vs_true, ), "volcano": ( _invisible_plot(label="Volcano"), _plot_volcano, ), } with gr.Row(): plots |= { "pred_resid": ( _invisible_plot(label="Predicted vs residuals"), _plot_pred_vs_resid, ), "count_resid": ( _invisible_plot(label="Counts vs residuals"), _plot_count_vs_resid, ), } with gr.Row(): download = gr.DownloadButton( label="Download parameters as CSV", visible=False, ) download_adata = gr.DownloadButton( label="Download all analysis as .h5ad", visible=False, ) output_table = gr.Dataframe( label="Fitted parameters", # max_height=100, visible=False, interactive=False, ) # ====== # EVENTS # ====== for key, input_file in input_files.items(): input_columns = input_cols[key] event_fn = { "fn": partial(load_input_data, cols=input_columns), "outputs": [input_data[key]] + [ col[0] if isinstance(col, tuple) else col for _, col in input_columns.items() ], } input_filenames[key].change( _load_from_file, inputs=[input_filenames[key]], outputs=[input_file], ).then( **event_fn, inputs=[input_filenames[key]], ) input_file.upload( **event_fn, inputs=[input_file], ) input_cols["strain_sheet"]["strain_id"][0].change( load_barcode_names, inputs=[ input_data["strain_sheet"], input_cols["strain_sheet"]["strain_id"][0], ], outputs=[ control_strains["reference"], control_strains["spike"], plotting_opts["highlight"], ], ).then( lambda : gr.update(interactive=True), inputs=[], outputs=[go_button], ) analysis_opts["use_spike"].change( lambda x: gr.update(visible=x == "Growth"), inputs=[analysis_opts["use_spike"]], outputs=[analysis_opts["growth_type"]], ) comptuation_inputs = ( [ f for _, f in input_data.items() ] + [ opt[0] if isinstance(opt, tuple) else opt for _, input_col in input_cols.items() for _, opt in input_col.items() ] + [ v for k, v in control_strains.items() ] + [ v for k, v in analysis_opts.items() ] ) evt = go_button.click( fn=do_analysis, inputs=comptuation_inputs + [mode_switch], outputs=[ output_table, adata, ], ).then( download_tables, inputs=[output_table, adata], outputs=[download, download_adata], ) for key, (p, fn) in plots.items(): evt = evt.then( fn, inputs=[ adata, plotting_opts["highlight"], plotting_opts["controls"], mode_switch, ], outputs=[p], ) if __name__ == "__main__": demo.queue() demo.launch(share=True)