Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import random | |
| import numpy as np | |
| import gdown | |
| from datasets import load_dataset | |
| from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver | |
| # download data from huggingface dataset | |
| # dataset = load_dataset("quantumiracle-git/robotinder-data") | |
| # download data from google drive | |
| # url = 'https://drive.google.com/drive/folders/10UmNM2YpvNSkdLMgYiIAxk5IbS4dUezw?usp=sharing' | |
| # output = './' | |
| # id = url.split('/')[-1] | |
| # os.system(f"gdown --id {id} -O {output} --folder --no-cookies") | |
| def video_identity(video): | |
| return video | |
| def nan(): | |
| return None | |
| # demo = gr.Interface(video_identity, | |
| # gr.Video(), | |
| # "playable_video", | |
| # examples=[ | |
| # os.path.join(os.path.dirname(__file__), | |
| # "videos/rl-video-episode-0.mp4")], | |
| # cache_examples=True) | |
| FORMAT = ['mp4', 'gif'][1] | |
| def update(data_folder='videos'): | |
| # data_folder='videos' | |
| envs = parse_envs() | |
| env_name = envs[random.randint(0, len(envs)-1)] | |
| # choose video | |
| videos = os.listdir(os.path.join(data_folder, env_name)) | |
| video_files = [] | |
| for f in videos: | |
| if f.endswith(f'.{FORMAT}'): | |
| video_files.append(os.path.join(data_folder, env_name, f)) | |
| # choose two videos | |
| selected_video_ids = np.random.choice(len(video_files), 2, replace=False) | |
| left = video_files[selected_video_ids[0]] | |
| right = video_files[selected_video_ids[1]] | |
| print(env_name, left, right) | |
| return left, right | |
| # def update(left, right): | |
| # if FORMAT == 'mp4': | |
| # left = os.path.join(os.path.dirname(__file__), | |
| # "videos/rl-video-episode-2.mp4") | |
| # right = os.path.join(os.path.dirname(__file__), | |
| # "videos/rl-video-episode-3.mp4") | |
| # else: | |
| # left = os.path.join(os.path.dirname(__file__), | |
| # "videos/rl-video-episode-2.gif") | |
| # right = os.path.join(os.path.dirname(__file__), | |
| # "videos/rl-video-episode-3.gif") | |
| # print(left, right) | |
| # return left, right | |
| def replay(left, right): | |
| return left, right | |
| def parse_envs(folder='./videos'): | |
| envs = [] | |
| for f in os.listdir(folder): | |
| if os.path.isdir(os.path.join(folder, f)): | |
| envs.append(f) | |
| return envs | |
| def build_interface(iter=3, data_folder='./videos'): | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| print(HF_TOKEN) | |
| HF_TOKEN = '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' # my HF token | |
| # hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") # HuggingFace logger instead of local one: https://github.com/gradio-app/gradio/blob/master/gradio/flagging.py | |
| hf_writer = HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") | |
| # callback = gr.CSVLogger() | |
| callback = hf_writer | |
| # build gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Here is RoboTinder!") | |
| gr.Markdown("Select the best robot behaviour in your choice!") | |
| with gr.Row(): | |
| # some initial videos | |
| if FORMAT == 'mp4': | |
| left_video_path = os.path.join(os.path.dirname(__file__), | |
| "videos/rl-video-episode-0.mp4") | |
| right_video_path = os.path.join(os.path.dirname(__file__), | |
| "videos/rl-video-episode-1.mp4") | |
| left = gr.PlayableVideo(left_video_path, label="left_video") | |
| right = gr.PlayableVideo(right_video_path, label="right_video") | |
| else: | |
| left_video_path = os.path.join(os.path.dirname(__file__), | |
| "videos/rl-video-episode-0.gif") | |
| right_video_path = os.path.join(os.path.dirname(__file__), | |
| "videos/rl-video-episode-1.gif") | |
| left = gr.Image(left_video_path, shape=(1024, 768), label="left_video") | |
| # right = gr.Image(right_video_path).style(height=768, width=1024) | |
| right = gr.Image(right_video_path, label="right_video") | |
| btn1 = gr.Button("Replay") | |
| user_choice = gr.Radio(["Left", "Right", "Not Sure"], label="Which one is your favorite?") | |
| btn2 = gr.Button("Next") | |
| # This needs to be called at some point prior to the first call to callback.flag() | |
| callback.setup([user_choice, left, right], "flagged_data_points") | |
| btn1.click(fn=replay, inputs=[left, right], outputs=[left, right]) | |
| btn2.click(fn=update, inputs=None, outputs=[left, right]) | |
| # We can choose which components to flag -- in this case, we'll flag all of them | |
| btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| # demo.launch(share=True) | |
| demo.launch(share=False) | |