WHU1psh commited on
Commit
4e4000d
·
verified ·
1 Parent(s): 881c265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -4
app.py CHANGED
@@ -9,18 +9,128 @@ import threading
9
  from collections import defaultdict
10
  from datetime import datetime
11
  from pathlib import Path
12
- from typing import Any, Dict, List, Tuple
13
 
14
  import gradio as gr
 
15
 
16
  # 路径配置(按用户要求)
17
- ROOT_DIR = Path(os.environ.get("VIDEOEVAL_ROOT", "MemDirector"))
18
- INPUT_DIR = ROOT_DIR / "user_study_input"
19
- OUTPUT_DIR = ROOT_DIR / "user_study_results"
 
 
 
 
 
 
 
 
 
 
 
20
  STORY_DIR = INPUT_DIR / "clip_movie_story"
21
  VIDEO_DIR = INPUT_DIR / "video"
22
 
23
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Movie-Level 指标定义
26
  MOVIE_CRITERIA: List[Tuple[str, str, str]] = [
@@ -144,6 +254,19 @@ def build_pending_samples() -> List[Dict[str, Any]]:
144
  return pending
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def compute_derived(scores: Dict[str, float]) -> Dict[str, float]:
148
  """计算 CL / CRh / AVG。"""
149
  cl = (
@@ -258,6 +381,7 @@ def create_app():
258
  gr.Markdown(
259
  f"<span class='hint'>输入目录:`{INPUT_DIR}` | 输出目录:`{OUTPUT_DIR}`</span>",
260
  )
 
261
 
262
  current_idx = gr.State(0)
263
  evaluator_state = gr.State("anonymous")
 
9
  from collections import defaultdict
10
  from datetime import datetime
11
  from pathlib import Path
12
+ from typing import Any, Dict, List, Optional, Tuple
13
 
14
  import gradio as gr
15
+ from huggingface_hub import CommitScheduler, snapshot_download
16
 
17
  # 路径配置(按用户要求)
18
+ # Spaces 推荐优先读取当前 Space 仓库内文件(app.py 同级)
19
+ APP_DIR = Path(__file__).resolve().parent
20
+ LOCAL_INPUT_DIR = APP_DIR / "user_study_input"
21
+ LOCAL_OUTPUT_DIR = APP_DIR / "user_study_results"
22
+ DATA_INPUT_DIR = Path("/data/user_study_input")
23
+ DATA_OUTPUT_DIR = Path("/data/user_study_results")
24
+ DATA_REPO_ID = os.environ.get("DATA_REPO_ID", "MemDirector/user_study_input")
25
+ RESULTS_REPO_ID = os.environ.get("RESULTS_REPO_ID", "MemDirector/user_study_results")
26
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+ SPACE_MODE = os.environ.get("SPACE_MODE", "repo_first") # repo_first / data_first / hub_only
28
+
29
+ ROOT_DIR = APP_DIR
30
+ INPUT_DIR = LOCAL_INPUT_DIR
31
+ OUTPUT_DIR = LOCAL_OUTPUT_DIR
32
  STORY_DIR = INPUT_DIR / "clip_movie_story"
33
  VIDEO_DIR = INPUT_DIR / "video"
34
 
35
  Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
36
+ scheduler: Optional[CommitScheduler] = None
37
+
38
+
39
+ def _set_paths(input_dir: Path, output_dir: Path) -> None:
40
+ global INPUT_DIR, OUTPUT_DIR, STORY_DIR, VIDEO_DIR, ROOT_DIR
41
+ INPUT_DIR = input_dir
42
+ OUTPUT_DIR = output_dir
43
+ STORY_DIR = INPUT_DIR / "clip_movie_story"
44
+ VIDEO_DIR = INPUT_DIR / "video"
45
+ ROOT_DIR = INPUT_DIR.parent
46
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
47
+
48
+
49
+ def _try_use_local_repo_layout() -> bool:
50
+ # Space 仓库内自带 user_study_input 时,直接读取(最符合“已放上去直接跑”)
51
+ if LOCAL_INPUT_DIR.exists():
52
+ _set_paths(LOCAL_INPUT_DIR, LOCAL_OUTPUT_DIR)
53
+ return True
54
+ return False
55
+
56
+
57
+ def _try_use_data_volume_layout() -> bool:
58
+ # 如果使用 /data 持久卷,则可放在 /data/user_study_input
59
+ if DATA_INPUT_DIR.exists():
60
+ _set_paths(DATA_INPUT_DIR, DATA_OUTPUT_DIR)
61
+ return True
62
+ return False
63
+
64
+
65
+ def _try_download_from_hub() -> bool:
66
+ # 最后兜底:从 dataset repo 下载
67
+ if not DATA_REPO_ID:
68
+ return False
69
+ hub_root = APP_DIR / ".hf_space_cache"
70
+ try:
71
+ snapshot_download(
72
+ repo_id=DATA_REPO_ID,
73
+ repo_type="dataset",
74
+ local_dir=str(hub_root),
75
+ token=HF_TOKEN,
76
+ allow_patterns=[
77
+ "clip_movie_story/**",
78
+ "video/**",
79
+ "user_study_input/**",
80
+ "user_study_results/**",
81
+ ],
82
+ )
83
+ except Exception as e:
84
+ print(f"[INIT] snapshot_download failed: {e}")
85
+ return False
86
+
87
+ # 兼容两种 dataset 结构:
88
+ # A) 仓库根目录直接是 clip_movie_story/ 与 video/
89
+ # B) 仓库里有 user_study_input/ 子目录
90
+ if (hub_root / "clip_movie_story").exists() and (hub_root / "video").exists():
91
+ hub_input = hub_root
92
+ elif (hub_root / "user_study_input").exists():
93
+ hub_input = hub_root / "user_study_input"
94
+ else:
95
+ return False
96
+
97
+ hub_output = hub_root / "user_study_results"
98
+ _set_paths(hub_input, hub_output)
99
+ return True
100
+
101
+
102
+ def init_space_storage() -> None:
103
+ """
104
+ Hugging Face Spaces 规范:
105
+ - 从 dataset repo 拉取 user_study_input 与 user_study_results 到本地 ROOT_DIR
106
+ - 使用 CommitScheduler 持续回写 user_study_results
107
+ """
108
+ global scheduler
109
+
110
+ if SPACE_MODE == "hub_only":
111
+ ok = _try_download_from_hub()
112
+ elif SPACE_MODE == "data_first":
113
+ ok = _try_use_data_volume_layout() or _try_use_local_repo_layout() or _try_download_from_hub()
114
+ else:
115
+ ok = _try_use_local_repo_layout() or _try_use_data_volume_layout() or _try_download_from_hub()
116
+ print(f"[INIT] storage init mode={SPACE_MODE}, success={ok}, input={INPUT_DIR}, output={OUTPUT_DIR}")
117
+
118
+ if RESULTS_REPO_ID:
119
+ try:
120
+ scheduler = CommitScheduler(
121
+ repo_id=RESULTS_REPO_ID,
122
+ repo_type="dataset",
123
+ folder_path=str(OUTPUT_DIR),
124
+ path_in_repo="user_study_results",
125
+ every=3,
126
+ token=HF_TOKEN,
127
+ )
128
+ print(f"[INIT] CommitScheduler enabled: {RESULTS_REPO_ID}")
129
+ except Exception as e:
130
+ print(f"[INIT] CommitScheduler init failed: {e}")
131
+
132
+
133
+ init_space_storage()
134
 
135
  # Movie-Level 指标定义
136
  MOVIE_CRITERIA: List[Tuple[str, str, str]] = [
 
254
  return pending
255
 
256
 
257
+ def build_data_diagnostics(samples: List[Dict[str, Any]]) -> str:
258
+ return (
259
+ f"**SPACE_MODE**: `{SPACE_MODE}` \n"
260
+ f"**DATA_REPO_ID**: `{DATA_REPO_ID}` \n"
261
+ f"**RESULTS_REPO_ID**: `{RESULTS_REPO_ID}` \n"
262
+ f"**ROOT_DIR**: `{ROOT_DIR}` \n"
263
+ f"**INPUT_DIR exists**: `{INPUT_DIR.exists()}` \n"
264
+ f"**STORY_DIR exists**: `{STORY_DIR.exists()}` \n"
265
+ f"**VIDEO_DIR exists**: `{VIDEO_DIR.exists()}` \n"
266
+ f"**Pending samples**: `{len(samples)}`"
267
+ )
268
+
269
+
270
  def compute_derived(scores: Dict[str, float]) -> Dict[str, float]:
271
  """计算 CL / CRh / AVG。"""
272
  cl = (
 
381
  gr.Markdown(
382
  f"<span class='hint'>输入目录:`{INPUT_DIR}` | 输出目录:`{OUTPUT_DIR}`</span>",
383
  )
384
+ gr.Markdown(build_data_diagnostics(samples))
385
 
386
  current_idx = gr.State(0)
387
  evaluator_state = gr.State("anonymous")