Merve Noyan
fixes
521af34
#!/usr/bin/env python
import copy
import tempfile
import gradio as gr
from huggingface_hub import CommitOperationAdd, HfApi
import pandas as pd
from papers import PaperList
REPO_ID = "CVPR2024/CVPR2024-papers"
FILENAME = "data.csv"
api = HfApi()
paper_list = PaperList()
path = api.hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type="dataset")
actual_df = pd.read_csv(path)
paper_id_to_index = {str(row["id"]): i for i, row in actual_df.iterrows()}
with gr.Blocks() as demo_search:
with gr.Group():
search_title = gr.Textbox(label="Search title")
search_author = gr.Textbox(label="Search author")
df = gr.Dataframe(
value=paper_list.df_prettified,
datatype=paper_list.get_column_datatypes(paper_list.get_column_names()),
type="pandas",
row_count=(0, "dynamic"),
interactive=False,
height=1000,
elem_id="table",
wrap=True,
)
inputs = [
search_title,
search_author,
]
gr.on(
triggers=[
search_title.submit,
search_author.submit,
],
fn=paper_list.search,
inputs=inputs,
outputs=df,
queue=False,
api_name=False,
)
demo_search.load(
fn=paper_list.search,
inputs=inputs,
outputs=df,
queue=False,
api_name=False,
)
def load_data(paper_id: str) -> tuple[str, str, str, str, str, str, str, str, str]:
try:
index = paper_id_to_index[paper_id]
except KeyError:
raise gr.Error(f"Paper ID {paper_id} not found.")
paper = actual_df.iloc[index]
return (
paper["id"],
paper["title"],
paper["authors"],
paper["arxiv_id"],
"\n".join([PaperList.create_link("GitHub", url) for url in paper["GitHub"]] if paper["GitHub"]!="[]" else " "),
"\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}")
for repo_id in paper["Space"]
] if paper["Space"] != "[]" else [" "]),
"\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in paper["Model"]]
if paper["Model"] != "[]" else [" "]),
"\n".join([PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") for repo_id in paper["Dataset"]
] if paper["Dataset"] != "[]" else [" "]
)
)
def split_and_strip(s: str) -> list[str]:
return [x.strip() for x in s.split("\n") if x.strip()]
def create_pr(
paper_id: str,
title: str,
authors: str,
arxiv_id: str,
project_page: str,
github_links: str,
space_ids: str,
model_ids: str,
dataset_ids: str,
oauth_token: gr.OAuthToken | None,
) -> str:
if oauth_token is None:
return "Please log in first."
try:
index = paper_id_to_index[paper_id]
except KeyError:
raise gr.Error(f"Paper ID {paper_id} not found.")
data = copy.deepcopy(df)
data[index]["title"] = title.strip()
data[index]["authors"] = authors
data[index]["arxiv_id"] = arxiv_id.strip()
data[index]["GitHub"] = github_links
data[index]["Space"] = space_ids
data[index]["Model"] = model_ids
data[index]["Dataset"] = dataset_ids
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
data.to_csv(f)
commit = CommitOperationAdd(FILENAME, f.name)
res = api.create_commit(
repo_id=REPO_ID,
operations=[commit],
commit_message=f"Update {paper_id}",
repo_type="dataset",
create_pr=True,
token=oauth_token.token,
)
return res.pr_url
with gr.Blocks() as demo_edit:
with gr.Group():
paper_id = gr.Textbox(label="ID", max_lines=1)
load_button = gr.Button("Load")
with gr.Group():
title = gr.Textbox(label="Title", max_lines=1)
authors = gr.Textbox(label="Authors", lines=5)
arxiv_id = gr.Textbox(label="arXiv ID", max_lines=1, placeholder="2404.00000")
github_links = gr.Textbox(
label="GitHub links",
lines=5,
placeholder="https://github.com/aaa/bbb\nhttps://github.com/ccc/ddd",
)
space_ids = gr.Textbox(label="Space IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2")
model_ids = gr.Textbox(label="Model IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2")
dataset_ids = gr.Textbox(
label="Dataset IDs", lines=5, placeholder="org_name1/repo_name1\norg_name2/repo_name2"
)
create_pr_button = gr.Button("Create PR")
result = gr.Textbox(label="Result", max_lines=1)
gr.on(
triggers=[
paper_id.submit,
load_button.click,
],
fn=load_data,
inputs=paper_id,
outputs=[
paper_id,
title,
authors,
arxiv_id,
github_links,
space_ids,
model_ids,
dataset_ids,
],
queue=False,
api_name=False,
)
create_pr_button.click(
fn=create_pr,
inputs=[
paper_id,
title,
authors,
arxiv_id,
github_links,
space_ids,
model_ids,
dataset_ids,
],
outputs=result,
queue=False,
api_name=False,
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(
"You can create PRs to update the CSV files in the [CVPR2024-papers repo](https://huggingface.co/datasets/CVPR2024/CVPR2024-papers) with this Space."
)
with gr.Tabs():
with gr.Tab(label="Step 1: Login"):
gr.Markdown("To create a PR, you first need to log in. Please press the login button below.")
gr.LoginButton()
with gr.Tab(label="Step 2: Search for paper ID"):
gr.Markdown("Search for the paper you would like to update and find its paper ID.")
demo_search.render()
with gr.Tab(label="Step 3: Edit and create PR"):
gr.Markdown("Enter the paper ID in the field below and press the Load button.")
gr.Markdown("After making the necessary changes, press the Create PR button.")
demo_edit.render()
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, debug=True)