File size: 3,109 Bytes
fde73f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """Prepare the unified corpus for training.
Splits the unified corpus into training chunks, with chronological ordering
preserved within each source. Outputs JSONL format suitable for HF datasets.
"""
import json
import os
from pathlib import Path
from transformers import AutoTokenizer
def chunk_text(text, tokenizer, chunk_size=2048, overlap=128):
"""Split text into overlapping chunks based on token count."""
tokens = tokenizer.encode(text, add_special_tokens=False)
chunks = []
i = 0
while i < len(tokens):
chunk = tokens[i:i + chunk_size]
if len(chunk) < 100: # skip tiny tail
break
chunks.append(chunk)
i += chunk_size - overlap
return chunks
def prepare(corpus_path, output_path, tokenizer_name="EleutherAI/pythia-1.4b",
chunk_size=2048, overlap=128):
"""Prepare training data from unified corpus.
Args:
corpus_path: path to unified_corpus.txt
output_path: path for train.jsonl output
tokenizer_name: HF model whose tokenizer to use
chunk_size: tokens per training example
overlap: overlap between consecutive chunks for context continuity
"""
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
print(f"Reading corpus: {corpus_path}")
with open(corpus_path) as f:
text = f.read()
print(f"Corpus size: {len(text)/(1024*1024):.2f} MB")
# Split by source markers (preserve source attribution)
sources = text.split('#'*70 + '\n# SOURCE: ')
print(f"Sources: {len(sources)}")
all_chunks = []
for src_block in sources:
if not src_block.strip():
continue
# Extract source name
lines = src_block.split('\n', 1)
src_name = lines[0].strip()
body = lines[1] if len(lines) > 1 else ''
chunks = chunk_text(body, tokenizer, chunk_size, overlap)
for chunk in chunks:
all_chunks.append({
'text': tokenizer.decode(chunk),
'source': src_name,
'n_tokens': len(chunk),
})
print(f" {src_name}: {len(chunks)} chunks")
print(f"\nTotal chunks: {len(all_chunks)}")
total_tokens = sum(c['n_tokens'] for c in all_chunks)
print(f"Total tokens: {total_tokens:,}")
# Write JSONL
with open(output_path, 'w') as f:
for chunk in all_chunks:
f.write(json.dumps(chunk) + '\n')
print(f"Saved: {output_path}")
return all_chunks
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--corpus', default='unified_corpus.txt')
parser.add_argument('--output', default='train.jsonl')
parser.add_argument('--tokenizer', default='EleutherAI/pythia-1.4b')
parser.add_argument('--chunk-size', type=int, default=2048)
parser.add_argument('--overlap', type=int, default=128)
args = parser.parse_args()
prepare(args.corpus, args.output, args.tokenizer, args.chunk_size, args.overlap)
|