#!/usr/bin/env python3 import json, os, sys, torch from datasets import load_dataset sys.path.insert(0, "/root/cognet-1b") from train_ultra import CharTokenizer DATA_DIR = "/root/cognet-1b/data_1b" tokenizer = CharTokenizer.load("/root/cognet-1b/tokenizer_v3.json") # Load existing train tokens train_tokens = torch.load(os.path.join(DATA_DIR, "train_tokens.pt"), map_location="cpu", weights_only=True) val_tokens = torch.load(os.path.join(DATA_DIR, "val_tokens.pt"), map_location="cpu", weights_only=True) all_ids = torch.cat([train_tokens, val_tokens]).tolist() print(f"Existing tokens: {len(all_ids):,}") def tokenize_texts(texts, desc=""): ids = [] for i, text in enumerate(texts): if not text or len(text.strip()) < 10: continue ids.extend(tokenizer.encode(text)) if i % 50000 == 0 and i > 0: print(f" {desc}: {i:,} texts -> {len(ids):,} tokens") return ids # 1. WIKITEXT with correct namespace print("1/3 - WikiText-103 (fixed API)...") try: ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split="train") texts = [x["text"] for x in ds if x["text"].strip()] ids = tokenize_texts(texts, "WikiText-103") all_ids.extend(ids) print(f" OK WikiText-103: {len(ids):,} tokens") del ds, texts except Exception as e: print(f" FAIL WikiText: {e}") try: ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="train") texts = [x["text"] for x in ds if x["text"].strip()] ids = tokenize_texts(texts, "WikiText-103") all_ids.extend(ids) print(f" OK WikiText-103 (alt): {len(ids):,} tokens") del ds, texts except Exception as e2: print(f" FAIL WikiText alt: {e2}") # 2. C4 English subset print("2/3 - C4 English...") try: ds = load_dataset("allenai/c4", "en", split="train", streaming=True) texts = [] count = 0 for x in ds: texts.append(x["text"]) count += 1 if count >= 100000: break ids = tokenize_texts(texts, "C4-EN") all_ids.extend(ids) print(f" OK C4-EN: {len(ids):,} tokens") del texts except Exception as e: print(f" FAIL C4: {e}") # 3. FINEMATH print("3/3 - FineWeb-Edu...") try: ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True) texts = [] count = 0 for x in ds: texts.append(x["text"]) count += 1 if count >= 100000: break ids = tokenize_texts(texts, "FineWeb-Edu") all_ids.extend(ids) print(f" OK FineWeb-Edu: {len(ids):,} tokens") del texts except Exception as e: print(f" FAIL FineWeb: {e}") # SAVE ALL print(f"TOTAL TOKENS: {len(all_ids):,}") tokens = torch.tensor(all_ids, dtype=torch.long) split = int(len(tokens) * 0.95) train_tokens = tokens[:split] val_tokens = tokens[split:] torch.save(train_tokens, os.path.join(DATA_DIR, "train_tokens.pt")) torch.save(val_tokens, os.path.join(DATA_DIR, "val_tokens.pt")) print(f"Train: {len(train_tokens):,} tokens ({len(train_tokens)/1e6:.1f}M)") print(f"Val: {len(val_tokens):,} tokens ({len(val_tokens)/1e6:.1f}M)") print("MORE DATA COMPLETE!")