SIKAI-C commited on
Commit
078d201
Β·
verified Β·
1 Parent(s): bb483aa

Create download.py

Browse files
Files changed (1) hide show
  1. download.py +192 -0
download.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download script for CSI-4CAST datasets.
3
+
4
+ This script downloads all available datasets from the CSI-4CAST Hugging Face organization
5
+ by checking for all possible combinations of channel models, delay spreads, and speeds.
6
+
7
+ Usage:
8
+ python3 download.py [--output-dir OUTPUT_DIR]
9
+
10
+ If no arguments provided, it will download datasets to a 'datasets' folder.
11
+ """
12
+
13
+ import argparse
14
+ from pathlib import Path
15
+
16
+ from huggingface_hub import HfApi, snapshot_download
17
+ from tqdm import tqdm
18
+
19
+ # Configuration constants
20
+ ORG = "CSI-4CAST"
21
+
22
+ # Regular dataset parameters
23
+ LIST_CHANNEL_MODEL = ["A", "C", "D"]
24
+ LIST_DELAY_SPREAD = [30e-9, 100e-9, 300e-9]
25
+ LIST_MIN_SPEED = [1, 10, 30]
26
+
27
+ # Generalization dataset parameters
28
+ LIST_CHANNEL_MODEL_GEN = ["A", "B", "C", "D", "E"]
29
+ LIST_DELAY_SPREAD_GEN = [30e-9, 50e-9, 100e-9, 200e-9, 300e-9, 400e-9]
30
+ LIST_MIN_SPEED_GEN = sorted([*range(3, 46, 3), 1, 10])
31
+
32
+ def make_folder_name(cm: str, ds: float, ms: int, **kwargs) -> str:
33
+ """Generate a standardized folder name based on channel model, delay spread, and minimum speed.
34
+
35
+ Args:
36
+ cm (str): Channel model identifier (e.g., 'A', 'B', 'C', 'D', 'E')
37
+ ds (float): Delay spread in seconds (e.g., 30e-9, 100e-9, 300e-9)
38
+ ms (int): Minimum speed in km/h (e.g., 1, 10, 30)
39
+ **kwargs: Additional keyword arguments (unused)
40
+
41
+ Returns:
42
+ str: Formatted folder name in the format 'cm_{cm}_ds_{ds}_ms_{ms}'
43
+ where ds is converted to nanoseconds and zero-padded to 3 digits,
44
+ and ms is zero-padded to 3 digits
45
+
46
+ Example:
47
+ >>> make_folder_name('A', 30e-9, 10)
48
+ 'cm_A_ds_030_ms_010'
49
+ """
50
+ # the precision of the delay spread is int
51
+ ds = round(ds * 1e9)
52
+ ds_str = str(ds).zfill(3)
53
+
54
+ # the precision of the min speed is .1
55
+ ms_str = str(ms)
56
+ ms_str = ms_str.zfill(3)
57
+
58
+ # the file name
59
+ return f"cm_{cm}_ds_{ds_str}_ms_{ms_str}"
60
+
61
+ def check_repo_exists(api: HfApi, repo_id: str) -> bool:
62
+ """Check if a repository exists in the organization."""
63
+ try:
64
+ api.repo_info(repo_id, repo_type="dataset")
65
+ return True
66
+ except Exception:
67
+ return False
68
+
69
+ def generate_dataset_combinations():
70
+ """Generate all possible dataset combinations."""
71
+ combinations = []
72
+
73
+ # Stats dataset
74
+ combinations.append("stats")
75
+
76
+ # Train regular datasets
77
+ for cm in LIST_CHANNEL_MODEL:
78
+ for ds in LIST_DELAY_SPREAD:
79
+ for ms in LIST_MIN_SPEED:
80
+ folder_name = make_folder_name(cm, ds, ms)
81
+ repo_name = f"train_regular_{folder_name}"
82
+ combinations.append(repo_name)
83
+
84
+ # Test regular datasets
85
+ for cm in LIST_CHANNEL_MODEL:
86
+ for ds in LIST_DELAY_SPREAD:
87
+ for ms in LIST_MIN_SPEED:
88
+ folder_name = make_folder_name(cm, ds, ms)
89
+ repo_name = f"test_regular_{folder_name}"
90
+ combinations.append(repo_name)
91
+
92
+ # Test generalization datasets
93
+ for cm in LIST_CHANNEL_MODEL_GEN:
94
+ for ds in LIST_DELAY_SPREAD_GEN:
95
+ for ms in LIST_MIN_SPEED_GEN:
96
+ folder_name = make_folder_name(cm, ds, ms)
97
+ repo_name = f"test_generalization_{folder_name}"
98
+ combinations.append(repo_name)
99
+
100
+ return combinations
101
+
102
+ def download_dataset(api: HfApi, org: str, repo_name: str, output_dir: Path, dry_run: bool = False) -> bool:
103
+ """Download a single dataset if it exists."""
104
+ repo_id = f"{org}/{repo_name}"
105
+
106
+ if not check_repo_exists(api, repo_id):
107
+ return False
108
+
109
+ try:
110
+ # Create target directory
111
+ target_dir = output_dir / repo_name
112
+ target_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ if dry_run:
115
+ # Create empty placeholder file
116
+ placeholder_file = target_dir / "placeholder.txt"
117
+ placeholder_file.write_text("")
118
+ print(f"βœ… Dry run - Created placeholder: {repo_name}")
119
+ else:
120
+ # Download the dataset
121
+ snapshot_download(
122
+ repo_id=repo_id,
123
+ repo_type="dataset",
124
+ local_dir=target_dir,
125
+ local_dir_use_symlinks=False
126
+ )
127
+ print(f"βœ… Downloaded: {repo_name}")
128
+
129
+ return True
130
+
131
+ except Exception as e:
132
+ print(f"❌ Error downloading {repo_name}: {e}")
133
+ return False
134
+
135
+ def main():
136
+ parser = argparse.ArgumentParser(description="Download all CSI-4CAST datasets from Hugging Face")
137
+ parser.add_argument("--output-dir", "-o", default="datasets",
138
+ help="Output directory for downloaded datasets (default: 'datasets')")
139
+ parser.add_argument("--dry-run", action="store_true",
140
+ help="Dry run mode: create empty placeholder files instead of downloading")
141
+
142
+ args = parser.parse_args()
143
+
144
+ output_dir = Path(args.output_dir).resolve()
145
+ org = ORG
146
+
147
+ mode = "Dry run" if args.dry_run else "Downloading"
148
+ print(f"{mode} datasets from organization: {org}")
149
+ print(f"Output directory: {output_dir}")
150
+ print()
151
+
152
+ # Create output directory
153
+ output_dir.mkdir(parents=True, exist_ok=True)
154
+
155
+ # Initialize Hugging Face API
156
+ api = HfApi()
157
+
158
+ # Generate all possible combinations
159
+ print("Generating dataset combinations...")
160
+ combinations = generate_dataset_combinations()
161
+ print(f"Total possible combinations: {len(combinations)}")
162
+ print()
163
+
164
+ # Download datasets
165
+ action = "Checking and creating placeholders for" if args.dry_run else "Checking and downloading"
166
+ print(f"{action} existing datasets...")
167
+ downloaded_count = 0
168
+ skipped_count = 0
169
+
170
+ for repo_name in tqdm(combinations, desc="Processing datasets"):
171
+ if download_dataset(api, org, repo_name, output_dir, args.dry_run):
172
+ downloaded_count += 1
173
+ else:
174
+ skipped_count += 1
175
+
176
+ print()
177
+ if args.dry_run:
178
+ print("πŸŽ‰ Dry run complete!")
179
+ print(f"βœ… Created placeholders: {downloaded_count} datasets")
180
+ print(f"⏭️ Skipped: {skipped_count} datasets (not found)")
181
+ print(f"πŸ“ Placeholders saved to: {output_dir}")
182
+ else:
183
+ print("πŸŽ‰ Download complete!")
184
+ print(f"βœ… Downloaded: {downloaded_count} datasets")
185
+ print(f"⏭️ Skipped: {skipped_count} datasets (not found)")
186
+ print(f"πŸ“ Datasets saved to: {output_dir}")
187
+ print()
188
+ print("To reconstruct the original folder structure, run:")
189
+ print(f"python3 reconstruction.py --input-dir {output_dir}")
190
+
191
+ if __name__ == "__main__":
192
+ main()