| """Prepare the unified corpus for training. |
| |
| Splits the unified corpus into training chunks, with chronological ordering |
| preserved within each source. Outputs JSONL format suitable for HF datasets. |
| """ |
| import json |
| import os |
| from pathlib import Path |
| from transformers import AutoTokenizer |
|
|
| def chunk_text(text, tokenizer, chunk_size=2048, overlap=128): |
| """Split text into overlapping chunks based on token count.""" |
| tokens = tokenizer.encode(text, add_special_tokens=False) |
| chunks = [] |
| i = 0 |
| while i < len(tokens): |
| chunk = tokens[i:i + chunk_size] |
| if len(chunk) < 100: |
| break |
| chunks.append(chunk) |
| i += chunk_size - overlap |
| return chunks |
|
|
|
|
| def prepare(corpus_path, output_path, tokenizer_name="EleutherAI/pythia-1.4b", |
| chunk_size=2048, overlap=128): |
| """Prepare training data from unified corpus. |
| |
| Args: |
| corpus_path: path to unified_corpus.txt |
| output_path: path for train.jsonl output |
| tokenizer_name: HF model whose tokenizer to use |
| chunk_size: tokens per training example |
| overlap: overlap between consecutive chunks for context continuity |
| """ |
| print(f"Loading tokenizer: {tokenizer_name}") |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| |
| print(f"Reading corpus: {corpus_path}") |
| with open(corpus_path) as f: |
| text = f.read() |
| print(f"Corpus size: {len(text)/(1024*1024):.2f} MB") |
| |
| |
| sources = text.split('#'*70 + '\n# SOURCE: ') |
| print(f"Sources: {len(sources)}") |
| |
| all_chunks = [] |
| for src_block in sources: |
| if not src_block.strip(): |
| continue |
| |
| lines = src_block.split('\n', 1) |
| src_name = lines[0].strip() |
| body = lines[1] if len(lines) > 1 else '' |
| |
| chunks = chunk_text(body, tokenizer, chunk_size, overlap) |
| for chunk in chunks: |
| all_chunks.append({ |
| 'text': tokenizer.decode(chunk), |
| 'source': src_name, |
| 'n_tokens': len(chunk), |
| }) |
| print(f" {src_name}: {len(chunks)} chunks") |
| |
| print(f"\nTotal chunks: {len(all_chunks)}") |
| total_tokens = sum(c['n_tokens'] for c in all_chunks) |
| print(f"Total tokens: {total_tokens:,}") |
| |
| |
| with open(output_path, 'w') as f: |
| for chunk in all_chunks: |
| f.write(json.dumps(chunk) + '\n') |
| print(f"Saved: {output_path}") |
| |
| return all_chunks |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--corpus', default='unified_corpus.txt') |
| parser.add_argument('--output', default='train.jsonl') |
| parser.add_argument('--tokenizer', default='EleutherAI/pythia-1.4b') |
| parser.add_argument('--chunk-size', type=int, default=2048) |
| parser.add_argument('--overlap', type=int, default=128) |
| args = parser.parse_args() |
| |
| prepare(args.corpus, args.output, args.tokenizer, args.chunk_size, args.overlap) |
|
|