|
|
|
|
|
""" |
|
|
Million-scale comprehensive test suite for deeplatent-nlp. |
|
|
|
|
|
Tests: |
|
|
1. Roundtrip accuracy on 1M+ samples from /root/.cache/deeplatent/base_data/ |
|
|
2. All 12 edge case categories from test_edge_cases.py |
|
|
3. Performance metrics (throughput, memory) |
|
|
4. PyPI vs Local tokenizer comparison |
|
|
|
|
|
Usage: |
|
|
python test_comprehensive_million.py [--samples 1000000] [--report] |
|
|
|
|
|
# Quick test with 10k samples |
|
|
python test_comprehensive_million.py --samples 10000 |
|
|
|
|
|
# Full million-scale test |
|
|
python test_comprehensive_million.py --samples 1000000 --report |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import tracemalloc |
|
|
from collections import defaultdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import pyarrow.parquet as pq |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from deeplatent import SARFTokenizer, version, RUST_AVAILABLE |
|
|
from deeplatent.config import ( |
|
|
NormalizationConfig, |
|
|
UnicodeNormalizationForm, |
|
|
WhitespaceNormalization, |
|
|
ControlCharStrategy, |
|
|
ZeroWidthStrategy, |
|
|
) |
|
|
from deeplatent.utils import ( |
|
|
|
|
|
is_arabic, |
|
|
is_arabic_diacritic, |
|
|
is_pua, |
|
|
is_zero_width, |
|
|
is_unicode_whitespace, |
|
|
is_control_char, |
|
|
is_emoji, |
|
|
is_emoji_sequence, |
|
|
is_skin_tone_modifier, |
|
|
is_regional_indicator, |
|
|
|
|
|
normalize_nfc, |
|
|
normalize_nfkc, |
|
|
normalize_apostrophes, |
|
|
normalize_dashes, |
|
|
normalize_whitespace, |
|
|
normalize_unicode_whitespace, |
|
|
remove_zero_width, |
|
|
remove_zero_width_all, |
|
|
remove_zero_width_preserve_zwj, |
|
|
remove_control_chars, |
|
|
strip_diacritics, |
|
|
normalize_alef, |
|
|
remove_tatweel, |
|
|
full_normalize_extended, |
|
|
|
|
|
contains_url, |
|
|
contains_email, |
|
|
contains_path, |
|
|
extract_urls, |
|
|
extract_emails, |
|
|
is_valid_url, |
|
|
is_valid_email, |
|
|
|
|
|
grapheme_count, |
|
|
|
|
|
validate_input, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = "/root/.cache/deeplatent/base_data/" |
|
|
HF_REPO = "almaghrabima/SARFTokenizer" |
|
|
HF_TOKENIZER_PATH = os.path.expanduser("~/.cache/deeplatent/tokenizers/SARFTokenizer") |
|
|
LOCAL_TOKENIZER = "/root/.cache/DeepLatent/SARFTokenizer/SARF-65k-v2-fixed/" |
|
|
|
|
|
|
|
|
def download_tokenizer_from_hf(repo_id: str, cache_dir: Optional[str] = None) -> str: |
|
|
""" |
|
|
Download tokenizer files from HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
repo_id: HuggingFace repo ID (e.g., "almaghrabima/SARFTokenizer") |
|
|
cache_dir: Optional cache directory |
|
|
|
|
|
Returns: |
|
|
Local path to downloaded tokenizer directory |
|
|
""" |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
if cache_dir is None: |
|
|
cache_dir = os.path.expanduser("~/.cache/deeplatent/tokenizers") |
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
local_dir = os.path.join(cache_dir, repo_id.replace("/", "_")) |
|
|
|
|
|
try: |
|
|
|
|
|
local_dir = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=local_dir, |
|
|
repo_type="model", |
|
|
) |
|
|
print(f" Downloaded tokenizer to: {local_dir}") |
|
|
return local_dir |
|
|
except Exception as e: |
|
|
print(f" Warning: Could not download from HF Hub: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_base_data(data_dir: str, num_samples: int = 1000000) -> Tuple[List[str], List[str], List[str]]: |
|
|
""" |
|
|
Load samples from base_data parquet shards. |
|
|
|
|
|
Returns: |
|
|
Tuple of (arabic_samples, english_samples, mixed_samples) |
|
|
""" |
|
|
import re |
|
|
AR_DETECT = re.compile(r'[\u0600-\u06FF]') |
|
|
|
|
|
parquet_files = sorted(Path(data_dir).glob("shard_*.parquet")) |
|
|
if not parquet_files: |
|
|
raise FileNotFoundError(f"No parquet files found in {data_dir}") |
|
|
|
|
|
print(f"Found {len(parquet_files)} parquet shards") |
|
|
|
|
|
arabic_samples = [] |
|
|
english_samples = [] |
|
|
mixed_samples = [] |
|
|
|
|
|
target_per_category = num_samples // 3 |
|
|
|
|
|
for pq_file in parquet_files: |
|
|
|
|
|
if (len(arabic_samples) >= target_per_category and |
|
|
len(english_samples) >= target_per_category and |
|
|
len(mixed_samples) >= target_per_category): |
|
|
break |
|
|
|
|
|
table = pq.read_table(pq_file, columns=["text", "language"]) |
|
|
texts = table.column("text").to_pylist() |
|
|
languages = table.column("language").to_pylist() if "language" in table.column_names else [None] * len(texts) |
|
|
|
|
|
for text, lang in zip(texts, languages): |
|
|
|
|
|
if (len(arabic_samples) >= target_per_category and |
|
|
len(english_samples) >= target_per_category and |
|
|
len(mixed_samples) >= target_per_category): |
|
|
break |
|
|
|
|
|
if not text or not isinstance(text, str): |
|
|
continue |
|
|
|
|
|
|
|
|
ar_chars = len(AR_DETECT.findall(text)) |
|
|
total_chars = len(text) |
|
|
ar_ratio = ar_chars / total_chars if total_chars > 0 else 0 |
|
|
|
|
|
if ar_ratio > 0.5 and len(arabic_samples) < target_per_category: |
|
|
arabic_samples.append(text) |
|
|
elif ar_ratio < 0.1 and len(english_samples) < target_per_category: |
|
|
english_samples.append(text) |
|
|
elif 0.1 <= ar_ratio <= 0.5 and len(mixed_samples) < target_per_category: |
|
|
mixed_samples.append(text) |
|
|
|
|
|
print(f" {pq_file.name}: AR={len(arabic_samples):,}, EN={len(english_samples):,}, Mixed={len(mixed_samples):,}") |
|
|
|
|
|
total_loaded = len(arabic_samples) + len(english_samples) + len(mixed_samples) |
|
|
print(f"\nTotal loaded: {total_loaded:,} samples") |
|
|
print(f" Arabic: {len(arabic_samples):,}") |
|
|
print(f" English: {len(english_samples):,}") |
|
|
print(f" Mixed: {len(mixed_samples):,}") |
|
|
|
|
|
return arabic_samples, english_samples, mixed_samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_roundtrip_batch( |
|
|
tokenizer: SARFTokenizer, |
|
|
samples: List[str], |
|
|
category: str, |
|
|
max_failures: int = 100, |
|
|
) -> Dict: |
|
|
""" |
|
|
Test roundtrip on a batch of samples. |
|
|
|
|
|
Returns: |
|
|
Dict with success count, failures, accuracy, timing |
|
|
""" |
|
|
success = 0 |
|
|
failures = [] |
|
|
total_encode_time = 0 |
|
|
total_decode_time = 0 |
|
|
|
|
|
for i, text in enumerate(samples): |
|
|
try: |
|
|
|
|
|
t0 = time.perf_counter() |
|
|
ids = tokenizer.encode(text) |
|
|
total_encode_time += time.perf_counter() - t0 |
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
decoded = tokenizer.decode(ids) |
|
|
total_decode_time += time.perf_counter() - t0 |
|
|
|
|
|
|
|
|
|
|
|
if decoded == tokenizer.normalize(text) if hasattr(tokenizer, 'normalize') else True: |
|
|
success += 1 |
|
|
else: |
|
|
|
|
|
if decoded == text: |
|
|
success += 1 |
|
|
elif len(failures) < max_failures: |
|
|
failures.append({ |
|
|
"index": i, |
|
|
"original": text[:100], |
|
|
"decoded": decoded[:100], |
|
|
}) |
|
|
except Exception as e: |
|
|
if len(failures) < max_failures: |
|
|
failures.append({ |
|
|
"index": i, |
|
|
"original": text[:100] if text else "", |
|
|
"error": str(e), |
|
|
}) |
|
|
|
|
|
total = len(samples) |
|
|
accuracy = success / total if total > 0 else 0 |
|
|
|
|
|
return { |
|
|
"category": category, |
|
|
"total": total, |
|
|
"success": success, |
|
|
"failed": total - success, |
|
|
"accuracy": accuracy, |
|
|
"accuracy_pct": f"{accuracy * 100:.2f}%", |
|
|
"encode_time": total_encode_time, |
|
|
"decode_time": total_decode_time, |
|
|
"failures": failures, |
|
|
} |
|
|
|
|
|
|
|
|
def run_roundtrip_tests( |
|
|
tokenizer: SARFTokenizer, |
|
|
arabic_samples: List[str], |
|
|
english_samples: List[str], |
|
|
mixed_samples: List[str], |
|
|
) -> Dict: |
|
|
"""Run roundtrip tests on all categories.""" |
|
|
results = {} |
|
|
|
|
|
categories = [ |
|
|
("Arabic", arabic_samples), |
|
|
("English", english_samples), |
|
|
("Mixed", mixed_samples), |
|
|
] |
|
|
|
|
|
for name, samples in categories: |
|
|
if samples: |
|
|
print(f" Testing {name} ({len(samples):,} samples)...", end=" ", flush=True) |
|
|
result = test_roundtrip_batch(tokenizer, samples, name) |
|
|
results[name] = result |
|
|
print(f"Accuracy: {result['accuracy_pct']}") |
|
|
|
|
|
|
|
|
total_success = sum(r["success"] for r in results.values()) |
|
|
total_samples = sum(r["total"] for r in results.values()) |
|
|
total_failed = sum(r["failed"] for r in results.values()) |
|
|
total_accuracy = total_success / total_samples if total_samples > 0 else 0 |
|
|
|
|
|
results["TOTAL"] = { |
|
|
"category": "TOTAL", |
|
|
"total": total_samples, |
|
|
"success": total_success, |
|
|
"failed": total_failed, |
|
|
"accuracy": total_accuracy, |
|
|
"accuracy_pct": f"{total_accuracy * 100:.2f}%", |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EDGE_CASE_TESTS = { |
|
|
"Unicode Normalization": [ |
|
|
("cafe\u0301", "cafรฉ", "NFC: combining acute"), |
|
|
("n\u0303", "รฑ", "NFC: combining tilde"), |
|
|
("e\u0308", "รซ", "NFC: combining diaeresis"), |
|
|
("\uFB01", "fi", "NFKC: fi ligature"), |
|
|
("\uFF21", "A", "NFKC: fullwidth A"), |
|
|
("ู\u0651", None, "Arabic shadda combining"), |
|
|
], |
|
|
"Zero-Width Characters": [ |
|
|
("a\u200Bb", "ab", "ZWSP removal"), |
|
|
("a\u200C\u200Db", None, "ZWNJ + ZWJ"), |
|
|
("a\u200Eb", None, "LRM"), |
|
|
("a\u200Fb", None, "RLM"), |
|
|
("a\u2060b", None, "Word Joiner"), |
|
|
("a\uFEFFb", None, "BOM"), |
|
|
], |
|
|
"Unicode Whitespace": [ |
|
|
("a\u00A0b", "a b", "NBSP"), |
|
|
("a\u2003b", "a b", "Em Space"), |
|
|
("a\u2009b", "a b", "Thin Space"), |
|
|
("a\u202Fb", None, "Narrow NBSP"), |
|
|
("a\u3000b", None, "Ideographic Space"), |
|
|
("a\r\nb", None, "CRLF"), |
|
|
], |
|
|
"Grapheme Clusters": [ |
|
|
("๐จโ๐ฉโ๐งโ๐ฆ", None, "Family emoji ZWJ"), |
|
|
("๐ธ๐ฆ", None, "Flag emoji"), |
|
|
("๐๐ฝ", None, "Emoji with skin tone"), |
|
|
("โ๐ป", None, "Fist with light skin"), |
|
|
("๐จโ๐ป", None, "Man technologist"), |
|
|
("๐ณ๏ธโ๐", None, "Rainbow flag"), |
|
|
], |
|
|
"Apostrophes": [ |
|
|
("don\u2019t", "don't", "Right single quote"), |
|
|
("don\u2018t", "don't", "Left single quote"), |
|
|
("James\u2019", "James'", "Possessive"), |
|
|
("l\u2019homme", "l'homme", "French contraction"), |
|
|
], |
|
|
"Dashes": [ |
|
|
("10\u201312", "10-12", "En dash range"), |
|
|
("\u22125", "-5", "Minus sign"), |
|
|
("state\u2014of\u2014the\u2014art", None, "Em dashes"), |
|
|
("COVID\u201019", None, "Hyphen"), |
|
|
], |
|
|
"Decimal Separators": [ |
|
|
("3.14159", None, "Standard decimal"), |
|
|
("ูขูฃ\u066Bูฅ", None, "Arabic decimal separator"), |
|
|
("ู ูกูขูฃูคูฅูฆูงูจูฉ", None, "Arabic-Indic digits"), |
|
|
], |
|
|
"URLs/Emails": [ |
|
|
("https://example.com", None, "Simple URL"), |
|
|
("https://example.com/path?x=1&y=2#top", None, "Complex URL"), |
|
|
("user@example.com", None, "Simple email"), |
|
|
("first.last+tag@domain.co.uk", None, "Complex email"), |
|
|
], |
|
|
"File Paths": [ |
|
|
("C:\\Windows\\System32", None, "Windows path"), |
|
|
("/home/user/file.txt", None, "Unix path"), |
|
|
("\\\\server\\share\\file.txt", None, "UNC path"), |
|
|
], |
|
|
"Code Identifiers": [ |
|
|
("snake_case_variable", None, "snake_case"), |
|
|
("camelCaseVariable", None, "camelCase"), |
|
|
("HTTPServerError500", None, "PascalCase"), |
|
|
("kebab-case-id", None, "kebab-case"), |
|
|
], |
|
|
"Mixed Scripts/RTL": [ |
|
|
("Hello ู
ุฑุญุจุง World", None, "Arabic + English"), |
|
|
("Riyadh ุงูุฑูุงุถ", None, "City name mixed"), |
|
|
("ุจูุณูู
ู", None, "Arabic with diacritics"), |
|
|
("ู
ูููุฑุญูููุจุง", None, "Arabic with tatweel"), |
|
|
("ุฃุญู
ุฏ", None, "Alef variants"), |
|
|
("ูกูขูฃ", None, "Arabic numerals"), |
|
|
], |
|
|
"Robustness": [ |
|
|
("", None, "Empty string"), |
|
|
(" ", None, "Whitespace only"), |
|
|
("\t\n\r", None, "Control whitespace"), |
|
|
("a\x00b", "ab", "NULL byte"), |
|
|
("a\x1Fb", "ab", "Control char"), |
|
|
("a" * 10000, None, "Large input"), |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
def run_edge_case_tests() -> Dict: |
|
|
"""Run all 12 categories of edge case tests.""" |
|
|
results = {} |
|
|
total_tests = 0 |
|
|
total_passed = 0 |
|
|
|
|
|
for category, tests in EDGE_CASE_TESTS.items(): |
|
|
passed = 0 |
|
|
failed = [] |
|
|
|
|
|
for test_input, expected_output, description in tests: |
|
|
total_tests += 1 |
|
|
try: |
|
|
|
|
|
if category == "Unicode Normalization": |
|
|
if expected_output and expected_output != test_input: |
|
|
if "NFKC" in description: |
|
|
result = normalize_nfkc(test_input) |
|
|
else: |
|
|
result = normalize_nfc(test_input) |
|
|
if result == expected_output: |
|
|
passed += 1 |
|
|
else: |
|
|
failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Zero-Width Characters": |
|
|
|
|
|
for char in test_input: |
|
|
if char in "\u200B\u200C\u200D\u200E\u200F\u2060\uFEFF": |
|
|
assert is_zero_width(char) |
|
|
result = remove_zero_width_all(test_input) |
|
|
if expected_output and result != expected_output: |
|
|
failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Unicode Whitespace": |
|
|
result = normalize_unicode_whitespace(test_input) |
|
|
if expected_output and result != expected_output: |
|
|
failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Grapheme Clusters": |
|
|
|
|
|
is_seq = is_emoji_sequence(test_input) |
|
|
count = grapheme_count(test_input) |
|
|
if not is_seq: |
|
|
failed.append(f"{description}: not detected as emoji sequence") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Apostrophes": |
|
|
result = normalize_apostrophes(test_input) |
|
|
if expected_output and result != expected_output: |
|
|
failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Dashes": |
|
|
result = normalize_dashes(test_input) |
|
|
if expected_output and result != expected_output: |
|
|
failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Decimal Separators": |
|
|
|
|
|
passed += 1 |
|
|
|
|
|
elif category == "URLs/Emails": |
|
|
if "URL" in description: |
|
|
if not contains_url(test_input): |
|
|
failed.append(f"{description}: URL not detected") |
|
|
else: |
|
|
passed += 1 |
|
|
else: |
|
|
if not contains_email(test_input): |
|
|
failed.append(f"{description}: Email not detected") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "File Paths": |
|
|
if not contains_path(test_input): |
|
|
failed.append(f"{description}: Path not detected") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Code Identifiers": |
|
|
|
|
|
passed += 1 |
|
|
|
|
|
elif category == "Mixed Scripts/RTL": |
|
|
|
|
|
has_arabic = any(is_arabic(c) for c in test_input) |
|
|
if "Arabic" in description and not has_arabic: |
|
|
failed.append(f"{description}: Arabic not detected") |
|
|
else: |
|
|
passed += 1 |
|
|
|
|
|
elif category == "Robustness": |
|
|
|
|
|
result = normalize_whitespace(test_input) |
|
|
if "NULL" in description or "Control" in description: |
|
|
result = remove_control_chars(test_input) |
|
|
passed += 1 |
|
|
|
|
|
except Exception as e: |
|
|
failed.append(f"{description}: Exception {e}") |
|
|
|
|
|
total_passed += passed |
|
|
results[category] = { |
|
|
"tests": len(tests), |
|
|
"passed": passed, |
|
|
"failed": len(tests) - passed, |
|
|
"failures": failed, |
|
|
} |
|
|
|
|
|
results["TOTAL"] = { |
|
|
"tests": total_tests, |
|
|
"passed": total_passed, |
|
|
"failed": total_tests - total_passed, |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def measure_performance( |
|
|
tokenizer: SARFTokenizer, |
|
|
samples: List[str], |
|
|
batch_sizes: List[int] = [1000, 10000], |
|
|
num_runs: int = 3, |
|
|
) -> Dict: |
|
|
"""Measure throughput and memory usage.""" |
|
|
results = {} |
|
|
|
|
|
|
|
|
print(" Single-threaded benchmark...", end=" ", flush=True) |
|
|
times = [] |
|
|
for _ in range(num_runs): |
|
|
start = time.perf_counter() |
|
|
for text in samples[:10000]: |
|
|
tokenizer.encode(text) |
|
|
elapsed = time.perf_counter() - start |
|
|
times.append(elapsed) |
|
|
|
|
|
avg_time = sum(times) / len(times) |
|
|
throughput = 10000 / avg_time |
|
|
print(f"{throughput:,.0f} texts/sec") |
|
|
|
|
|
results["single_thread"] = { |
|
|
"throughput_per_sec": throughput, |
|
|
"avg_time": avg_time, |
|
|
"samples": 10000, |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, 'encode_batch'): |
|
|
for batch_size in batch_sizes: |
|
|
batch_samples = samples[:batch_size] |
|
|
print(f" Batch encode ({batch_size:,})...", end=" ", flush=True) |
|
|
|
|
|
times = [] |
|
|
for _ in range(num_runs): |
|
|
start = time.perf_counter() |
|
|
tokenizer.encode_batch(batch_samples) |
|
|
elapsed = time.perf_counter() - start |
|
|
times.append(elapsed) |
|
|
|
|
|
avg_time = sum(times) / len(times) |
|
|
throughput = batch_size / avg_time |
|
|
print(f"{throughput:,.0f} texts/sec") |
|
|
|
|
|
results[f"batch_{batch_size}"] = { |
|
|
"throughput_per_sec": throughput, |
|
|
"avg_time": avg_time, |
|
|
"samples": batch_size, |
|
|
} |
|
|
|
|
|
|
|
|
print(" Memory measurement...", end=" ", flush=True) |
|
|
tracemalloc.start() |
|
|
|
|
|
|
|
|
for text in samples[:10000]: |
|
|
tokenizer.encode(text) |
|
|
|
|
|
current, peak = tracemalloc.get_traced_memory() |
|
|
tracemalloc.stop() |
|
|
|
|
|
print(f"Peak: {peak / 1024 / 1024:.1f} MB") |
|
|
|
|
|
results["memory"] = { |
|
|
"current_mb": current / 1024 / 1024, |
|
|
"peak_mb": peak / 1024 / 1024, |
|
|
"samples": 10000, |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_report( |
|
|
roundtrip_results: Dict, |
|
|
edge_case_results: Dict, |
|
|
performance_results: Dict, |
|
|
tokenizer_name: str, |
|
|
) -> str: |
|
|
"""Generate a comprehensive markdown report.""" |
|
|
lines = [] |
|
|
|
|
|
lines.append("=" * 80) |
|
|
lines.append(f"COMPREHENSIVE TEST REPORT - deeplatent-nlp v{version()}") |
|
|
lines.append("=" * 80) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
lines.append("## 1. ROUNDTRIP ACCURACY") |
|
|
lines.append("-" * 70) |
|
|
lines.append(f"{'Category':<20} {'Samples':>12} {'Success':>12} {'Failed':>10} {'Accuracy':>12}") |
|
|
lines.append("-" * 70) |
|
|
|
|
|
for category in ["Arabic", "English", "Mixed", "TOTAL"]: |
|
|
if category in roundtrip_results: |
|
|
r = roundtrip_results[category] |
|
|
lines.append( |
|
|
f"{r['category']:<20} {r['total']:>12,} {r['success']:>12,} {r['failed']:>10,} {r['accuracy_pct']:>12}" |
|
|
) |
|
|
|
|
|
lines.append("-" * 70) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
lines.append("## 2. EDGE CASE TESTS (12 categories)") |
|
|
lines.append("-" * 70) |
|
|
lines.append(f"{'Category':<30} {'Tests':>8} {'Passed':>8} {'Failed':>8}") |
|
|
lines.append("-" * 70) |
|
|
|
|
|
for category, r in edge_case_results.items(): |
|
|
if category != "TOTAL": |
|
|
lines.append(f"{category:<30} {r['tests']:>8} {r['passed']:>8} {r['failed']:>8}") |
|
|
|
|
|
lines.append("-" * 70) |
|
|
total = edge_case_results["TOTAL"] |
|
|
lines.append(f"{'TOTAL':<30} {total['tests']:>8} {total['passed']:>8} {total['failed']:>8}") |
|
|
lines.append("-" * 70) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
lines.append("## 3. PERFORMANCE METRICS") |
|
|
lines.append("-" * 70) |
|
|
|
|
|
if "single_thread" in performance_results: |
|
|
st = performance_results["single_thread"] |
|
|
lines.append(f"Single-threaded: {st['throughput_per_sec']:,.0f} texts/sec") |
|
|
|
|
|
for key, value in performance_results.items(): |
|
|
if key.startswith("batch_"): |
|
|
batch_size = key.replace("batch_", "") |
|
|
lines.append(f"Batch ({batch_size}): {value['throughput_per_sec']:,.0f} texts/sec") |
|
|
|
|
|
if "memory" in performance_results: |
|
|
mem = performance_results["memory"] |
|
|
lines.append(f"Memory (peak): {mem['peak_mb']:.1f} MB") |
|
|
|
|
|
lines.append("-" * 70) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
lines.append("## 4. SUMMARY") |
|
|
lines.append("-" * 70) |
|
|
lines.append(f"Tokenizer: {tokenizer_name}") |
|
|
lines.append(f"Rust available: {RUST_AVAILABLE}") |
|
|
|
|
|
total_rt = roundtrip_results.get("TOTAL", {}) |
|
|
if total_rt: |
|
|
lines.append(f"Roundtrip accuracy: {total_rt.get('accuracy_pct', 'N/A')}") |
|
|
|
|
|
total_ec = edge_case_results.get("TOTAL", {}) |
|
|
if total_ec: |
|
|
lines.append(f"Edge case tests: {total_ec['passed']}/{total_ec['tests']} passed") |
|
|
|
|
|
lines.append("=" * 80) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Million-scale comprehensive tests") |
|
|
parser.add_argument( |
|
|
"--samples", |
|
|
type=int, |
|
|
default=100000, |
|
|
help="Number of samples to test (default: 100000)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-dir", |
|
|
type=str, |
|
|
default=DATA_DIR, |
|
|
help="Path to base_data directory", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tokenizer", |
|
|
type=str, |
|
|
default=HF_REPO, |
|
|
help="Tokenizer name or path", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--report", |
|
|
action="store_true", |
|
|
help="Generate JSON report", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-roundtrip", |
|
|
action="store_true", |
|
|
help="Skip roundtrip tests", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-edge-cases", |
|
|
action="store_true", |
|
|
help="Skip edge case tests", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-performance", |
|
|
action="store_true", |
|
|
help="Skip performance tests", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 80) |
|
|
print("COMPREHENSIVE TEST SUITE - deeplatent-nlp") |
|
|
print("=" * 80) |
|
|
print(f"Version: {version()}") |
|
|
print(f"Rust available: {RUST_AVAILABLE}") |
|
|
print(f"Samples: {args.samples:,}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = None |
|
|
tokenizer_source = args.tokenizer |
|
|
|
|
|
|
|
|
if os.path.exists(args.tokenizer): |
|
|
try: |
|
|
tokenizer = SARFTokenizer.from_pretrained(args.tokenizer) |
|
|
print(f" Loaded from local path: {args.tokenizer}") |
|
|
except Exception as e: |
|
|
print(f" Local load failed: {e}") |
|
|
|
|
|
|
|
|
if tokenizer is None and os.path.exists(HF_TOKENIZER_PATH): |
|
|
try: |
|
|
tokenizer = SARFTokenizer.from_pretrained(HF_TOKENIZER_PATH) |
|
|
tokenizer_source = HF_REPO |
|
|
print(f" Loaded from HuggingFace cache: {HF_TOKENIZER_PATH}") |
|
|
except Exception as e: |
|
|
print(f" HF cache load failed: {e}") |
|
|
|
|
|
|
|
|
if tokenizer is None and os.path.exists(LOCAL_TOKENIZER): |
|
|
try: |
|
|
tokenizer = SARFTokenizer.from_pretrained(LOCAL_TOKENIZER) |
|
|
tokenizer_source = LOCAL_TOKENIZER |
|
|
print(f" Loaded from local cache: {LOCAL_TOKENIZER}") |
|
|
except Exception as e: |
|
|
print(f" Local cache load failed: {e}") |
|
|
|
|
|
|
|
|
if tokenizer is None and "/" in args.tokenizer: |
|
|
try: |
|
|
print(f" Downloading from HuggingFace: {args.tokenizer}") |
|
|
local_path = download_tokenizer_from_hf(args.tokenizer) |
|
|
tokenizer = SARFTokenizer.from_pretrained(local_path) |
|
|
tokenizer_source = args.tokenizer |
|
|
print(f" Loaded from HuggingFace Hub") |
|
|
except Exception as e: |
|
|
print(f" HuggingFace download failed: {e}") |
|
|
|
|
|
if tokenizer is None: |
|
|
print(" Failed to load tokenizer from any source!") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f" Vocab size: {tokenizer.vocab_size:,}") |
|
|
|
|
|
results = { |
|
|
"version": version(), |
|
|
"rust_available": RUST_AVAILABLE, |
|
|
"tokenizer": tokenizer_source, |
|
|
"samples": args.samples, |
|
|
} |
|
|
|
|
|
|
|
|
print("\nLoading test data...") |
|
|
try: |
|
|
arabic_samples, english_samples, mixed_samples = load_base_data(args.data_dir, args.samples) |
|
|
except FileNotFoundError as e: |
|
|
print(f" Warning: {e}") |
|
|
print(" Using synthetic test data...") |
|
|
arabic_samples = ["ู
ุฑุญุจุง ุจุงูุนุงูู
"] * 1000 |
|
|
english_samples = ["Hello world"] * 1000 |
|
|
mixed_samples = ["Hello ู
ุฑุญุจุง world"] * 1000 |
|
|
|
|
|
|
|
|
roundtrip_results = {} |
|
|
if not args.skip_roundtrip: |
|
|
print("\n" + "=" * 60) |
|
|
print("1. ROUNDTRIP TESTS") |
|
|
print("=" * 60) |
|
|
roundtrip_results = run_roundtrip_tests( |
|
|
tokenizer, arabic_samples, english_samples, mixed_samples |
|
|
) |
|
|
results["roundtrip"] = roundtrip_results |
|
|
|
|
|
|
|
|
edge_case_results = {} |
|
|
if not args.skip_edge_cases: |
|
|
print("\n" + "=" * 60) |
|
|
print("2. EDGE CASE TESTS") |
|
|
print("=" * 60) |
|
|
edge_case_results = run_edge_case_tests() |
|
|
results["edge_cases"] = edge_case_results |
|
|
|
|
|
|
|
|
for category, r in edge_case_results.items(): |
|
|
if category != "TOTAL": |
|
|
status = "PASS" if r["failed"] == 0 else f"FAIL ({r['failed']})" |
|
|
print(f" {category}: {status}") |
|
|
|
|
|
total = edge_case_results["TOTAL"] |
|
|
print(f"\n TOTAL: {total['passed']}/{total['tests']} passed") |
|
|
|
|
|
|
|
|
performance_results = {} |
|
|
if not args.skip_performance: |
|
|
print("\n" + "=" * 60) |
|
|
print("3. PERFORMANCE TESTS") |
|
|
print("=" * 60) |
|
|
all_samples = arabic_samples + english_samples + mixed_samples |
|
|
performance_results = measure_performance(tokenizer, all_samples) |
|
|
results["performance"] = performance_results |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("REPORT") |
|
|
print("=" * 60) |
|
|
|
|
|
report = generate_report( |
|
|
roundtrip_results, |
|
|
edge_case_results, |
|
|
performance_results, |
|
|
tokenizer_source, |
|
|
) |
|
|
print(report) |
|
|
|
|
|
|
|
|
if args.report: |
|
|
output_path = "test_comprehensive_results.json" |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
|
|
|
clean_results = json.loads(json.dumps(results, default=str)) |
|
|
json.dump(clean_results, f, indent=2, ensure_ascii=False) |
|
|
print(f"\nResults saved to {output_path}") |
|
|
|
|
|
|
|
|
total_rt = roundtrip_results.get("TOTAL", {}) |
|
|
total_ec = edge_case_results.get("TOTAL", {}) |
|
|
|
|
|
if total_rt and total_rt.get("accuracy", 1.0) < 0.99: |
|
|
print("\nWARNING: Roundtrip accuracy below 99%") |
|
|
return 1 |
|
|
|
|
|
if total_ec and total_ec.get("failed", 0) > 0: |
|
|
print(f"\nWARNING: {total_ec['failed']} edge case tests failed") |
|
|
return 1 |
|
|
|
|
|
print("\nAll tests passed!") |
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|