Vaishnav14220 commited on
Commit
fdbfba8
·
1 Parent(s): bd4ecf7

Push datasets via push_to_hub and load from Hub on resume

Browse files
Files changed (2) hide show
  1. app.py +9 -41
  2. src/dataset_prepare.py +8 -7
app.py CHANGED
@@ -10,7 +10,8 @@ from pathlib import Path
10
  from datetime import datetime
11
  from typing import List, Tuple
12
 
13
- from huggingface_hub import login, snapshot_download, hf_hub_download, HfApi
 
14
  from src.config import (
15
  FORWARD_DATASET_NAME,
16
  RETRO_DATASET_NAME,
@@ -72,52 +73,19 @@ def _ensure_clean_dir(path: Path):
72
 
73
 
74
  def _download_dataset(repo_id: str, target_dir: Path) -> bool:
75
- if _dir_has_arrow_files(target_dir) and (target_dir / "dataset_dict.json").exists():
76
  return True
77
  if not HF_MODEL_TOKEN:
78
  print(f"⚠️ Cannot download dataset {repo_id}: HF_MODEL_TOKEN not set.")
79
  return False
80
  try:
81
- print(f"⬇️ Downloading cached dataset from {repo_id}...")
 
 
 
82
  _ensure_clean_dir(target_dir)
83
- downloaded_path = Path(
84
- snapshot_download(
85
- repo_id=repo_id,
86
- repo_type="dataset",
87
- local_dir=str(target_dir),
88
- local_dir_use_symlinks=False,
89
- token=HF_MODEL_TOKEN,
90
- allow_patterns=["*"],
91
- )
92
- )
93
- if downloaded_path != target_dir:
94
- for item in downloaded_path.iterdir():
95
- dest = target_dir / item.name
96
- if dest.exists():
97
- if dest.is_dir():
98
- shutil.rmtree(dest)
99
- else:
100
- dest.unlink()
101
- shutil.move(str(item), str(dest))
102
- dataset_file = target_dir / "dataset_dict.json"
103
- if not dataset_file.exists():
104
- nested = list(target_dir.glob("**/dataset_dict.json"))
105
- for cand in nested:
106
- if cand.parent == target_dir:
107
- dataset_file = cand
108
- break
109
- # move nested dataset up one level
110
- for child in cand.parent.iterdir():
111
- dest = target_dir / child.name
112
- if dest.exists():
113
- if dest.is_dir():
114
- shutil.rmtree(dest)
115
- else:
116
- dest.unlink()
117
- shutil.move(str(child), str(dest))
118
- dataset_file = target_dir / "dataset_dict.json"
119
- break
120
- return dataset_file.exists() and _dir_has_arrow_files(target_dir)
121
  except Exception as exc:
122
  print(f"⚠️ Could not download dataset {repo_id}: {exc}")
123
  return False
 
10
  from datetime import datetime
11
  from typing import List, Tuple
12
 
13
+ from huggingface_hub import login, hf_hub_download, HfApi
14
+ from datasets import load_dataset, DatasetDict
15
  from src.config import (
16
  FORWARD_DATASET_NAME,
17
  RETRO_DATASET_NAME,
 
73
 
74
 
75
  def _download_dataset(repo_id: str, target_dir: Path) -> bool:
76
+ if (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir):
77
  return True
78
  if not HF_MODEL_TOKEN:
79
  print(f"⚠️ Cannot download dataset {repo_id}: HF_MODEL_TOKEN not set.")
80
  return False
81
  try:
82
+ print(f"⬇️ Loading dataset {repo_id} from Hugging Face Hub...")
83
+ ds = load_dataset(repo_id)
84
+ if not isinstance(ds, DatasetDict):
85
+ ds = DatasetDict({k: v for k, v in ds.items()})
86
  _ensure_clean_dir(target_dir)
87
+ ds.save_to_disk(str(target_dir))
88
+ return (target_dir / "dataset_dict.json").exists() and _dir_has_arrow_files(target_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  except Exception as exc:
90
  print(f"⚠️ Could not download dataset {repo_id}: {exc}")
91
  return False
src/dataset_prepare.py CHANGED
@@ -24,7 +24,7 @@ HF_API = HfApi(token=HF_TOKEN)
24
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
25
 
26
 
27
- def upload_split(local_dir: Path, repo_id: str, label: str):
28
  if not UPLOAD_DATASETS:
29
  print(f"Skipping upload of {label} dataset (ORD_UPLOAD_DATASETS disabled).")
30
  return
@@ -34,12 +34,13 @@ def upload_split(local_dir: Path, repo_id: str, label: str):
34
  return
35
 
36
  try:
37
- print(f"Uploading {label} dataset to {repo_id}...")
38
  create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, token=HF_TOKEN)
39
- HF_API.upload_folder(
40
- folder_path=str(local_dir),
41
  repo_id=repo_id,
42
- repo_type="dataset",
 
 
43
  commit_message=f"Update {label} dataset",
44
  )
45
  print(f"✅ Uploaded {label} dataset to Hugging Face Hub.")
@@ -97,9 +98,9 @@ def build_dataset(map_fn, name: str, max_samples=None):
97
  dsd.save_to_disk(str(save_path))
98
 
99
  if name == "forward":
100
- upload_split(save_path, FORWARD_DATASET_NAME, "forward")
101
  elif name == "retro":
102
- upload_split(save_path, RETRO_DATASET_NAME, "retro")
103
 
104
  print(f"\n{name} dataset statistics:")
105
  for split_name, ds in dsd.items():
 
24
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
25
 
26
 
27
+ def push_dataset(dataset: DatasetDict, repo_id: str, label: str):
28
  if not UPLOAD_DATASETS:
29
  print(f"Skipping upload of {label} dataset (ORD_UPLOAD_DATASETS disabled).")
30
  return
 
34
  return
35
 
36
  try:
37
+ print(f"Uploading {label} dataset to {repo_id} via push_to_hub...")
38
  create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, token=HF_TOKEN)
39
+ dataset.push_to_hub(
 
40
  repo_id=repo_id,
41
+ token=HF_TOKEN,
42
+ max_shard_size="2GB",
43
+ private=False,
44
  commit_message=f"Update {label} dataset",
45
  )
46
  print(f"✅ Uploaded {label} dataset to Hugging Face Hub.")
 
98
  dsd.save_to_disk(str(save_path))
99
 
100
  if name == "forward":
101
+ push_dataset(dsd, FORWARD_DATASET_NAME, "forward")
102
  elif name == "retro":
103
+ push_dataset(dsd, RETRO_DATASET_NAME, "retro")
104
 
105
  print(f"\n{name} dataset statistics:")
106
  for split_name, ds in dsd.items():