ethanker commited on
Commit
fad9f17
·
verified ·
1 Parent(s): 65998db

Upload push_to_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. push_to_hf.py +267 -0
push_to_hf.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Generator
8
+
9
+ import torch
10
+ from datasets import Dataset, load_dataset
11
+ from huggingface_hub import HfApi, Repository, create_repo, upload_file
12
+ from omegaconf import OmegaConf
13
+ from rich.console import Console
14
+ from rich.panel import Panel
15
+ from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn
16
+
17
+ from models import SymbolFIMModel
18
+
19
+ console = Console()
20
+
21
+
22
+ def load_jsonl(path: str) -> Generator[dict, None, None]:
23
+ with open(path, "r", encoding="utf-8") as handle:
24
+ for line in handle:
25
+ if not line.strip():
26
+ continue
27
+ yield json.loads(line)
28
+
29
+
30
+ def push_model(
31
+ model_path: str,
32
+ config_path: str,
33
+ repo_id: str,
34
+ token: str,
35
+ private: bool = False,
36
+ ) -> None:
37
+ console.print(Panel(f"[bold cyan]Pushing Model to {repo_id}[/bold cyan]", border_style="cyan"))
38
+
39
+ cfg = OmegaConf.load(config_path)
40
+
41
+ model = SymbolFIMModel(
42
+ vocab_size=260,
43
+ d_model=cfg.model.d_model,
44
+ n_layers=cfg.model.n_layers,
45
+ n_heads=cfg.model.n_heads,
46
+ window=cfg.model.window,
47
+ max_len=cfg.max_len,
48
+ ast_head_cfg=None,
49
+ )
50
+
51
+ if not os.path.exists(model_path):
52
+ raise FileNotFoundError(f"Model not found at {model_path}")
53
+
54
+ state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
55
+ model.load_state_dict(state_dict)
56
+ model.eval()
57
+
58
+ try:
59
+ repo_url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
60
+ except Exception as e:
61
+ if "403" in str(e) or "Forbidden" in str(e):
62
+ console.print(f"[bold yellow]⚠️ Warning:[/bold yellow] Cannot create repo. Make sure:")
63
+ console.print(" 1. Your HF token has 'write' permissions (not just 'read')")
64
+ console.print(" 2. The repo exists at https://huggingface.co/{repo_id}")
65
+ console.print(" 3. You have access to the namespace")
66
+ console.print(f"\n[yellow]Trying to upload to existing repo...[/yellow]")
67
+ else:
68
+ raise
69
+ api = HfApi(token=token)
70
+
71
+ with Progress(
72
+ SpinnerColumn(),
73
+ TextColumn("[progress.description]{task.description}"),
74
+ BarColumn(),
75
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
76
+ console=console
77
+ ) as progress:
78
+ task = progress.add_task("[cyan]Uploading model files...", total=100)
79
+
80
+ model_dir = Path("/tmp/hf_model_upload")
81
+ model_dir.mkdir(exist_ok=True)
82
+
83
+ torch.save(model.state_dict(), model_dir / "pytorch_model.bin")
84
+ progress.update(task, advance=30)
85
+
86
+ model_config = {
87
+ "vocab_size": model.vocab_size,
88
+ "d_model": cfg.model.d_model,
89
+ "n_layers": cfg.model.n_layers,
90
+ "n_heads": cfg.model.n_heads,
91
+ "window": cfg.model.window,
92
+ "max_len": cfg.max_len,
93
+ "model_type": "symbol_fim_transformer",
94
+ }
95
+
96
+ with open(model_dir / "config.json", "w") as f:
97
+ json.dump(model_config, f, indent=2)
98
+ progress.update(task, advance=20)
99
+
100
+ api.upload_folder(
101
+ folder_path=str(model_dir),
102
+ repo_id=repo_id,
103
+ repo_type="model",
104
+ token=token,
105
+ )
106
+ progress.update(task, advance=50)
107
+
108
+ console.print(f"[green]✓[/green] Model pushed to: [cyan]https://huggingface.co/{repo_id}[/cyan]")
109
+
110
+
111
+ def push_dataset(
112
+ dataset_path: str,
113
+ repo_id: str,
114
+ token: str,
115
+ private: bool = False,
116
+ max_samples: int = None,
117
+ ) -> None:
118
+ console.print(Panel(f"[bold cyan]Pushing Dataset to {repo_id}[/bold cyan]", border_style="cyan"))
119
+
120
+ records = []
121
+ with Progress(
122
+ SpinnerColumn(),
123
+ TextColumn("[progress.description]{task.description}"),
124
+ BarColumn(),
125
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
126
+ console=console
127
+ ) as progress:
128
+ task = progress.add_task("[cyan]Loading dataset...", total=None)
129
+
130
+ for idx, record in enumerate(load_jsonl(dataset_path)):
131
+ records.append(record)
132
+ if max_samples and idx + 1 >= max_samples:
133
+ break
134
+ if (idx + 1) % 1000 == 0:
135
+ progress.update(task, description=f"[cyan]Loaded {idx + 1:,} samples...")
136
+
137
+ progress.update(task, completed=True)
138
+ console.print(f"[green]✓[/green] Loaded {len(records):,} samples")
139
+
140
+ dataset = Dataset.from_list(records)
141
+
142
+ with Progress(
143
+ SpinnerColumn(),
144
+ TextColumn("[progress.description]{task.description}"),
145
+ BarColumn(),
146
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
147
+ console=console
148
+ ) as progress:
149
+ task = progress.add_task("[cyan]Pushing dataset to Hub...", total=100)
150
+
151
+ dataset.push_to_hub(
152
+ repo_id=repo_id,
153
+ token=token,
154
+ private=private,
155
+ )
156
+ progress.update(task, completed=True)
157
+
158
+ console.print(f"[green]✓[/green] Dataset pushed to: [cyan]https://huggingface.co/datasets/{repo_id}[/cyan]")
159
+
160
+
161
+ def push_code(
162
+ code_dir: str,
163
+ repo_id: str,
164
+ token: str,
165
+ private: bool = False,
166
+ ) -> None:
167
+ console.print(Panel(f"[bold cyan]Pushing Code to {repo_id}[/bold cyan]", border_style="cyan"))
168
+
169
+ create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
170
+ api = HfApi(token=token)
171
+
172
+ code_path = Path(code_dir)
173
+ files_to_upload = []
174
+
175
+ for ext in ["*.py", "*.yaml", "*.yml", "*.txt", "*.md"]:
176
+ files_to_upload.extend(code_path.rglob(ext))
177
+
178
+ files_to_upload = [f for f in files_to_upload if "__pycache__" not in str(f)]
179
+
180
+ with Progress(
181
+ SpinnerColumn(),
182
+ TextColumn("[progress.description]{task.description}"),
183
+ BarColumn(),
184
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
185
+ console=console
186
+ ) as progress:
187
+ task = progress.add_task(f"[cyan]Uploading {len(files_to_upload)} files...", total=len(files_to_upload))
188
+
189
+ for file_path in files_to_upload:
190
+ rel_path = file_path.relative_to(code_path)
191
+ api.upload_file(
192
+ path_or_fileobj=str(file_path),
193
+ path_in_repo=str(rel_path),
194
+ repo_id=repo_id,
195
+ repo_type="model",
196
+ token=token,
197
+ )
198
+ progress.update(task, advance=1)
199
+
200
+ console.print(f"[green]✓[/green] Code pushed to: [cyan]https://huggingface.co/{repo_id}[/cyan]")
201
+
202
+
203
+ def main() -> None:
204
+ parser = argparse.ArgumentParser(description="Push experiment to Hugging Face Hub")
205
+ parser.add_argument("--model-path", type=str, default="/workspace/runs/model.pt")
206
+ parser.add_argument("--config-path", type=str, required=True)
207
+ parser.add_argument("--dataset-path", type=str, required=True)
208
+ parser.add_argument("--code-dir", type=str, default="/workspace/experiments")
209
+ parser.add_argument("--model-repo", type=str, required=True, help="HF repo ID for model (e.g., username/model-name)")
210
+ parser.add_argument("--dataset-repo", type=str, required=True, help="HF repo ID for dataset (e.g., username/dataset-name)")
211
+ parser.add_argument("--code-repo", type=str, help="HF repo ID for code (optional, defaults to model-repo)")
212
+ parser.add_argument("--token", type=str, help="HF token (or set HF_TOKEN env var)")
213
+ parser.add_argument("--private", action="store_true", help="Make repos private")
214
+ parser.add_argument("--max-dataset-samples", type=int, help="Limit dataset samples (for testing)")
215
+ parser.add_argument("--push-model", action="store_true", default=True)
216
+ parser.add_argument("--push-dataset", action="store_true", default=True)
217
+ parser.add_argument("--push-code", action="store_true", default=True)
218
+
219
+ args = parser.parse_args()
220
+
221
+ token = args.token or os.getenv("HF_TOKEN")
222
+ if not token:
223
+ console.print("[bold red]Error:[/bold red] HF token required. Set --token or HF_TOKEN env var")
224
+ return
225
+
226
+ code_repo = args.code_repo or args.model_repo
227
+
228
+ try:
229
+ if args.push_model:
230
+ push_model(
231
+ model_path=args.model_path,
232
+ config_path=args.config_path,
233
+ repo_id=args.model_repo,
234
+ token=token,
235
+ private=args.private,
236
+ )
237
+ console.print()
238
+
239
+ if args.push_dataset:
240
+ push_dataset(
241
+ dataset_path=args.dataset_path,
242
+ repo_id=args.dataset_repo,
243
+ token=token,
244
+ private=args.private,
245
+ max_samples=args.max_dataset_samples,
246
+ )
247
+ console.print()
248
+
249
+ if args.push_code:
250
+ push_code(
251
+ code_dir=args.code_dir,
252
+ repo_id=code_repo,
253
+ token=token,
254
+ private=args.private,
255
+ )
256
+ console.print()
257
+
258
+ console.print(Panel("[bold green]✓ All components pushed successfully![/bold green]", border_style="green"))
259
+
260
+ except Exception as e:
261
+ console.print(f"[bold red]Error:[/bold red] {e}")
262
+ raise
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
267
+