Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,17 +2,47 @@ import gradio as gr
|
|
| 2 |
import random
|
| 3 |
from datasets import load_dataset
|
| 4 |
import os
|
| 5 |
-
hf_token = os.environ['hf_token']
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def get_random_video():
|
| 11 |
# 随机选择一个索引
|
| 12 |
-
random_index = random.randint(0, len(
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return video_path
|
| 17 |
|
| 18 |
# Gradio 接口
|
|
|
|
| 2 |
import random
|
| 3 |
from datasets import load_dataset
|
| 4 |
import os
|
|
|
|
| 5 |
|
| 6 |
+
hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌
|
| 7 |
+
submission_url = "Vchitect/VBench_sampled_video" # 数据集的 URL
|
| 8 |
+
local_dir = "VBench_sampled_video" # 本地文件夹路径
|
| 9 |
+
|
| 10 |
+
# 克隆数据集
|
| 11 |
+
submission_repo = Repository(local_dir=local_dir, clone_from=submission_url, use_auth_token=hf_token, repo_type="dataset")
|
| 12 |
+
submission_repo.git_pull() # 更新本地仓库
|
| 13 |
+
|
| 14 |
+
model_names = os.listdir(local_dir)
|
| 15 |
+
|
| 16 |
+
with open("videos_by_dimension.json") as f:
|
| 17 |
+
dimension = json.load(f)['videos_by_dimension']
|
| 18 |
+
|
| 19 |
+
# with open("all_videos.json") as f:
|
| 20 |
+
# all_videos = json.load(f)
|
| 21 |
+
|
| 22 |
+
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']
|
| 23 |
+
|
| 24 |
def get_random_video():
|
| 25 |
# 随机选择一个索引
|
| 26 |
+
random_index = random.randint(0, len(types) - 1)
|
| 27 |
+
type = types[random_index]
|
| 28 |
+
# 随机选择一个Prompt
|
| 29 |
+
random_index = random.randint(0, len(dimension[type]) - 1)
|
| 30 |
+
prompt = dimension[type][random_index]
|
| 31 |
+
# 随机一个模型
|
| 32 |
+
random_index = random.randint(0, len(model_names) - 1)
|
| 33 |
+
model_name = model_names[random_index]
|
| 34 |
+
|
| 35 |
+
video_path = os.path.join(model_name, type, prompt)
|
| 36 |
+
if os.path.exists(video_path):
|
| 37 |
+
print(video_path)
|
| 38 |
+
return video_path
|
| 39 |
+
else:
|
| 40 |
+
video_path = os.path.join(model_name, prompt)
|
| 41 |
+
if os.path.exists(video_path):
|
| 42 |
+
print(video_path)
|
| 43 |
+
return video_path
|
| 44 |
+
# video_path = dataset['train'][random_index]['video_path']
|
| 45 |
+
print('error:', video_path)
|
| 46 |
return video_path
|
| 47 |
|
| 48 |
# Gradio 接口
|