Luna-150M / generate_data.py
JMSykala's picture
Upload 9 files
9c737ff verified
# Copyright 2026 Jakub Sykała
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import argparse
import numpy as np
from tqdm import tqdm
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple, Set
# The 9 features
FEATURE_NAMES = [
'syllable_id', # 0
'onset_id', # 1
'nucleus_id', # 2
'coda_id', # 3
'position', # 4
'is_capitalized', # 5
'token_type', # 6
'has_space_after', # 7
'is_word_end', # 8
]
N_FEATURES = len(FEATURE_NAMES)
assert N_FEATURES == 9, f"Expect 9 features, got {N_FEATURES}"
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# Helper Functions
def get_tokenizer():
"""Create a fresh tokenizer instance with suppressed output."""
import sys
from io import StringIO
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
from tokenizer import LunaTokenizer
tok = LunaTokenizer()
finally:
sys.stdout = old_stdout
return tok
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# Pass 1: Build Global Vocabulary with Frequency Filtering
def extract_vocab_from_chunk(args: Tuple[str, int, int]) -> Dict[str, Counter]:
"""Extract vocabulary counts from a chunk."""
input_path, start_byte, end_byte = args
with open(input_path, 'r', encoding='utf-8', errors='ignore') as f:
f.seek(start_byte)
text = f.read(end_byte - start_byte)
if not text or not text.strip():
return {
'syllables': Counter(),
'onsets': Counter(),
'nuclei': Counter(),
'codas': Counter()
}
tokenizer = get_tokenizer()
encoded = tokenizer.encode(text)
syllable_counts = Counter()
onset_counts = Counter()
nucleus_counts = Counter()
coda_counts = Counter()
for token in encoded:
text_content = token.get('text', '')
token_type = token.get('token_type', 0)
# Syllable key
if token_type == 2:
syl_key = f"<punct_{text_content}>"
elif token_type == 1:
syl_key = f"<num_{text_content}>"
elif token_type == 3:
syl_key = f"<char_{text_content}>"
else:
syl_key = text_content.lower() if text_content else ''
if syl_key:
syllable_counts[syl_key] += 1
# Count phonetic components for regular syllables
if token_type == 0 and text_content:
syl_lower = text_content.lower()
vowels = set('aeiouy')
# Extract onset/nucleus/coda
nucleus_start = -1
nucleus_end = -1
for i, char in enumerate(syl_lower):
if char in vowels:
if nucleus_start == -1:
nucleus_start = i
nucleus_end = i + 1
elif nucleus_start != -1:
break
if nucleus_start != -1:
onset = syl_lower[:nucleus_start]
nucleus = syl_lower[nucleus_start:nucleus_end]
coda = syl_lower[nucleus_end:]
else:
onset, nucleus, coda = syl_lower, '', ''
onset_counts[onset] += 1
nucleus_counts[nucleus] += 1
coda_counts[coda] += 1
return {
'syllables': syllable_counts,
'onsets': onset_counts,
'nuclei': nucleus_counts,
'codas': coda_counts,
}
def build_global_vocab(
input_path: str,
output_dir: str,
n_workers: int = 8,
min_freq: int = 15,
max_syllables: int = 32768,
max_onsets: int = 2048,
max_nuclei: int = 512,
max_codas: int = 2048
) -> str:
"""
Build global vocabulary with frequency filtering and caps.
- filters onset/nucleus/coda vocabularies to prevent VRAM
explosion from garbage phonetic components.
"""
file_size = os.path.getsize(input_path)
print(f"\n{'='*70}")
print("PASS 1: Building Global Vocabulary (v4 with phonetic caps)")
print(f"{'='*70}")
print(f"Input: {input_path} ({file_size/1e6:.0f} MB)")
print(f"Caps: syllables={max_syllables}, onsets={max_onsets}, nuclei={max_nuclei}, codas={max_codas}")
# Find chunk boundaries
chunk_size = 2 * 1024 * 1024
chunk_boundaries = [0]
with open(input_path, 'rb') as f:
while True:
pos = chunk_boundaries[-1] + chunk_size
if pos >= file_size:
chunk_boundaries.append(file_size)
break
f.seek(pos)
f.readline()
chunk_boundaries.append(f.tell())
n_chunks = len(chunk_boundaries) - 1
print(f"Processing {n_chunks} chunks with {n_workers} workers...")
# Extract vocabulary counts
jobs = [(input_path, chunk_boundaries[i], chunk_boundaries[i+1]) for i in range(n_chunks)]
syllable_counts = Counter()
onset_counts = Counter()
nucleus_counts = Counter()
coda_counts = Counter()
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = [executor.submit(extract_vocab_from_chunk, job) for job in jobs]
pbar = tqdm(total=len(futures), desc="Scanning")
for future in as_completed(futures):
vocab = future.result()
syllable_counts.update(vocab['syllables'])
onset_counts.update(vocab['onsets'])
nucleus_counts.update(vocab['nuclei'])
coda_counts.update(vocab['codas'])
pbar.update(1)
pbar.close()
print(f"\nRaw vocab sizes: syllables={len(syllable_counts)}, onsets={len(onset_counts)}, nuclei={len(nucleus_counts)}, codas={len(coda_counts)}")
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# Frequency filtering
print(f"\nApplying frequency filters...")
# Filter syllables
filtered_syls = {s for s, c in syllable_counts.items() if c >= min_freq}
if len(filtered_syls) > max_syllables:
top = syllable_counts.most_common(max_syllables)
filtered_syls = {s for s, c in top if c >= min_freq}
# Filter onsets
filtered_onsets = {o for o, c in onset_counts.items() if c >= min_freq}
if len(filtered_onsets) > max_onsets:
top = onset_counts.most_common(max_onsets)
filtered_onsets = {o for o, _ in top}
# Filter nuclei
filtered_nuclei = {n for n, c in nucleus_counts.items() if c >= min_freq}
if len(filtered_nuclei) > max_nuclei:
top = nucleus_counts.most_common(max_nuclei)
filtered_nuclei = {n for n, _ in top}
# Filter codas
filtered_codas = {c for c, cnt in coda_counts.items() if cnt >= min_freq}
if len(filtered_codas) > max_codas:
top = coda_counts.most_common(max_codas)
filtered_codas = {c for c, _ in top}
# Calculate coverage
total_tokens = sum(syllable_counts.values())
kept_tokens = sum(syllable_counts[s] for s in filtered_syls)
coverage = kept_tokens / total_tokens * 100 if total_tokens > 0 else 0
print(f" Syllables: {len(syllable_counts)}{len(filtered_syls)} ({coverage:.1f}% coverage)")
print(f" Onsets: {len(onset_counts)}{len(filtered_onsets)}")
print(f" Nuclei: {len(nucleus_counts)}{len(filtered_nuclei)}")
print(f" Codas: {len(coda_counts)}{len(filtered_codas)}")
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# Create Deterministic ID Mappings
# Syllables
special_syls = ['<pad>', '<unk>']
other_syls = sorted(filtered_syls - set(special_syls))
syllable_to_id = {s: i for i, s in enumerate(special_syls + other_syls)}
id_to_syllable = {i: s for s, i in syllable_to_id.items()}
# Onsets (with special tokens)
special_onsets = ['<pad>', '', '<unk>', '<num>', '<punct>', '<special>']
other_onsets = sorted(filtered_onsets - set(special_onsets))
onset_to_id = {s: i for i, s in enumerate(special_onsets + other_onsets)}
# Nuclei
special_nuclei = ['<pad>', '', '<unk>']
other_nuclei = sorted(filtered_nuclei - set(special_nuclei))
nucleus_to_id = {s: i for i, s in enumerate(special_nuclei + other_nuclei)}
# Codas
special_codas = ['<pad>', '', '<unk>']
other_codas = sorted(filtered_codas - set(special_codas))
coda_to_id = {s: i for i, s in enumerate(special_codas + other_codas)}
print(f"\nFinal vocab sizes: syllables={len(syllable_to_id)}, onsets={len(onset_to_id)}, nuclei={len(nucleus_to_id)}, codas={len(coda_to_id)}")
# Save vocabulary
vocab_path = os.path.join(output_dir, "vocab.json")
vocab_data = {
'syllable_to_id': syllable_to_id,
'id_to_syllable': {str(k): v for k, v in id_to_syllable.items()},
'onset_to_id': onset_to_id,
'nucleus_to_id': nucleus_to_id,
'coda_to_id': coda_to_id,
'version': 'v4',
'features': FEATURE_NAMES,
'n_features': N_FEATURES,
}
with open(vocab_path, 'w', encoding='utf-8') as f:
json.dump(vocab_data, f, indent=2)
print(f"Saved: {vocab_path}")
return vocab_path
#-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# Pass 2: Tokenize with global vocabulary
def tokenize_chunk_with_global_vocab(args: Tuple[str, int, int, str, str]) -> Tuple[str, int]:
"""Tokenize a chunk using global vocabulary."""
input_path, start_byte, end_byte, output_path, vocab_path = args
with open(input_path, 'r', encoding='utf-8', errors='ignore') as f:
f.seek(start_byte)
text = f.read(end_byte - start_byte)
if not text or not text.strip():
return None, 0
# Load global vocabulary
with open(vocab_path, 'r', encoding='utf-8') as f:
vocab = json.load(f)
syllable_to_id = vocab['syllable_to_id']
onset_to_id = vocab['onset_to_id']
nucleus_to_id = vocab['nucleus_to_id']
coda_to_id = vocab['coda_to_id']
# UNK IDs for fallback
syl_unk = syllable_to_id.get('<unk>', 1)
onset_unk = onset_to_id.get('<unk>', 2)
nucleus_unk = nucleus_to_id.get('<unk>', 2)
coda_unk = coda_to_id.get('<unk>', 2)
tokenizer = get_tokenizer()
encoded = tokenizer.encode(text)
if not encoded:
return None, 0
tokens = []
vowels = set('aeiouy')
for e in encoded:
text_content = e.get('text', '')
token_type = e['token_type']
# Determine syllable key
if token_type == 2:
syl_key = f"<punct_{text_content}>"
elif token_type == 1:
syl_key = f"<num_{text_content}>"
elif token_type == 3:
syl_key = f"<char_{text_content}>"
else:
syl_key = text_content.lower()
syl_id = syllable_to_id.get(syl_key, syl_unk)
# Extract onset/nucleus/coda for regular syllables
if token_type == 0 and text_content:
syl_lower = text_content.lower()
nucleus_start = -1
nucleus_end = -1
for i, char in enumerate(syl_lower):
if char in vowels:
if nucleus_start == -1:
nucleus_start = i
nucleus_end = i + 1
elif nucleus_start != -1:
break
if nucleus_start != -1:
onset = syl_lower[:nucleus_start]
nucleus = syl_lower[nucleus_start:nucleus_end]
coda = syl_lower[nucleus_end:]
else:
onset, nucleus, coda = syl_lower, '', ''
onset_id = onset_to_id.get(onset, onset_unk)
nucleus_id = nucleus_to_id.get(nucleus, nucleus_unk)
coda_id = coda_to_id.get(coda, coda_unk)
elif token_type == 1: # Number
onset_id = onset_to_id.get('<num>', onset_unk)
nucleus_id = nucleus_to_id.get('', 1)
coda_id = coda_to_id.get('', 1)
elif token_type == 2: # Punctuation
onset_id = onset_to_id.get('<punct>', onset_unk)
nucleus_id = nucleus_to_id.get('', 1)
coda_id = coda_to_id.get('', 1)
else: # Special
onset_id = onset_to_id.get('<special>', onset_unk)
nucleus_id = nucleus_to_id.get('', 1)
coda_id = coda_to_id.get('', 1)
# Build 9-feature token (v4 format)
tokens.append([
syl_id, # 0: syllable_id
onset_id, # 1: onset_id
nucleus_id, # 2: nucleus_id
coda_id, # 3: coda_id
e['position'], # 4: position
e['is_capitalized'], # 5: is_capitalized
e['token_type'], # 6: token_type
e['has_space_after'], # 7: has_space_after
e['is_word_end'], # 8: is_word_end
])
arr = np.array(tokens, dtype=np.int32)
np.save(output_path, arr)
return output_path, len(arr)
def tokenize_with_global_vocab(
input_path: str,
output_dir: str,
vocab_path: str,
val_split: float = 0.02,
n_workers: int = 8
):
"""Tokenize entire dataset using global vocabulary."""
file_size = os.path.getsize(input_path)
print(f"\n{'='*70}")
print("PASS 2: Tokenizing with Global Vocabulary")
print(f"{'='*70}")
temp_dir = os.path.join(output_dir, "_temp")
os.makedirs(temp_dir, exist_ok=True)
# Find chunk boundaries
chunk_size = 2 * 1024 * 1024
chunk_boundaries = [0]
with open(input_path, 'rb') as f:
while True:
pos = chunk_boundaries[-1] + chunk_size
if pos >= file_size:
chunk_boundaries.append(file_size)
break
f.seek(pos)
f.readline()
chunk_boundaries.append(f.tell())
n_chunks = len(chunk_boundaries) - 1
# Tokenize chunks
jobs = []
for i in range(n_chunks):
chunk_output = os.path.join(temp_dir, f"chunk_{i:06d}.npy")
jobs.append((input_path, chunk_boundaries[i], chunk_boundaries[i+1], chunk_output, vocab_path))
chunk_files = []
total_tokens = 0
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = {executor.submit(tokenize_chunk_with_global_vocab, job): i for i, job in enumerate(jobs)}
pbar = tqdm(total=len(jobs), desc="Tokenizing")
for future in as_completed(futures):
path, count = future.result()
if path:
chunk_files.append(path)
total_tokens += count
pbar.update(1)
pbar.close()
print(f"Total tokens: {total_tokens:,}")
# Sort and merge
chunk_files.sort()
total_rows = 0
chunk_sizes = []
for cf in chunk_files:
arr = np.load(cf, mmap_mode='r')
chunk_sizes.append(len(arr))
total_rows += len(arr)
n_val = int(total_rows * val_split)
n_train = total_rows - n_val
print(f"Split: train={n_train:,}, val={n_val:,}")
# Create memmap files
train_path = os.path.join(output_dir, "train_tokens.dat")
val_path = os.path.join(output_dir, "val_tokens.dat")
train_mm = np.memmap(train_path, dtype=np.int32, mode='w+', shape=(n_train, N_FEATURES))
val_mm = np.memmap(val_path, dtype=np.int32, mode='w+', shape=(n_val, N_FEATURES))
# Merge chunks
offset = 0
for cf, size in tqdm(zip(chunk_files, chunk_sizes), total=len(chunk_files), desc="Merging"):
arr = np.load(cf)
end = offset + size
if end <= n_train:
train_mm[offset:end] = arr
elif offset >= n_train:
val_offset = offset - n_train
val_mm[val_offset:val_offset + size] = arr
else:
split_point = n_train - offset
train_mm[offset:n_train] = arr[:split_point]
val_mm[0:size - split_point] = arr[split_point:]
offset = end
del arr
train_mm.flush()
val_mm.flush()
del train_mm, val_mm
# Cleanup
for cf in chunk_files:
os.remove(cf)
os.rmdir(temp_dir)
# Save config
with open(vocab_path, 'r') as f:
vocab = json.load(f)
config = {
"total_tokens": total_tokens,
"train_tokens": n_train,
"val_tokens": n_val,
"n_features": N_FEATURES,
"feature_names": FEATURE_NAMES,
"vocab_sizes": {
"syllables": len(vocab['syllable_to_id']),
"onsets": len(vocab['onset_to_id']),
"nuclei": len(vocab['nucleus_to_id']),
"codas": len(vocab['coda_to_id']),
"positions": 4,
"capitalized": 2,
"token_types": 4,
"has_space_after": 2,
"is_word_end": 2,
},
"dtype": "int32",
"version": "v4",
}
config_path = os.path.join(output_dir, "config.json")
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
print(f"\nOutput: train={os.path.getsize(train_path)/1e9:.2f}GB, val={os.path.getsize(val_path)/1e9:.2f}GB")
def generate_dataset(
input_path: str,
output_dir: str,
val_split: float = 0.02,
n_workers: int = 8,
min_freq: int = 10,
max_syllables: int = 30000,
max_onsets: int = 1500,
max_nuclei: int = 500,
max_codas: int = 2000
):
"""Generate complete dataset."""
print("=" * 70)
print("Luna - Data Generation Pipeline")
print("=" * 70)
os.makedirs(output_dir, exist_ok=True)
vocab_path = build_global_vocab(
input_path, output_dir, n_workers,
min_freq, max_syllables, max_onsets, max_nuclei, max_codas
)
tokenize_with_global_vocab(input_path, output_dir, vocab_path, val_split, n_workers)
print(f"\n{'='*70}")
print("COMPLETE!")
print(f"{'='*70}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate Luna training data")
parser.add_argument("--input", type=str, required=True, help="Input text file")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
parser.add_argument("--val_split", type=float, default=0.02, help="Validation split")
parser.add_argument("--workers", type=int, default=8, help="Number of workers")
parser.add_argument("--min_freq", type=int, default=10, help="Min frequency for syllables")
parser.add_argument("--max_syllables", type=int, default=32768, help="Max syllable vocab")
parser.add_argument("--max_onsets", type=int, default=1500, help="Max onset vocab")
parser.add_argument("--max_nuclei", type=int, default=500, help="Max nucleus vocab")
parser.add_argument("--max_codas", type=int, default=2000, help="Max coda vocab")
args = parser.parse_args()
generate_dataset(
input_path=args.input,
output_dir=args.output_dir,
val_split=args.val_split,
n_workers=args.workers,
min_freq=args.min_freq,
max_syllables=args.max_syllables,
max_onsets=args.max_onsets,
max_nuclei=args.max_nuclei,
max_codas=args.max_codas
)