import os import sys import argparse import logging import subprocess from pathlib import Path from typing import Optional DEFAULT_DATA_URLS = { "openwebtext": "https://huggingface.co/datasets/Skylion007/openwebtext/resolve/main/train.json", "slimpajama": "https://huggingface.co/datasets/cerebras/SlimPajama-627B/resolve/main/train.tar.gz", } def setup_logging(level: str = "INFO") -> logging.Logger: logger = logging.getLogger("download_data") logger.setLevel(getattr(logging, level.upper())) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.addHandler(handler) return logger def download_file(url: str, output_dir: str, force: bool = False) -> str: output_path = Path(output_dir) / Path(url).name if output_path.exists() and not force: return str(output_path) logger.info(f"Downloading {url}...") try: result = subprocess.run( ["curl", "-L", "-o", str(output_path), url], capture_output=True, text=True, ) if result.returncode != 0: logger.warning(f"curl failed, trying wget: {result.stderr}") result = subprocess.run( ["wget", "-O", str(output_path), url], capture_output=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"Download failed: {result.stderr}") logger.info(f"Downloaded to {output_path}") return str(output_path) except FileNotFoundError: logger.error("curl or wget not found. Please install curl or wget.") raise def download_huggingface( dataset_name: str, output_dir: str, split: str = "train", cache_dir: Optional[str] = None, ) -> str: try: from datasets import load_dataset except ImportError: logger.error("Please install datasets: pip install datasets") sys.exit(1) logger.info(f"Downloading {dataset_name} from HuggingFace...") cache = cache_dir or output_dir dataset = load_dataset( dataset_name, split=split, cache_dir=cache, ) output_path = Path(output_dir) / f"{dataset_name}_{split}.txt" with open(output_path, "w", encoding="utf-8") as f: for i, example in enumerate(dataset): if "text" in example: f.write(example["text"] + "\n") elif "content" in example: f.write(example["content"] + "\n") elif "article" in example: f.write(example["article"] + "\n") else: f.write(str(example) + "\n") if (i + 1) % 10000 == 0: logger.info(f"Processed {i + 1} examples...") logger.info(f"Saved to {output_path}") return str(output_path) def prepare_openwebtext(output_dir: str, force: bool = False) -> str: from datasets import load_dataset output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / "openwebtext_train.txt" if output_file.exists() and not force: logger.info(f"Using cached {output_file}") return str(output_file) logger.info("Downloading OpenWebText dataset...") dataset = load_dataset( "openwebtext", split="train", cache_dir=str(output_dir / "cache"), ) with open(output_file, "w", encoding="utf-8") as f: for i, example in enumerate(dataset): f.write(example["text"] + "\n") if (i + 1) % 10000 == 0: logger.info(f"Processed {i + 1} examples...") logger.info(f"Saved {len(dataset)} examples to {output_file}") return str(output_file) def prepare_slimpajama(output_dir: str, force: bool = False) -> list[str]: from datasets import load_dataset output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) files = [] for split in ["train", "val"]: output_file = output_dir / f"slimpajama_{split}.txt" if output_file.exists() and not force: logger.info(f"Using cached {output_file}") files.append(str(output_file)) continue logger.info(f"Downloading SlimPajama {split} split...") dataset = load_dataset( "cerebras/SlimPajama-627B", split=split, cache_dir=str(output_dir / "cache"), ) with open(output_file, "w", encoding="utf-8") as f: for i, example in enumerate(dataset): f.write(example["text"] + "\n") if (i + 1) % 100000 == 0: logger.info(f"Processed {i + 1} examples...") logger.info(f"Saved {len(dataset)} examples to {output_file}") files.append(str(output_file)) return files def prepare_wikitext(output_dir: str, version: str = "wikitext-2-raw-v1") -> list[str]: from datasets import load_dataset output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) files = [] for split in ["train", "val", "test"]: output_file = output_dir / f"wikitext_{split}.txt" if output_file.exists(): logger.info(f"Using cached {output_file}") files.append(str(output_file)) continue logger.info(f"Downloading WikiText {split} split...") dataset = load_dataset( "wikitext", version, split=split, cache_dir=str(output_dir / "cache"), ) with open(output_file, "w", encoding="utf-8") as f: for example in dataset: f.write(example["text"]) logger.info(f"Saved to {output_file}") files.append(str(output_file)) return files def parse_args(): parser = argparse.ArgumentParser(description="Download training data for Codsworth") parser.add_argument( "--dataset", type=str, choices=["openwebtext", "slimpajama", "wikitext", "custom"], default="openwebtext", help="Dataset to download", ) parser.add_argument( "--output_dir", type=str, default="data", help="Output directory", ) parser.add_argument( "--force", action="store_true", help="Force re-download", ) parser.add_argument( "--log_level", type=str, default="INFO", ) return parser.parse_args() def main(): global logger logger = setup_logging(args.log_level) args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) if args.dataset == "openwebtext": prepare_openwebtext(str(output_dir), args.force) elif args.dataset == "slimpajama": prepare_slimpajama(str(output_dir), args.force) elif args.dataset == "wikitext": prepare_wikitext(str(output_dir)) elif args.dataset == "custom": logger.info("Custom dataset mode - please provide your own data files") else: logger.error(f"Unknown dataset: {args.dataset}") sys.exit(1) logger.info("Data preparation complete!") if __name__ == "__main__": main()