arach commited on
Commit
f9666b9
·
verified ·
1 Parent(s): 3a6cce4

Add Hugging Face Jobs launcher for news summarization eval

Browse files
eval/news_summarization/launch_hf_job.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import os
6
+ import textwrap
7
+ from pathlib import Path
8
+
9
+ from huggingface_hub import HfApi, run_job
10
+
11
+ TOKEN_KEYS = [
12
+ "HF_TOKEN",
13
+ "HUGGINGFACE_TOKEN",
14
+ "HUGGING_FACE_TOKEN",
15
+ "HUGGING_FACE_HUB_TOKEN",
16
+ "HF_API_TOKEN",
17
+ ]
18
+
19
+ ENV_CANDIDATES = [
20
+ Path.cwd() / ".env.local",
21
+ Path.cwd() / ".env",
22
+ Path.home() / ".env.local",
23
+ ]
24
+
25
+
26
+ def load_token_from_env_files() -> str | None:
27
+ for env_path in ENV_CANDIDATES:
28
+ if not env_path.exists():
29
+ continue
30
+ for line in env_path.read_text().splitlines():
31
+ stripped = line.strip()
32
+ if not stripped or stripped.startswith("#") or "=" not in stripped:
33
+ continue
34
+ key, value = stripped.split("=", 1)
35
+ key = key.strip()
36
+ value = value.strip().strip('"').strip("'")
37
+ if key in TOKEN_KEYS and value:
38
+ return value
39
+ return None
40
+
41
+
42
+ def resolve_token(cli_token: str | None) -> str:
43
+ if cli_token:
44
+ return cli_token
45
+ for key in TOKEN_KEYS:
46
+ value = os.environ.get(key)
47
+ if value:
48
+ return value
49
+ env_file_token = load_token_from_env_files()
50
+ if env_file_token:
51
+ return env_file_token
52
+ raise SystemExit("No Hugging Face token found. Export one or pass --token.")
53
+
54
+
55
+ def parse_args() -> argparse.Namespace:
56
+ parser = argparse.ArgumentParser(description="Launch the news summarization eval as a Hugging Face Job.")
57
+ parser.add_argument("--source-repo-id", default="arach/training-lab")
58
+ parser.add_argument("--source-repo-type", default="model")
59
+ parser.add_argument("--results-repo-id")
60
+ parser.add_argument("--results-repo-type", default="model")
61
+ parser.add_argument("--results-path-prefix", default="eval/news_summarization/results/jobs")
62
+ parser.add_argument("--image", default="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel")
63
+ parser.add_argument("--flavor", default="a10g-small")
64
+ parser.add_argument("--timeout", default="2h")
65
+ parser.add_argument("--token")
66
+ parser.add_argument("--namespace")
67
+ parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
68
+ parser.add_argument("--limit", type=int, default=50)
69
+ parser.add_argument("--seed", type=int, default=7)
70
+ parser.add_argument("--prompt-style", default="simple", choices=["simple", "helpful", "detailed"])
71
+ parser.add_argument("--dtype", default="bfloat16")
72
+ parser.add_argument("--device-map", default="auto")
73
+ parser.add_argument("--max-new-tokens", type=int, default=220)
74
+ parser.add_argument("--max-article-chars", type=int, default=8000)
75
+ parser.add_argument("--trust-remote-code", action="store_true")
76
+ parser.add_argument("--disable-rouge", action="store_true")
77
+ parser.add_argument("--disable-bertscore", action="store_true")
78
+ parser.add_argument("--bertscore-model", default="roberta-large")
79
+ parser.add_argument("--sync-code", action="store_true")
80
+ return parser.parse_args()
81
+
82
+
83
+ def sync_news_eval_to_hf(source_repo_id: str, source_repo_type: str, token: str) -> None:
84
+ api = HfApi(token=token)
85
+ root = Path(__file__).resolve().parents[2]
86
+ eval_dir = Path(__file__).resolve().parent
87
+ for rel in [
88
+ "run_news_summary_pilot.py",
89
+ "run_hf_transformers.py",
90
+ "launch_hf_job.py",
91
+ "README.md",
92
+ "COLAB_QUICKSTART.md",
93
+ "requirements.txt",
94
+ "notebook.ipynb",
95
+ ]:
96
+ api.upload_file(
97
+ repo_id=source_repo_id,
98
+ repo_type=source_repo_type,
99
+ path_or_fileobj=str(eval_dir / rel),
100
+ path_in_repo=f"eval/news_summarization/{rel}",
101
+ commit_message="Sync news summarization eval harness",
102
+ )
103
+
104
+
105
+ def main() -> int:
106
+ args = parse_args()
107
+ token = resolve_token(args.token)
108
+ source_repo_id = args.source_repo_id
109
+ source_repo_type = args.source_repo_type
110
+ results_repo_id = args.results_repo_id or source_repo_id
111
+ results_repo_type = args.results_repo_type
112
+
113
+ if args.sync_code:
114
+ sync_news_eval_to_hf(source_repo_id, source_repo_type, token)
115
+
116
+ flags = [
117
+ f"--model {args.model}",
118
+ f"--limit {args.limit}",
119
+ f"--seed {args.seed}",
120
+ f"--prompt-style {args.prompt_style}",
121
+ f"--dtype {args.dtype}",
122
+ f"--device-map {args.device_map}",
123
+ f"--max-new-tokens {args.max_new_tokens}",
124
+ f"--max-article-chars {args.max_article_chars}",
125
+ "--save-every 1",
126
+ "--verbose",
127
+ ]
128
+ if args.trust_remote_code:
129
+ flags.append("--trust-remote-code")
130
+ if args.disable_rouge:
131
+ flags.append("--disable-rouge")
132
+ if args.disable_bertscore:
133
+ flags.append("--disable-bertscore")
134
+ else:
135
+ flags.append(f"--bertscore-model {args.bertscore_model}")
136
+
137
+ job_script = textwrap.dedent(
138
+ f"""
139
+ set -euo pipefail
140
+ python - <<'PY'
141
+ from huggingface_hub import snapshot_download
142
+ snapshot_download(
143
+ repo_id='{source_repo_id}',
144
+ repo_type='{source_repo_type}',
145
+ local_dir='/workspace/training-lab',
146
+ token='{token}',
147
+ allow_patterns=['eval/news_summarization/*'],
148
+ )
149
+ PY
150
+ cd /workspace/training-lab
151
+ mkdir -p {args.results_path_prefix}/$JOB_ID
152
+ python eval/news_summarization/run_hf_transformers.py {' '.join(flags)} --output {args.results_path_prefix}/$JOB_ID/results.json
153
+ python - <<'PY'
154
+ import os
155
+ from huggingface_hub import HfApi
156
+ api = HfApi(token=os.environ['HF_TOKEN'])
157
+ api.upload_folder(
158
+ repo_id='{results_repo_id}',
159
+ repo_type='{results_repo_type}',
160
+ folder_path='{args.results_path_prefix}/$JOB_ID',
161
+ path_in_repo='{args.results_path_prefix}/$JOB_ID',
162
+ commit_message='Upload news summarization eval results for $JOB_ID',
163
+ )
164
+ PY
165
+ """
166
+ ).strip()
167
+
168
+ job = run_job(
169
+ image=args.image,
170
+ command=["bash", "-lc", job_script],
171
+ flavor=args.flavor,
172
+ timeout=args.timeout,
173
+ namespace=args.namespace,
174
+ token=token,
175
+ secrets={"HF_TOKEN": token},
176
+ )
177
+ print(f"Job launched: {job.url}")
178
+ print(f"Job id: {job.id}")
179
+ return 0
180
+
181
+
182
+ if __name__ == "__main__":
183
+ raise SystemExit(main())