codsworth-3.8m / codsworth /scripts /download_data.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
import os
import sys
import argparse
import logging
import subprocess
from pathlib import Path
from typing import Optional
DEFAULT_DATA_URLS = {
"openwebtext": "https://huggingface.co/datasets/Skylion007/openwebtext/resolve/main/train.json",
"slimpajama": "https://huggingface.co/datasets/cerebras/SlimPajama-627B/resolve/main/train.tar.gz",
}
def setup_logging(level: str = "INFO") -> logging.Logger:
logger = logging.getLogger("download_data")
logger.setLevel(getattr(logging, level.upper()))
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.addHandler(handler)
return logger
def download_file(url: str, output_dir: str, force: bool = False) -> str:
output_path = Path(output_dir) / Path(url).name
if output_path.exists() and not force:
return str(output_path)
logger.info(f"Downloading {url}...")
try:
result = subprocess.run(
["curl", "-L", "-o", str(output_path), url],
capture_output=True,
text=True,
)
if result.returncode != 0:
logger.warning(f"curl failed, trying wget: {result.stderr}")
result = subprocess.run(
["wget", "-O", str(output_path), url],
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Download failed: {result.stderr}")
logger.info(f"Downloaded to {output_path}")
return str(output_path)
except FileNotFoundError:
logger.error("curl or wget not found. Please install curl or wget.")
raise
def download_huggingface(
dataset_name: str,
output_dir: str,
split: str = "train",
cache_dir: Optional[str] = None,
) -> str:
try:
from datasets import load_dataset
except ImportError:
logger.error("Please install datasets: pip install datasets")
sys.exit(1)
logger.info(f"Downloading {dataset_name} from HuggingFace...")
cache = cache_dir or output_dir
dataset = load_dataset(
dataset_name,
split=split,
cache_dir=cache,
)
output_path = Path(output_dir) / f"{dataset_name}_{split}.txt"
with open(output_path, "w", encoding="utf-8") as f:
for i, example in enumerate(dataset):
if "text" in example:
f.write(example["text"] + "\n")
elif "content" in example:
f.write(example["content"] + "\n")
elif "article" in example:
f.write(example["article"] + "\n")
else:
f.write(str(example) + "\n")
if (i + 1) % 10000 == 0:
logger.info(f"Processed {i + 1} examples...")
logger.info(f"Saved to {output_path}")
return str(output_path)
def prepare_openwebtext(output_dir: str, force: bool = False) -> str:
from datasets import load_dataset
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / "openwebtext_train.txt"
if output_file.exists() and not force:
logger.info(f"Using cached {output_file}")
return str(output_file)
logger.info("Downloading OpenWebText dataset...")
dataset = load_dataset(
"openwebtext",
split="train",
cache_dir=str(output_dir / "cache"),
)
with open(output_file, "w", encoding="utf-8") as f:
for i, example in enumerate(dataset):
f.write(example["text"] + "\n")
if (i + 1) % 10000 == 0:
logger.info(f"Processed {i + 1} examples...")
logger.info(f"Saved {len(dataset)} examples to {output_file}")
return str(output_file)
def prepare_slimpajama(output_dir: str, force: bool = False) -> list[str]:
from datasets import load_dataset
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
files = []
for split in ["train", "val"]:
output_file = output_dir / f"slimpajama_{split}.txt"
if output_file.exists() and not force:
logger.info(f"Using cached {output_file}")
files.append(str(output_file))
continue
logger.info(f"Downloading SlimPajama {split} split...")
dataset = load_dataset(
"cerebras/SlimPajama-627B",
split=split,
cache_dir=str(output_dir / "cache"),
)
with open(output_file, "w", encoding="utf-8") as f:
for i, example in enumerate(dataset):
f.write(example["text"] + "\n")
if (i + 1) % 100000 == 0:
logger.info(f"Processed {i + 1} examples...")
logger.info(f"Saved {len(dataset)} examples to {output_file}")
files.append(str(output_file))
return files
def prepare_wikitext(output_dir: str, version: str = "wikitext-2-raw-v1") -> list[str]:
from datasets import load_dataset
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
files = []
for split in ["train", "val", "test"]:
output_file = output_dir / f"wikitext_{split}.txt"
if output_file.exists():
logger.info(f"Using cached {output_file}")
files.append(str(output_file))
continue
logger.info(f"Downloading WikiText {split} split...")
dataset = load_dataset(
"wikitext",
version,
split=split,
cache_dir=str(output_dir / "cache"),
)
with open(output_file, "w", encoding="utf-8") as f:
for example in dataset:
f.write(example["text"])
logger.info(f"Saved to {output_file}")
files.append(str(output_file))
return files
def parse_args():
parser = argparse.ArgumentParser(description="Download training data for Codsworth")
parser.add_argument(
"--dataset",
type=str,
choices=["openwebtext", "slimpajama", "wikitext", "custom"],
default="openwebtext",
help="Dataset to download",
)
parser.add_argument(
"--output_dir",
type=str,
default="data",
help="Output directory",
)
parser.add_argument(
"--force",
action="store_true",
help="Force re-download",
)
parser.add_argument(
"--log_level",
type=str,
default="INFO",
)
return parser.parse_args()
def main():
global logger
logger = setup_logging(args.log_level)
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if args.dataset == "openwebtext":
prepare_openwebtext(str(output_dir), args.force)
elif args.dataset == "slimpajama":
prepare_slimpajama(str(output_dir), args.force)
elif args.dataset == "wikitext":
prepare_wikitext(str(output_dir))
elif args.dataset == "custom":
logger.info("Custom dataset mode - please provide your own data files")
else:
logger.error(f"Unknown dataset: {args.dataset}")
sys.exit(1)
logger.info("Data preparation complete!")
if __name__ == "__main__":
main()