| import os |
| import sys |
| import argparse |
| import logging |
| import subprocess |
| from pathlib import Path |
| from typing import Optional |
|
|
|
|
| DEFAULT_DATA_URLS = { |
| "openwebtext": "https://huggingface.co/datasets/Skylion007/openwebtext/resolve/main/train.json", |
| "slimpajama": "https://huggingface.co/datasets/cerebras/SlimPajama-627B/resolve/main/train.tar.gz", |
| } |
|
|
|
|
| def setup_logging(level: str = "INFO") -> logging.Logger: |
| logger = logging.getLogger("download_data") |
| logger.setLevel(getattr(logging, level.upper())) |
| handler = logging.StreamHandler() |
| handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) |
| logger.addHandler(handler) |
| return logger |
|
|
|
|
| def download_file(url: str, output_dir: str, force: bool = False) -> str: |
| output_path = Path(output_dir) / Path(url).name |
| |
| if output_path.exists() and not force: |
| return str(output_path) |
| |
| logger.info(f"Downloading {url}...") |
| |
| try: |
| result = subprocess.run( |
| ["curl", "-L", "-o", str(output_path), url], |
| capture_output=True, |
| text=True, |
| ) |
| |
| if result.returncode != 0: |
| logger.warning(f"curl failed, trying wget: {result.stderr}") |
| |
| result = subprocess.run( |
| ["wget", "-O", str(output_path), url], |
| capture_output=True, |
| text=True, |
| ) |
| |
| if result.returncode != 0: |
| raise RuntimeError(f"Download failed: {result.stderr}") |
| |
| logger.info(f"Downloaded to {output_path}") |
| return str(output_path) |
| |
| except FileNotFoundError: |
| logger.error("curl or wget not found. Please install curl or wget.") |
| raise |
|
|
|
|
| def download_huggingface( |
| dataset_name: str, |
| output_dir: str, |
| split: str = "train", |
| cache_dir: Optional[str] = None, |
| ) -> str: |
| try: |
| from datasets import load_dataset |
| except ImportError: |
| logger.error("Please install datasets: pip install datasets") |
| sys.exit(1) |
| |
| logger.info(f"Downloading {dataset_name} from HuggingFace...") |
| |
| cache = cache_dir or output_dir |
| |
| dataset = load_dataset( |
| dataset_name, |
| split=split, |
| cache_dir=cache, |
| ) |
| |
| output_path = Path(output_dir) / f"{dataset_name}_{split}.txt" |
| |
| with open(output_path, "w", encoding="utf-8") as f: |
| for i, example in enumerate(dataset): |
| if "text" in example: |
| f.write(example["text"] + "\n") |
| elif "content" in example: |
| f.write(example["content"] + "\n") |
| elif "article" in example: |
| f.write(example["article"] + "\n") |
| else: |
| f.write(str(example) + "\n") |
| |
| if (i + 1) % 10000 == 0: |
| logger.info(f"Processed {i + 1} examples...") |
| |
| logger.info(f"Saved to {output_path}") |
| return str(output_path) |
|
|
|
|
| def prepare_openwebtext(output_dir: str, force: bool = False) -> str: |
| from datasets import load_dataset |
| |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| output_file = output_dir / "openwebtext_train.txt" |
| |
| if output_file.exists() and not force: |
| logger.info(f"Using cached {output_file}") |
| return str(output_file) |
| |
| logger.info("Downloading OpenWebText dataset...") |
| |
| dataset = load_dataset( |
| "openwebtext", |
| split="train", |
| cache_dir=str(output_dir / "cache"), |
| ) |
| |
| with open(output_file, "w", encoding="utf-8") as f: |
| for i, example in enumerate(dataset): |
| f.write(example["text"] + "\n") |
| |
| if (i + 1) % 10000 == 0: |
| logger.info(f"Processed {i + 1} examples...") |
| |
| logger.info(f"Saved {len(dataset)} examples to {output_file}") |
| return str(output_file) |
|
|
|
|
| def prepare_slimpajama(output_dir: str, force: bool = False) -> list[str]: |
| from datasets import load_dataset |
| |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| files = [] |
| |
| for split in ["train", "val"]: |
| output_file = output_dir / f"slimpajama_{split}.txt" |
| |
| if output_file.exists() and not force: |
| logger.info(f"Using cached {output_file}") |
| files.append(str(output_file)) |
| continue |
| |
| logger.info(f"Downloading SlimPajama {split} split...") |
| |
| dataset = load_dataset( |
| "cerebras/SlimPajama-627B", |
| split=split, |
| cache_dir=str(output_dir / "cache"), |
| ) |
| |
| with open(output_file, "w", encoding="utf-8") as f: |
| for i, example in enumerate(dataset): |
| f.write(example["text"] + "\n") |
| |
| if (i + 1) % 100000 == 0: |
| logger.info(f"Processed {i + 1} examples...") |
| |
| logger.info(f"Saved {len(dataset)} examples to {output_file}") |
| files.append(str(output_file)) |
| |
| return files |
|
|
|
|
| def prepare_wikitext(output_dir: str, version: str = "wikitext-2-raw-v1") -> list[str]: |
| from datasets import load_dataset |
| |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| files = [] |
| |
| for split in ["train", "val", "test"]: |
| output_file = output_dir / f"wikitext_{split}.txt" |
| |
| if output_file.exists(): |
| logger.info(f"Using cached {output_file}") |
| files.append(str(output_file)) |
| continue |
| |
| logger.info(f"Downloading WikiText {split} split...") |
| |
| dataset = load_dataset( |
| "wikitext", |
| version, |
| split=split, |
| cache_dir=str(output_dir / "cache"), |
| ) |
| |
| with open(output_file, "w", encoding="utf-8") as f: |
| for example in dataset: |
| f.write(example["text"]) |
| |
| logger.info(f"Saved to {output_file}") |
| files.append(str(output_file)) |
| |
| return files |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Download training data for Codsworth") |
| |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| choices=["openwebtext", "slimpajama", "wikitext", "custom"], |
| default="openwebtext", |
| help="Dataset to download", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="data", |
| help="Output directory", |
| ) |
| parser.add_argument( |
| "--force", |
| action="store_true", |
| help="Force re-download", |
| ) |
| parser.add_argument( |
| "--log_level", |
| type=str, |
| default="INFO", |
| ) |
| |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| global logger |
| logger = setup_logging(args.log_level) |
| |
| args = parse_args() |
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| if args.dataset == "openwebtext": |
| prepare_openwebtext(str(output_dir), args.force) |
| elif args.dataset == "slimpajama": |
| prepare_slimpajama(str(output_dir), args.force) |
| elif args.dataset == "wikitext": |
| prepare_wikitext(str(output_dir)) |
| elif args.dataset == "custom": |
| logger.info("Custom dataset mode - please provide your own data files") |
| else: |
| logger.error(f"Unknown dataset: {args.dataset}") |
| sys.exit(1) |
| |
| logger.info("Data preparation complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |