import gradio as gr from datasets import disable_caching, load_dataset from transformer_ranker import TransformerRanker from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS, GRADIO_THEME from demo.utils import ( BANNER, FOOTER, CSS, UNSET, EmbeddingProgressTracker, compute_ratio, validate_dataset, preprocess_dataset, ensure_dataset_is_loaded ) disable_caching() with gr.Blocks(css=CSS, theme=None) as demo: gr.Markdown(BANNER) ##### 1. Load from datasets ##### gr.Markdown("## Load Downstream Dataset") gr.Markdown( "Select a dataset from the Hugging Face Hub such as `trec`. " "This defines your downstream task." ) with gr.Group(): dataset = gr.State(None) dataset_id = gr.Textbox( label="Dataset name", placeholder="try: trec, conll2003, ag_news", max_lines=1, ) load_dataset_button = gr.Button(value="Load data", variant="primary", interactive=True,) # enable loading if dataset exists on hub dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button) gr.Markdown( "Settings auto-configured. " "Adjust the downsampling ratio in Dataset Setup, " "or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)." ) ##### data preprocessing ##### with gr.Accordion("Dataset Setup", open=False) as dataset_config: with gr.Row() as dataset_details: dataset_id_label = gr.Label("", label="Dataset") num_samples = gr.State(0) num_samples_label = gr.Label("", label="Dataset size") num_samples.change( lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label] ) with gr.Row(): text_column = gr.Dropdown("", label="Text Column") text_pair_column = gr.Dropdown("", label="Text Pair") with gr.Row(): label_column = gr.Dropdown("", label="Labels") task_category = gr.Dropdown("", label="Downstream Task") with gr.Group(): downsample_ratio = gr.State(0.0) sampling_rate = gr.Slider( 20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1 ) downsample_ratio_label = gr.Label("", label="Sampling rate") downsample_ratio.change( lambda x: f"{x:.1%}", inputs=[downsample_ratio], outputs=[downsample_ratio_label], ) sampling_rate.change( compute_ratio, inputs=[sampling_rate, num_samples], outputs=downsample_ratio, ) num_samples.change( compute_ratio, inputs=[sampling_rate, num_samples], outputs=downsample_ratio, ) # load and show details def load_hf_dataset(dataset_id): try: dataset = load_dataset(dataset_id, trust_remote_code=True) dataset_details = preprocess_dataset(dataset) except ValueError as e: gr.Warning("Collections not supported. Load one dataset only.") return ( gr.update(value="Loaded"), dataset_id, dataset, *dataset_details ) load_dataset_button.click( load_hf_dataset, inputs=[dataset_id], outputs=[ load_dataset_button, dataset_id_label, dataset, task_category, text_column, text_pair_column, label_column, num_samples, ], scroll_to_output=True, ) ########## 2. Select LMs ########## gr.Markdown("## Select Language Models") gr.Markdown( "Add two or more pretrained models for ranking. " "Go with small models since this demo runs on CPU." ) with gr.Group(): model_options = [ (model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS ] models = gr.CheckboxGroup( choices=model_options, label="Model List", value=PRESELECTED_LMS ) ########## 3. Run ranking ########## gr.Markdown("## Rank Language Models") gr.Markdown( "Rank models by transferability to your downstream task. " "Adjust the metric and layer aggregation in Advanced Settings." ) with gr.Group(): submit_button = gr.Button("Run ranking", variant="primary", interactive=False) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): estimator = gr.Dropdown( choices=["hscore", "logme", "knn"], label="Transferability metric", value="hscore", ) layer_aggregator = gr.Dropdown( choices=["lastlayer", "layermean", "bestlayer"], label="Layer aggregation", value="layermean", ) # ranking button works after dataset loads dataset.change( ensure_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) label_column.change( ensure_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) text_column.change( ensure_dataset_is_loaded, inputs=[dataset, text_column, label_column, task_category], outputs=submit_button ) def rank_models( dataset, downsample_ratio, selected_models, layer_aggregator, estimator, text_column, text_pair_column, label_column, task_category, progress=gr.Progress(), ): if text_column == UNSET: raise gr.Error("Text column is not set.") if label_column == UNSET: raise gr.Error("Label column is not set.") if task_category == UNSET: raise gr.Error( "Task category not set. Dataset must support classification or regression." ) if text_pair_column == UNSET: text_pair_column = None progress(0.0, "Starting") with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: try: ranker = TransformerRanker( dataset, dataset_downsample=downsample_ratio, text_column=text_column, text_pair_column=text_pair_column, label_column=label_column, task_category=task_category, ) results = ranker.run( models=selected_models, layer_aggregator=layer_aggregator, estimator=estimator, batch_size=64, tracker=tracker, ) sorted_results = sorted( results._results.items(), key=lambda item: item[1], reverse=True ) return [ (i + 1, model, score) for i, (model, score) in enumerate(sorted_results) ] except Exception as e: print(e) gr.Warning(f"Ranking issue: {e}") return [] gr.Markdown("Ranking table → higher scores indicate better downstream performance.") ranking_results = gr.Dataframe( headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"], value=[["-", "-", "-"]] ) submit_button.click( rank_models, inputs=[ dataset, downsample_ratio, models, layer_aggregator, estimator, text_column, text_pair_column, label_column, task_category, ], outputs=ranking_results, scroll_to_output=True, ) gr.Markdown(FOOTER) if __name__ == "__main__": # run up to 3 requests at once demo.queue(default_concurrency_limit=3) # run with 6 workers demo.launch(max_threads=6)