|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
from collections import Counter |
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
from typing import Dict, List, Tuple, Set |
|
|
|
|
|
|
|
|
|
|
|
FEATURE_NAMES = [ |
|
|
'syllable_id', |
|
|
'onset_id', |
|
|
'nucleus_id', |
|
|
'coda_id', |
|
|
'position', |
|
|
'is_capitalized', |
|
|
'token_type', |
|
|
'has_space_after', |
|
|
'is_word_end', |
|
|
] |
|
|
|
|
|
N_FEATURES = len(FEATURE_NAMES) |
|
|
assert N_FEATURES == 9, f"Expect 9 features, got {N_FEATURES}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tokenizer(): |
|
|
"""Create a fresh tokenizer instance with suppressed output.""" |
|
|
import sys |
|
|
from io import StringIO |
|
|
old_stdout = sys.stdout |
|
|
sys.stdout = StringIO() |
|
|
try: |
|
|
from tokenizer import LunaTokenizer |
|
|
tok = LunaTokenizer() |
|
|
finally: |
|
|
sys.stdout = old_stdout |
|
|
return tok |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_vocab_from_chunk(args: Tuple[str, int, int]) -> Dict[str, Counter]: |
|
|
"""Extract vocabulary counts from a chunk.""" |
|
|
input_path, start_byte, end_byte = args |
|
|
|
|
|
with open(input_path, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
f.seek(start_byte) |
|
|
text = f.read(end_byte - start_byte) |
|
|
|
|
|
if not text or not text.strip(): |
|
|
return { |
|
|
'syllables': Counter(), |
|
|
'onsets': Counter(), |
|
|
'nuclei': Counter(), |
|
|
'codas': Counter() |
|
|
} |
|
|
tokenizer = get_tokenizer() |
|
|
encoded = tokenizer.encode(text) |
|
|
syllable_counts = Counter() |
|
|
onset_counts = Counter() |
|
|
nucleus_counts = Counter() |
|
|
coda_counts = Counter() |
|
|
|
|
|
for token in encoded: |
|
|
text_content = token.get('text', '') |
|
|
token_type = token.get('token_type', 0) |
|
|
|
|
|
|
|
|
if token_type == 2: |
|
|
syl_key = f"<punct_{text_content}>" |
|
|
elif token_type == 1: |
|
|
syl_key = f"<num_{text_content}>" |
|
|
elif token_type == 3: |
|
|
syl_key = f"<char_{text_content}>" |
|
|
else: |
|
|
syl_key = text_content.lower() if text_content else '' |
|
|
|
|
|
if syl_key: |
|
|
syllable_counts[syl_key] += 1 |
|
|
|
|
|
|
|
|
if token_type == 0 and text_content: |
|
|
syl_lower = text_content.lower() |
|
|
vowels = set('aeiouy') |
|
|
|
|
|
|
|
|
nucleus_start = -1 |
|
|
nucleus_end = -1 |
|
|
for i, char in enumerate(syl_lower): |
|
|
if char in vowels: |
|
|
if nucleus_start == -1: |
|
|
nucleus_start = i |
|
|
nucleus_end = i + 1 |
|
|
elif nucleus_start != -1: |
|
|
break |
|
|
|
|
|
if nucleus_start != -1: |
|
|
onset = syl_lower[:nucleus_start] |
|
|
nucleus = syl_lower[nucleus_start:nucleus_end] |
|
|
coda = syl_lower[nucleus_end:] |
|
|
else: |
|
|
onset, nucleus, coda = syl_lower, '', '' |
|
|
|
|
|
onset_counts[onset] += 1 |
|
|
nucleus_counts[nucleus] += 1 |
|
|
coda_counts[coda] += 1 |
|
|
|
|
|
return { |
|
|
'syllables': syllable_counts, |
|
|
'onsets': onset_counts, |
|
|
'nuclei': nucleus_counts, |
|
|
'codas': coda_counts, |
|
|
} |
|
|
|
|
|
def build_global_vocab( |
|
|
input_path: str, |
|
|
output_dir: str, |
|
|
n_workers: int = 8, |
|
|
min_freq: int = 15, |
|
|
max_syllables: int = 32768, |
|
|
max_onsets: int = 2048, |
|
|
max_nuclei: int = 512, |
|
|
max_codas: int = 2048 |
|
|
) -> str: |
|
|
""" |
|
|
Build global vocabulary with frequency filtering and caps. |
|
|
- filters onset/nucleus/coda vocabularies to prevent VRAM |
|
|
explosion from garbage phonetic components. |
|
|
""" |
|
|
file_size = os.path.getsize(input_path) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("PASS 1: Building Global Vocabulary (v4 with phonetic caps)") |
|
|
print(f"{'='*70}") |
|
|
print(f"Input: {input_path} ({file_size/1e6:.0f} MB)") |
|
|
print(f"Caps: syllables={max_syllables}, onsets={max_onsets}, nuclei={max_nuclei}, codas={max_codas}") |
|
|
|
|
|
|
|
|
chunk_size = 2 * 1024 * 1024 |
|
|
chunk_boundaries = [0] |
|
|
|
|
|
with open(input_path, 'rb') as f: |
|
|
while True: |
|
|
pos = chunk_boundaries[-1] + chunk_size |
|
|
if pos >= file_size: |
|
|
chunk_boundaries.append(file_size) |
|
|
break |
|
|
f.seek(pos) |
|
|
f.readline() |
|
|
chunk_boundaries.append(f.tell()) |
|
|
|
|
|
n_chunks = len(chunk_boundaries) - 1 |
|
|
print(f"Processing {n_chunks} chunks with {n_workers} workers...") |
|
|
|
|
|
|
|
|
jobs = [(input_path, chunk_boundaries[i], chunk_boundaries[i+1]) for i in range(n_chunks)] |
|
|
|
|
|
syllable_counts = Counter() |
|
|
onset_counts = Counter() |
|
|
nucleus_counts = Counter() |
|
|
coda_counts = Counter() |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor: |
|
|
futures = [executor.submit(extract_vocab_from_chunk, job) for job in jobs] |
|
|
|
|
|
pbar = tqdm(total=len(futures), desc="Scanning") |
|
|
for future in as_completed(futures): |
|
|
vocab = future.result() |
|
|
syllable_counts.update(vocab['syllables']) |
|
|
onset_counts.update(vocab['onsets']) |
|
|
nucleus_counts.update(vocab['nuclei']) |
|
|
coda_counts.update(vocab['codas']) |
|
|
pbar.update(1) |
|
|
pbar.close() |
|
|
|
|
|
print(f"\nRaw vocab sizes: syllables={len(syllable_counts)}, onsets={len(onset_counts)}, nuclei={len(nucleus_counts)}, codas={len(coda_counts)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nApplying frequency filters...") |
|
|
|
|
|
|
|
|
filtered_syls = {s for s, c in syllable_counts.items() if c >= min_freq} |
|
|
if len(filtered_syls) > max_syllables: |
|
|
top = syllable_counts.most_common(max_syllables) |
|
|
filtered_syls = {s for s, c in top if c >= min_freq} |
|
|
|
|
|
|
|
|
filtered_onsets = {o for o, c in onset_counts.items() if c >= min_freq} |
|
|
if len(filtered_onsets) > max_onsets: |
|
|
top = onset_counts.most_common(max_onsets) |
|
|
filtered_onsets = {o for o, _ in top} |
|
|
|
|
|
|
|
|
filtered_nuclei = {n for n, c in nucleus_counts.items() if c >= min_freq} |
|
|
if len(filtered_nuclei) > max_nuclei: |
|
|
top = nucleus_counts.most_common(max_nuclei) |
|
|
filtered_nuclei = {n for n, _ in top} |
|
|
|
|
|
|
|
|
filtered_codas = {c for c, cnt in coda_counts.items() if cnt >= min_freq} |
|
|
if len(filtered_codas) > max_codas: |
|
|
top = coda_counts.most_common(max_codas) |
|
|
filtered_codas = {c for c, _ in top} |
|
|
|
|
|
|
|
|
total_tokens = sum(syllable_counts.values()) |
|
|
kept_tokens = sum(syllable_counts[s] for s in filtered_syls) |
|
|
coverage = kept_tokens / total_tokens * 100 if total_tokens > 0 else 0 |
|
|
|
|
|
print(f" Syllables: {len(syllable_counts)} → {len(filtered_syls)} ({coverage:.1f}% coverage)") |
|
|
print(f" Onsets: {len(onset_counts)} → {len(filtered_onsets)}") |
|
|
print(f" Nuclei: {len(nucleus_counts)} → {len(filtered_nuclei)}") |
|
|
print(f" Codas: {len(coda_counts)} → {len(filtered_codas)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
special_syls = ['<pad>', '<unk>'] |
|
|
other_syls = sorted(filtered_syls - set(special_syls)) |
|
|
syllable_to_id = {s: i for i, s in enumerate(special_syls + other_syls)} |
|
|
id_to_syllable = {i: s for s, i in syllable_to_id.items()} |
|
|
|
|
|
|
|
|
special_onsets = ['<pad>', '', '<unk>', '<num>', '<punct>', '<special>'] |
|
|
other_onsets = sorted(filtered_onsets - set(special_onsets)) |
|
|
onset_to_id = {s: i for i, s in enumerate(special_onsets + other_onsets)} |
|
|
|
|
|
|
|
|
special_nuclei = ['<pad>', '', '<unk>'] |
|
|
other_nuclei = sorted(filtered_nuclei - set(special_nuclei)) |
|
|
nucleus_to_id = {s: i for i, s in enumerate(special_nuclei + other_nuclei)} |
|
|
|
|
|
|
|
|
special_codas = ['<pad>', '', '<unk>'] |
|
|
other_codas = sorted(filtered_codas - set(special_codas)) |
|
|
coda_to_id = {s: i for i, s in enumerate(special_codas + other_codas)} |
|
|
|
|
|
print(f"\nFinal vocab sizes: syllables={len(syllable_to_id)}, onsets={len(onset_to_id)}, nuclei={len(nucleus_to_id)}, codas={len(coda_to_id)}") |
|
|
|
|
|
|
|
|
vocab_path = os.path.join(output_dir, "vocab.json") |
|
|
vocab_data = { |
|
|
'syllable_to_id': syllable_to_id, |
|
|
'id_to_syllable': {str(k): v for k, v in id_to_syllable.items()}, |
|
|
'onset_to_id': onset_to_id, |
|
|
'nucleus_to_id': nucleus_to_id, |
|
|
'coda_to_id': coda_to_id, |
|
|
'version': 'v4', |
|
|
'features': FEATURE_NAMES, |
|
|
'n_features': N_FEATURES, |
|
|
} |
|
|
|
|
|
with open(vocab_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(vocab_data, f, indent=2) |
|
|
|
|
|
print(f"Saved: {vocab_path}") |
|
|
return vocab_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_chunk_with_global_vocab(args: Tuple[str, int, int, str, str]) -> Tuple[str, int]: |
|
|
"""Tokenize a chunk using global vocabulary.""" |
|
|
input_path, start_byte, end_byte, output_path, vocab_path = args |
|
|
|
|
|
with open(input_path, 'r', encoding='utf-8', errors='ignore') as f: |
|
|
f.seek(start_byte) |
|
|
text = f.read(end_byte - start_byte) |
|
|
|
|
|
if not text or not text.strip(): |
|
|
return None, 0 |
|
|
|
|
|
|
|
|
with open(vocab_path, 'r', encoding='utf-8') as f: |
|
|
vocab = json.load(f) |
|
|
|
|
|
syllable_to_id = vocab['syllable_to_id'] |
|
|
onset_to_id = vocab['onset_to_id'] |
|
|
nucleus_to_id = vocab['nucleus_to_id'] |
|
|
coda_to_id = vocab['coda_to_id'] |
|
|
|
|
|
|
|
|
syl_unk = syllable_to_id.get('<unk>', 1) |
|
|
onset_unk = onset_to_id.get('<unk>', 2) |
|
|
nucleus_unk = nucleus_to_id.get('<unk>', 2) |
|
|
coda_unk = coda_to_id.get('<unk>', 2) |
|
|
|
|
|
tokenizer = get_tokenizer() |
|
|
encoded = tokenizer.encode(text) |
|
|
|
|
|
if not encoded: |
|
|
return None, 0 |
|
|
|
|
|
tokens = [] |
|
|
vowels = set('aeiouy') |
|
|
|
|
|
for e in encoded: |
|
|
text_content = e.get('text', '') |
|
|
token_type = e['token_type'] |
|
|
|
|
|
|
|
|
if token_type == 2: |
|
|
syl_key = f"<punct_{text_content}>" |
|
|
elif token_type == 1: |
|
|
syl_key = f"<num_{text_content}>" |
|
|
elif token_type == 3: |
|
|
syl_key = f"<char_{text_content}>" |
|
|
else: |
|
|
syl_key = text_content.lower() |
|
|
|
|
|
syl_id = syllable_to_id.get(syl_key, syl_unk) |
|
|
|
|
|
|
|
|
if token_type == 0 and text_content: |
|
|
syl_lower = text_content.lower() |
|
|
nucleus_start = -1 |
|
|
nucleus_end = -1 |
|
|
|
|
|
for i, char in enumerate(syl_lower): |
|
|
if char in vowels: |
|
|
if nucleus_start == -1: |
|
|
nucleus_start = i |
|
|
nucleus_end = i + 1 |
|
|
elif nucleus_start != -1: |
|
|
break |
|
|
|
|
|
if nucleus_start != -1: |
|
|
onset = syl_lower[:nucleus_start] |
|
|
nucleus = syl_lower[nucleus_start:nucleus_end] |
|
|
coda = syl_lower[nucleus_end:] |
|
|
else: |
|
|
onset, nucleus, coda = syl_lower, '', '' |
|
|
|
|
|
onset_id = onset_to_id.get(onset, onset_unk) |
|
|
nucleus_id = nucleus_to_id.get(nucleus, nucleus_unk) |
|
|
coda_id = coda_to_id.get(coda, coda_unk) |
|
|
|
|
|
elif token_type == 1: |
|
|
onset_id = onset_to_id.get('<num>', onset_unk) |
|
|
nucleus_id = nucleus_to_id.get('', 1) |
|
|
coda_id = coda_to_id.get('', 1) |
|
|
|
|
|
elif token_type == 2: |
|
|
onset_id = onset_to_id.get('<punct>', onset_unk) |
|
|
nucleus_id = nucleus_to_id.get('', 1) |
|
|
coda_id = coda_to_id.get('', 1) |
|
|
|
|
|
else: |
|
|
onset_id = onset_to_id.get('<special>', onset_unk) |
|
|
nucleus_id = nucleus_to_id.get('', 1) |
|
|
coda_id = coda_to_id.get('', 1) |
|
|
|
|
|
|
|
|
tokens.append([ |
|
|
syl_id, |
|
|
onset_id, |
|
|
nucleus_id, |
|
|
coda_id, |
|
|
e['position'], |
|
|
e['is_capitalized'], |
|
|
e['token_type'], |
|
|
e['has_space_after'], |
|
|
e['is_word_end'], |
|
|
]) |
|
|
|
|
|
arr = np.array(tokens, dtype=np.int32) |
|
|
np.save(output_path, arr) |
|
|
|
|
|
return output_path, len(arr) |
|
|
def tokenize_with_global_vocab( |
|
|
input_path: str, |
|
|
output_dir: str, |
|
|
vocab_path: str, |
|
|
val_split: float = 0.02, |
|
|
n_workers: int = 8 |
|
|
): |
|
|
"""Tokenize entire dataset using global vocabulary.""" |
|
|
file_size = os.path.getsize(input_path) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("PASS 2: Tokenizing with Global Vocabulary") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
temp_dir = os.path.join(output_dir, "_temp") |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
chunk_size = 2 * 1024 * 1024 |
|
|
chunk_boundaries = [0] |
|
|
|
|
|
with open(input_path, 'rb') as f: |
|
|
while True: |
|
|
pos = chunk_boundaries[-1] + chunk_size |
|
|
if pos >= file_size: |
|
|
chunk_boundaries.append(file_size) |
|
|
break |
|
|
f.seek(pos) |
|
|
f.readline() |
|
|
chunk_boundaries.append(f.tell()) |
|
|
|
|
|
n_chunks = len(chunk_boundaries) - 1 |
|
|
|
|
|
|
|
|
jobs = [] |
|
|
for i in range(n_chunks): |
|
|
chunk_output = os.path.join(temp_dir, f"chunk_{i:06d}.npy") |
|
|
jobs.append((input_path, chunk_boundaries[i], chunk_boundaries[i+1], chunk_output, vocab_path)) |
|
|
|
|
|
chunk_files = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor: |
|
|
futures = {executor.submit(tokenize_chunk_with_global_vocab, job): i for i, job in enumerate(jobs)} |
|
|
|
|
|
pbar = tqdm(total=len(jobs), desc="Tokenizing") |
|
|
for future in as_completed(futures): |
|
|
path, count = future.result() |
|
|
if path: |
|
|
chunk_files.append(path) |
|
|
total_tokens += count |
|
|
pbar.update(1) |
|
|
pbar.close() |
|
|
|
|
|
print(f"Total tokens: {total_tokens:,}") |
|
|
|
|
|
|
|
|
chunk_files.sort() |
|
|
|
|
|
total_rows = 0 |
|
|
chunk_sizes = [] |
|
|
for cf in chunk_files: |
|
|
arr = np.load(cf, mmap_mode='r') |
|
|
chunk_sizes.append(len(arr)) |
|
|
total_rows += len(arr) |
|
|
|
|
|
n_val = int(total_rows * val_split) |
|
|
n_train = total_rows - n_val |
|
|
|
|
|
print(f"Split: train={n_train:,}, val={n_val:,}") |
|
|
|
|
|
|
|
|
train_path = os.path.join(output_dir, "train_tokens.dat") |
|
|
val_path = os.path.join(output_dir, "val_tokens.dat") |
|
|
|
|
|
train_mm = np.memmap(train_path, dtype=np.int32, mode='w+', shape=(n_train, N_FEATURES)) |
|
|
val_mm = np.memmap(val_path, dtype=np.int32, mode='w+', shape=(n_val, N_FEATURES)) |
|
|
|
|
|
|
|
|
offset = 0 |
|
|
for cf, size in tqdm(zip(chunk_files, chunk_sizes), total=len(chunk_files), desc="Merging"): |
|
|
arr = np.load(cf) |
|
|
end = offset + size |
|
|
|
|
|
if end <= n_train: |
|
|
train_mm[offset:end] = arr |
|
|
elif offset >= n_train: |
|
|
val_offset = offset - n_train |
|
|
val_mm[val_offset:val_offset + size] = arr |
|
|
else: |
|
|
split_point = n_train - offset |
|
|
train_mm[offset:n_train] = arr[:split_point] |
|
|
val_mm[0:size - split_point] = arr[split_point:] |
|
|
|
|
|
offset = end |
|
|
del arr |
|
|
|
|
|
train_mm.flush() |
|
|
val_mm.flush() |
|
|
del train_mm, val_mm |
|
|
|
|
|
|
|
|
for cf in chunk_files: |
|
|
os.remove(cf) |
|
|
os.rmdir(temp_dir) |
|
|
|
|
|
|
|
|
with open(vocab_path, 'r') as f: |
|
|
vocab = json.load(f) |
|
|
|
|
|
config = { |
|
|
"total_tokens": total_tokens, |
|
|
"train_tokens": n_train, |
|
|
"val_tokens": n_val, |
|
|
"n_features": N_FEATURES, |
|
|
"feature_names": FEATURE_NAMES, |
|
|
"vocab_sizes": { |
|
|
"syllables": len(vocab['syllable_to_id']), |
|
|
"onsets": len(vocab['onset_to_id']), |
|
|
"nuclei": len(vocab['nucleus_to_id']), |
|
|
"codas": len(vocab['coda_to_id']), |
|
|
"positions": 4, |
|
|
"capitalized": 2, |
|
|
"token_types": 4, |
|
|
"has_space_after": 2, |
|
|
"is_word_end": 2, |
|
|
}, |
|
|
"dtype": "int32", |
|
|
"version": "v4", |
|
|
} |
|
|
|
|
|
config_path = os.path.join(output_dir, "config.json") |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
print(f"\nOutput: train={os.path.getsize(train_path)/1e9:.2f}GB, val={os.path.getsize(val_path)/1e9:.2f}GB") |
|
|
|
|
|
|
|
|
def generate_dataset( |
|
|
input_path: str, |
|
|
output_dir: str, |
|
|
val_split: float = 0.02, |
|
|
n_workers: int = 8, |
|
|
min_freq: int = 10, |
|
|
max_syllables: int = 30000, |
|
|
max_onsets: int = 1500, |
|
|
max_nuclei: int = 500, |
|
|
max_codas: int = 2000 |
|
|
): |
|
|
"""Generate complete dataset.""" |
|
|
print("=" * 70) |
|
|
print("Luna - Data Generation Pipeline") |
|
|
print("=" * 70) |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
vocab_path = build_global_vocab( |
|
|
input_path, output_dir, n_workers, |
|
|
min_freq, max_syllables, max_onsets, max_nuclei, max_codas |
|
|
) |
|
|
|
|
|
tokenize_with_global_vocab(input_path, output_dir, vocab_path, val_split, n_workers) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("COMPLETE!") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Generate Luna training data") |
|
|
|
|
|
parser.add_argument("--input", type=str, required=True, help="Input text file") |
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Output directory") |
|
|
parser.add_argument("--val_split", type=float, default=0.02, help="Validation split") |
|
|
parser.add_argument("--workers", type=int, default=8, help="Number of workers") |
|
|
parser.add_argument("--min_freq", type=int, default=10, help="Min frequency for syllables") |
|
|
parser.add_argument("--max_syllables", type=int, default=32768, help="Max syllable vocab") |
|
|
parser.add_argument("--max_onsets", type=int, default=1500, help="Max onset vocab") |
|
|
parser.add_argument("--max_nuclei", type=int, default=500, help="Max nucleus vocab") |
|
|
parser.add_argument("--max_codas", type=int, default=2000, help="Max coda vocab") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
generate_dataset( |
|
|
input_path=args.input, |
|
|
output_dir=args.output_dir, |
|
|
val_split=args.val_split, |
|
|
n_workers=args.workers, |
|
|
min_freq=args.min_freq, |
|
|
max_syllables=args.max_syllables, |
|
|
max_onsets=args.max_onsets, |
|
|
max_nuclei=args.max_nuclei, |
|
|
max_codas=args.max_codas |
|
|
) |