feat: Add separate sliders for all and selected repositories in the PapersWithCode tasks tab
15420a6
| import gradio as gr | |
| import pandas as pd | |
| import logging | |
| import re | |
| from task_visualizations import TaskVisualizations | |
| import plotly.graph_objects as go | |
| logging.basicConfig(level=logging.INFO) | |
| class AppConfig: | |
| repo_representations_path = "data/repo_representations.jsonl" | |
| task_counts_path = "data/repos_task_counts.csv" | |
| selected_task_counts_path = "data/selected_repos_task_counts.csv" | |
| tasks_path = "data/paperswithcode_tasks.csv" | |
| def load_repo_df(repo_representations_path): | |
| data = pd.read_json(repo_representations_path, lines=True, orient="records") | |
| return data.assign( | |
| text=data["text"] | |
| .str.replace(r"<img.*\/>", "", regex=True) | |
| .str.replace("│", "\n") | |
| .str.replace("⋮", "\n") | |
| ) | |
| def display_representations(repo, representation1, representation2): | |
| repo_data = repos_df[repos_df["repo_name"] == repo] | |
| logging.info(f"repo_data: {repo_data}") | |
| text1 = ( | |
| repo_data[repo_data["representation"] == representation1]["text"].iloc[0] | |
| if not repo_data[repo_data["representation"] == representation1].empty | |
| else "No data available" | |
| ) | |
| text2 = ( | |
| repo_data[repo_data["representation"] == representation2]["text"].iloc[0] | |
| if not repo_data[repo_data["representation"] == representation2].empty | |
| else "No data available" | |
| ) | |
| return text1, text2 | |
| def setup_repository_representations_tab(repos, representation_types): | |
| gr.Markdown("Select a repository and two representation types to compare them.") | |
| with gr.Row(): | |
| repo = gr.Dropdown(choices=repos, label="Repository", value=repos[0]) | |
| representation1 = gr.Dropdown( | |
| choices=representation_types, label="Representation 1", value="readme" | |
| ) | |
| representation2 = gr.Dropdown( | |
| choices=representation_types, | |
| label="Representation 2", | |
| value="generated_readme", | |
| ) | |
| with gr.Row(): | |
| with gr.Column( | |
| elem_id="column1", | |
| variant="panel", | |
| scale=1, | |
| min_width=300, | |
| ): | |
| text1 = gr.Markdown() | |
| with gr.Column( | |
| elem_id="column2", | |
| variant="panel", | |
| scale=1, | |
| min_width=300, | |
| ): | |
| text2 = gr.Markdown() | |
| def update_representations(repo, representation1, representation2): | |
| text1_content, text2_content = display_representations( | |
| repo, representation1, representation2 | |
| ) | |
| return ( | |
| f"### Representation 1: {representation1}\n\n{text1_content}", | |
| f"### Representation 2: {representation2}\n\n{text2_content}", | |
| ) | |
| # Initial call to populate textboxes with default values | |
| text1.value, text2.value = update_representations( | |
| repos[0], "readme", "generated_readme" | |
| ) | |
| for component in [repo, representation1, representation2]: | |
| component.change( | |
| fn=update_representations, | |
| inputs=[repo, representation1, representation2], | |
| outputs=[text1, text2], | |
| ) | |
| ## main | |
| repos_df = load_repo_df(AppConfig.repo_representations_path) | |
| repos = list(repos_df["repo_name"].unique()) | |
| representation_types = list(repos_df["representation"].unique()) | |
| logging.info(f"found {len(repos)} repositories") | |
| logging.info(f"representation types: {representation_types}") | |
| task_visualizations = TaskVisualizations( | |
| AppConfig.task_counts_path, | |
| AppConfig.selected_task_counts_path, | |
| AppConfig.tasks_path, | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Explore Repository Representations"): | |
| setup_repository_representations_tab(repos, representation_types) | |
| with gr.Tab("Explore PapersWithCode Tasks"): | |
| task_counts_description = """ | |
| ## PapersWithCode Tasks Visualization | |
| PapersWithCode tasks are grouped by area. | |
| """.strip() | |
| gr.Markdown(task_counts_description) | |
| with gr.Row(): | |
| min_task_counts_slider_all = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| value=100, | |
| step=10, | |
| label="Minimum Task Count (All Repositories)", | |
| ) | |
| min_task_counts_slider_selected = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| value=100, | |
| step=10, | |
| label="Minimum Task Count (Selected Repositories)", | |
| ) | |
| update_button = gr.Button("Update Plots") | |
| with gr.Row("Task Counts"): | |
| all_repos_tasks_plot = gr.Plot(label="All Repositories") | |
| selected_repos_tasks_plot = gr.Plot(label="Selected Repositories") | |
| update_button.click( | |
| fn=task_visualizations.get_tasks_sunbursts, | |
| inputs=[min_task_counts_slider_all, min_task_counts_slider_selected], | |
| outputs=[all_repos_tasks_plot, selected_repos_tasks_plot], | |
| ) | |
| demo.launch() | |