aramt5 / src /data /balance_corpus.py
crossroderick's picture
Data augmentation and balancing updates for a re-run of v3
11632a3
#!/usr/bin/env python3
"""
Balance the augmented corpus to include more multi-word examples.
Current distribution: 98.5% single, 0.0% two-word, 1.5% multi
Target distribution: ~40% single, ~30% two-word, ~30% multi
Strategy:
1. Extract single-word vocabulary with transliterations
2. Generate two-word compound pairs using Syriac patterns:
- Construct state: noun + d- + noun (e.g., beyt d-ʾabrāhām)
- Proclitic combos: b-/w-/l- + word + word
3. Downsample single-word examples
4. Output balanced corpus
"""
import json
import random
from collections import defaultdict
from pathlib import Path
# Proclitics for combining
PROCLITICS_WEST = {
"ܒ": "b-", # in/with
"ܕ": "d-", # of/that
"ܘ": "w-", # and
"ܠ": "l-", # to/for
}
PROCLITICS_EAST = {
"ܒ": "b-",
"ܕ": "d-",
"ܘ": "w-",
"ܠ": "l-",
}
def load_corpus(path: Path) -> list[dict]:
"""Load JSONL corpus."""
entries = []
with open(path) as f:
for line in f:
entries.append(json.loads(line))
return entries
def extract_vocabulary(entries: list[dict]) -> dict[str, str]:
"""Extract single-word vocabulary: syriac -> transliteration."""
vocab = {}
for entry in entries:
t = entry["transliteration"]
src = t["src"]
tgt = t["tgt"]
# Only single words, skip proclitics
if " " not in src and not tgt.startswith(("b-", "d-", "w-", "l-")):
vocab[src] = tgt
return vocab
def generate_two_word_pairs(
vocab: dict[str, str],
dialect: str,
count: int,
) -> list[dict]:
"""Generate two-word compound pairs."""
pairs = []
words = list(vocab.items())
proclitic_map = PROCLITICS_WEST if dialect == "west" else PROCLITICS_EAST
# Sample word pairs
random.shuffle(words)
used = set()
for i in range(0, len(words) - 1, 2):
if len(pairs) >= count:
break
syr1, lat1 = words[i]
syr2, lat2 = words[i + 1]
# Skip if either has vowel marks that might cause issues
key = (syr1, syr2)
if key in used:
continue
used.add(key)
# Pattern 1: Simple juxtaposition (word1 word2)
pairs.append(
{
"transliteration": {
"src": f"{syr1} {syr2}",
"tgt": f"{lat1} {lat2}",
"title": "compound",
"dialect": dialect,
"source": "synthetic-2word",
}
}
)
# Pattern 2: Construct state with d- (word1 d-word2)
pairs.append(
{
"transliteration": {
"src": f"{syr1} ܕ{syr2}",
"tgt": f"{lat1} d-{lat2}",
"title": "construct",
"dialect": dialect,
"source": "synthetic-construct",
}
}
)
# Pattern 3: Proclitic + word1 + word2
for syr_pro, lat_pro in proclitic_map.items():
if syr_pro == "ܕ": # Skip d- since we have construct
continue
pairs.append(
{
"transliteration": {
"src": f"{syr_pro}{syr1} {syr2}",
"tgt": f"{lat_pro}{lat1} {lat2}",
"title": "proclitic-pair",
"dialect": dialect,
"source": "synthetic-proclitic",
}
}
)
return pairs[:count]
def generate_multi_word_phrases(
vocab: dict[str, str],
dialect: str,
count: int,
) -> list[dict]:
"""Generate 3, 4, and 5-word phrases."""
phrases = []
words = list(vocab.items())
random.shuffle(words)
i = 0
while len(phrases) < count and i + 4 < len(words):
syr1, lat1 = words[i]
syr2, lat2 = words[i + 1]
syr3, lat3 = words[i + 2]
syr4, lat4 = words[i + 3]
syr5, lat5 = words[i + 4]
i += 5
# 3-word patterns
phrases.append(
{
"transliteration": {
"src": f"{syr1} ܕ{syr2} ܘ{syr3}",
"tgt": f"{lat1} d-{lat2} w-{lat3}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-3word",
}
}
)
phrases.append(
{
"transliteration": {
"src": f"{syr1} {syr2} {syr3}",
"tgt": f"{lat1} {lat2} {lat3}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-3word",
}
}
)
# 4-word patterns
phrases.append(
{
"transliteration": {
"src": f"{syr1} {syr2} ܕ{syr3} {syr4}",
"tgt": f"{lat1} {lat2} d-{lat3} {lat4}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-4word",
}
}
)
phrases.append(
{
"transliteration": {
"src": f"ܒ{syr1} {syr2} ܘ{syr3} {syr4}",
"tgt": f"b-{lat1} {lat2} w-{lat3} {lat4}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-4word",
}
}
)
# 5-word patterns
phrases.append(
{
"transliteration": {
"src": f"{syr1} ܕ{syr2} {syr3} ܘ{syr4} {syr5}",
"tgt": f"{lat1} d-{lat2} {lat3} w-{lat4} {lat5}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-5word",
}
}
)
phrases.append(
{
"transliteration": {
"src": f"{syr1} {syr2} {syr3} {syr4} {syr5}",
"tgt": f"{lat1} {lat2} {lat3} {lat4} {lat5}",
"title": "phrase",
"dialect": dialect,
"source": "synthetic-5word",
}
}
)
return phrases[:count]
def balance_corpus(
entries: list[dict],
dialect: str,
target_single_ratio: float = 0.40,
target_two_ratio: float = 0.30,
target_multi_ratio: float = 0.30,
) -> list[dict]:
"""Balance corpus with target distribution."""
# Categorize existing entries
single = []
two_word = []
multi = []
for entry in entries:
src = entry["transliteration"]["src"]
words = src.split()
if len(words) == 1:
single.append(entry)
elif len(words) == 2:
two_word.append(entry)
else:
multi.append(entry)
print(f"Original distribution:")
print(f" Single: {len(single):>8}")
print(f" Two: {len(two_word):>8}")
print(f" Multi: {len(multi):>8}")
# Extract vocabulary for synthetic generation
vocab = extract_vocabulary(entries)
print(f" Vocabulary size: {len(vocab)}")
# Calculate target counts
# Use multi-word count as anchor (keep all existing multi-word)
existing_multi = len(multi)
# Target: enough examples that each category is well-represented
# Use the multi count scaled up as reference
target_multi = max(existing_multi, 100_000)
target_two = int(target_multi * target_two_ratio / target_multi_ratio)
target_single = int(target_multi * target_single_ratio / target_multi_ratio)
print(f"\nTarget counts:")
print(f" Single: {target_single:>8}")
print(f" Two: {target_two:>8}")
print(f" Multi: {target_multi:>8}")
# Generate synthetic two-word pairs
needed_two = max(0, target_two - len(two_word))
if needed_two > 0:
print(f"\nGenerating {needed_two} synthetic two-word pairs...")
synthetic_two = generate_two_word_pairs(vocab, dialect, needed_two)
two_word.extend(synthetic_two)
print(f" Generated: {len(synthetic_two)}")
# Generate synthetic multi-word phrases (3, 4, 5 words)
needed_multi = max(0, target_multi - len(multi))
if needed_multi > 0:
print(
f"\nGenerating {needed_multi} synthetic multi-word phrases (3-5 words)..."
)
synthetic_multi = generate_multi_word_phrases(vocab, dialect, needed_multi)
multi.extend(synthetic_multi)
print(f" Generated: {len(synthetic_multi)}")
# Downsample single-word examples
if len(single) > target_single:
print(f"\nDownsampling single-word from {len(single)} to {target_single}...")
random.shuffle(single)
single = single[:target_single]
# Combine
balanced = single + two_word + multi
random.shuffle(balanced)
print(f"\nFinal distribution:")
final_single = sum(
1 for e in balanced if len(e["transliteration"]["src"].split()) == 1
)
final_two = sum(
1 for e in balanced if len(e["transliteration"]["src"].split()) == 2
)
final_multi = sum(
1 for e in balanced if len(e["transliteration"]["src"].split()) >= 3
)
total = len(balanced)
print(f" Single: {final_single:>8} ({100*final_single/total:.1f}%)")
print(f" Two: {final_two:>8} ({100*final_two/total:.1f}%)")
print(f" Multi: {final_multi:>8} ({100*final_multi/total:.1f}%)")
print(f" Total: {total:>8}")
return balanced
def main():
data_dir = Path(__file__).parent
for dialect in ["west", "east"]:
print(f"\n{'='*60}")
print(f"Processing {dialect.capitalize()} dialect")
print("=" * 60)
input_path = data_dir / f"syriac_{dialect}_augmented_corpus.jsonl"
output_path = data_dir / f"syriac_{dialect}_balanced_corpus.jsonl"
if not input_path.exists():
print(f" Skipping - {input_path} not found")
continue
entries = load_corpus(input_path)
balanced = balance_corpus(entries, dialect)
# Write output
with open(output_path, "w") as f:
for entry in balanced:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"\nWritten to: {output_path}")
if __name__ == "__main__":
random.seed(42)
main()