| | |
| |
|
| | import copy |
| | import tempfile |
| |
|
| | import gradio as gr |
| | from huggingface_hub import CommitOperationAdd, HfApi |
| | import pandas as pd |
| | from papers import PaperList |
| |
|
| | REPO_ID = "CVPR2024/CVPR2024-papers" |
| | FILENAME = "data.csv" |
| |
|
| | api = HfApi() |
| |
|
| |
|
| | paper_list = PaperList() |
| |
|
| | path = api.hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="dataset") |
| | actual_df = pd.read_csv(path) |
| | paper_id_to_index = {str(row["id"]): i for i, row in actual_df.iterrows()} |
| |
|
| |
|
| | with gr.Blocks() as demo_search: |
| | with gr.Group(): |
| | search_title = gr.Textbox(label="Search title") |
| | search_author = gr.Textbox(label="Search author") |
| | df = gr.Dataframe( |
| | value=paper_list.df_prettified, |
| | datatype=paper_list.get_column_datatypes(paper_list.get_column_names()), |
| | type="pandas", |
| | row_count=(0, "dynamic"), |
| | interactive=False, |
| | height=1000, |
| | elem_id="table", |
| | wrap=True, |
| | ) |
| |
|
| | inputs = [ |
| | search_title, |
| | search_author, |
| | ] |
| | gr.on( |
| | triggers=[ |
| | search_title.submit, |
| | search_author.submit, |
| | ], |
| | fn=paper_list.search, |
| | inputs=inputs, |
| | outputs=df, |
| | queue=False, |
| | api_name=False, |
| | ) |
| | demo_search.load( |
| | fn=paper_list.search, |
| | inputs=inputs, |
| | outputs=df, |
| | queue=False, |
| | api_name=False, |
| | ) |
| |
|
| |
|
| | def load_data(paper_id: str) -> tuple[str, str, str, str, str, str, str, str, str]: |
| | try: |
| | index = paper_id_to_index[paper_id] |
| | except KeyError: |
| | raise gr.Error(f"Paper ID {paper_id} not found.") |
| | |
| | paper = actual_df.iloc[index] |
| |
|
| | return ( |
| | paper["id"], |
| | paper["title"], |
| | paper["authors"], |
| | paper["arxiv_id"], |
| | "\n".join([PaperList.create_link("GitHub", url) for url in paper["GitHub"]] if paper["GitHub"]!="[]" else " "), |
| | "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}") |
| | for repo_id in paper["Space"] |
| | ] if paper["Space"] != "[]" else [" "]), |
| | "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in paper["Model"]] |
| | if paper["Model"] != "[]" else [" "]), |
| | "\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") for repo_id in paper["Dataset"] |
| | ] if paper["Dataset"] != "[]" else [" "] |
| | ) |
| | ) |
| |
|
| |
|
| | def split_and_strip(s: str) -> list[str]: |
| | return [x.strip() for x in s.split("\n") if x.strip()] |
| |
|
| |
|
| | def create_pr( |
| | paper_id: str, |
| | title: str, |
| | authors: str, |
| | arxiv_id: str, |
| | project_page: str, |
| | github_links: str, |
| | space_ids: str, |
| | model_ids: str, |
| | dataset_ids: str, |
| | oauth_token: gr.OAuthToken | None, |
| | ) -> str: |
| | if oauth_token is None: |
| | return "Please log in first." |
| |
|
| | try: |
| | index = paper_id_to_index[paper_id] |
| | except KeyError: |
| | raise gr.Error(f"Paper ID {paper_id} not found.") |
| |
|
| | data = copy.deepcopy(df) |
| | data[index]["title"] = title.strip() |
| | data[index]["authors"] = authors |
| | data[index]["arxiv_id"] = arxiv_id.strip() |
| | data[index]["GitHub"] = github_links |
| | data[index]["Space"] = space_ids |
| | data[index]["Model"] = model_ids |
| | data[index]["Dataset"] = dataset_ids |
| | with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: |
| | data.to_csv(f) |
| |
|
| | commit = CommitOperationAdd(FILENAME, f.name) |
| | res = api.create_commit( |
| | repo_id=REPO_ID, |
| | operations=[commit], |
| | commit_message=f"Update {paper_id}", |
| | repo_type="dataset", |
| | create_pr=True, |
| | token=oauth_token.token, |
| | ) |
| | return res.pr_url |
| |
|
| |
|
| | with gr.Blocks() as demo_edit: |
| | with gr.Group(): |
| | paper_id = gr.Textbox(label="ID", max_lines=1) |
| | load_button = gr.Button("Load") |
| | with gr.Group(): |
| | title = gr.Textbox(label="Title", max_lines=1) |
| | authors = gr.Textbox(label="Authors", lines=5) |
| | arxiv_id = gr.Textbox(label="arXiv ID", max_lines=1, placeholder="2404.00000") |
| | github_links = gr.Textbox( |
| | label="GitHub links", |
| | lines=5, |
| | placeholder="https://github.com/aaa/bbb\nhttps://github.com/ccc/ddd", |
| | ) |
| | space_ids = gr.Textbox(label="Space IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") |
| | model_ids = gr.Textbox(label="Model IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2") |
| | dataset_ids = gr.Textbox( |
| | label="Dataset IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2" |
| | ) |
| | create_pr_button = gr.Button("Create PR") |
| | result = gr.Textbox(label="Result", max_lines=1) |
| |
|
| | gr.on( |
| | triggers=[ |
| | paper_id.submit, |
| | load_button.click, |
| | ], |
| | fn=load_data, |
| | inputs=paper_id, |
| | outputs=[ |
| | paper_id, |
| | title, |
| | authors, |
| | arxiv_id, |
| | github_links, |
| | space_ids, |
| | model_ids, |
| | dataset_ids, |
| | ], |
| | queue=False, |
| | api_name=False, |
| | ) |
| | create_pr_button.click( |
| | fn=create_pr, |
| | inputs=[ |
| | paper_id, |
| | title, |
| | authors, |
| | arxiv_id, |
| | github_links, |
| | space_ids, |
| | model_ids, |
| | dataset_ids, |
| | ], |
| | outputs=result, |
| | queue=False, |
| | api_name=False, |
| | ) |
| |
|
| | with gr.Blocks(css="style.css") as demo: |
| | gr.Markdown( |
| | "You can create PRs to update the CSV files in the [CVPR2024-papers repo](https://huggingface.co/datasets/CVPR2024/CVPR2024-papers) with this Space." |
| | ) |
| | with gr.Tabs(): |
| | with gr.Tab(label="Step 1: Login"): |
| | gr.Markdown("To create a PR, you first need to log in. Please press the login button below.") |
| | gr.LoginButton() |
| | with gr.Tab(label="Step 2: Search for paper ID"): |
| | gr.Markdown("Search for the paper you would like to update and find its paper ID.") |
| | demo_search.render() |
| | with gr.Tab(label="Step 3: Edit and create PR"): |
| | gr.Markdown("Enter the paper ID in the field below and press the Load button.") |
| | gr.Markdown("After making the necessary changes, press the Create PR button.") |
| | demo_edit.render() |
| |
|
| | if __name__ == "__main__": |
| | demo.queue(api_open=False).launch(show_api=False, debug=True) |
| |
|