| |
| import json, os, sys, torch |
| from datasets import load_dataset |
|
|
| sys.path.insert(0, "/root/cognet-1b") |
| from train_ultra import CharTokenizer |
|
|
| DATA_DIR = "/root/cognet-1b/data_1b" |
| tokenizer = CharTokenizer.load("/root/cognet-1b/tokenizer_v3.json") |
|
|
| |
| train_tokens = torch.load(os.path.join(DATA_DIR, "train_tokens.pt"), map_location="cpu", weights_only=True) |
| val_tokens = torch.load(os.path.join(DATA_DIR, "val_tokens.pt"), map_location="cpu", weights_only=True) |
| all_ids = torch.cat([train_tokens, val_tokens]).tolist() |
| print(f"Existing tokens: {len(all_ids):,}") |
|
|
| def tokenize_texts(texts, desc=""): |
| ids = [] |
| for i, text in enumerate(texts): |
| if not text or len(text.strip()) < 10: |
| continue |
| ids.extend(tokenizer.encode(text)) |
| if i % 50000 == 0 and i > 0: |
| print(f" {desc}: {i:,} texts -> {len(ids):,} tokens") |
| return ids |
|
|
| |
| print("1/3 - WikiText-103 (fixed API)...") |
| try: |
| ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split="train") |
| texts = [x["text"] for x in ds if x["text"].strip()] |
| ids = tokenize_texts(texts, "WikiText-103") |
| all_ids.extend(ids) |
| print(f" OK WikiText-103: {len(ids):,} tokens") |
| del ds, texts |
| except Exception as e: |
| print(f" FAIL WikiText: {e}") |
| try: |
| ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="train") |
| texts = [x["text"] for x in ds if x["text"].strip()] |
| ids = tokenize_texts(texts, "WikiText-103") |
| all_ids.extend(ids) |
| print(f" OK WikiText-103 (alt): {len(ids):,} tokens") |
| del ds, texts |
| except Exception as e2: |
| print(f" FAIL WikiText alt: {e2}") |
|
|
| |
| print("2/3 - C4 English...") |
| try: |
| ds = load_dataset("allenai/c4", "en", split="train", streaming=True) |
| texts = [] |
| count = 0 |
| for x in ds: |
| texts.append(x["text"]) |
| count += 1 |
| if count >= 100000: |
| break |
| ids = tokenize_texts(texts, "C4-EN") |
| all_ids.extend(ids) |
| print(f" OK C4-EN: {len(ids):,} tokens") |
| del texts |
| except Exception as e: |
| print(f" FAIL C4: {e}") |
|
|
| |
| print("3/3 - FineWeb-Edu...") |
| try: |
| ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) |
| texts = [] |
| count = 0 |
| for x in ds: |
| texts.append(x["text"]) |
| count += 1 |
| if count >= 100000: |
| break |
| ids = tokenize_texts(texts, "FineWeb-Edu") |
| all_ids.extend(ids) |
| print(f" OK FineWeb-Edu: {len(ids):,} tokens") |
| del texts |
| except Exception as e: |
| print(f" FAIL FineWeb: {e}") |
|
|
| |
| print(f"TOTAL TOKENS: {len(all_ids):,}") |
| tokens = torch.tensor(all_ids, dtype=torch.long) |
| split = int(len(tokens) * 0.95) |
| train_tokens = tokens[:split] |
| val_tokens = tokens[split:] |
| torch.save(train_tokens, os.path.join(DATA_DIR, "train_tokens.pt")) |
| torch.save(val_tokens, os.path.join(DATA_DIR, "val_tokens.pt")) |
| print(f"Train: {len(train_tokens):,} tokens ({len(train_tokens)/1e6:.1f}M)") |
| print(f"Val: {len(val_tokens):,} tokens ({len(val_tokens)/1e6:.1f}M)") |
| print("MORE DATA COMPLETE!") |
|
|