daa-tokenizers / dataset_stats.py
Ouaill's picture
Upload dataset_stats.py with huggingface_hub
ec4201a verified
Raw
History Blame Contribute Delete
9.51 kB
#!/usr/bin/env python3 -u
"""
dataset_stats.py — Load each Darija dataset from HuggingFace and compute:
- min / max / mean / median sentence length (characters)
- count of Arabic / Arabizi / Mixed sentences
"""
import json, csv, os, gc, warnings, statistics, random
import regex
warnings.filterwarnings("ignore")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
OUTPUT_DIR = "/root/oiq_cc_tokenizer/results"
_AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]")
_LAT_PAT = regex.compile(r"[a-zA-Z]")
DATASETS = [
{
"name": "DODa",
"repo": "atlasia/DODa",
"split": "train",
"text_col": None,
"config": None,
"max_rows": 0, # 0 = all
},
{
"name": "Darija-Wiki",
"repo": "atlasia/Moroccan-Darija-Wiki-Dataset",
"split": "train",
"text_col": None,
"config": None,
"max_rows": 0,
},
{
"name": "Atlaset",
"repo": "atlasia/Atlaset",
"split": "train",
"text_col": None,
"config": None,
"max_rows": 0, # all rows
},
{
"name": "Ours (daa-pairs)",
"repo": "OiQ/daa-pairs",
"split": "train",
"text_col": None,
"config": None,
"max_rows": 0,
},
]
class IncrementalStats:
"""Compute min/max/mean/median without storing all values."""
def __init__(self):
self.count = 0
self._min = float('inf')
self._max = 0
self._sum = 0
self._samples = []
self._reservoir_size = 10000
self._seen = 0
def add(self, val):
self.count += 1
self._seen += 1
self._min = min(self._min, val)
self._max = max(self._max, val)
self._sum += val
if len(self._samples) < self._reservoir_size:
self._samples.append(val)
else:
j = random.randint(0, self._seen - 1)
if j < self._reservoir_size:
self._samples[j] = val
def finalize(self):
return {
"count": self.count,
"min": self._min if self._min != float('inf') else 0,
"max": self._max,
"mean": round(self._sum / max(self.count, 1), 2),
"median": statistics.median(self._samples) if self._samples else 0,
}
def classify_script(text: str) -> str:
ar_chars = len(_AR_PAT.findall(text))
lat_chars = len(_LAT_PAT.findall(text))
total_alpha = ar_chars + lat_chars
if total_alpha == 0:
return "ar"
ar_ratio = ar_chars / total_alpha
lat_ratio = lat_chars / total_alpha
if ar_ratio > 0.9 and lat_ratio < 0.1:
return "ar"
elif lat_ratio > 0.9 and ar_ratio < 0.1:
return "az"
else:
return "mi"
def find_text_column_from_cols(cols, name: str) -> str:
candidates = ["text", "darija", "arabic", "sentence", "word",
"content", "src", "source", "ar", "az", "mixed",
"darija_ar", "darija_az", "darija_mix"]
for c in candidates:
if c in cols:
print(f" [{name}] Using column '{c}' from {cols}")
return c
print(f" [{name}] Could not find text column! Columns: {cols}")
return None
def compute_stats(lengths_by_script: dict) -> dict:
result = {}
for script in ("ar", "az", "mi", "ALL"):
s = lengths_by_script.get(script)
if s is None:
result[script] = {"count": 0, "min": 0, "max": 0, "mean": 0, "median": 0}
else:
result[script] = s.finalize()
return result
def process_dataset(name: str, repo: str, split: str, text_col=None,
config=None, max_rows: int = 0) -> dict | None:
from datasets import load_dataset
print(f"\n{'='*80}")
print(f"Loading {name} ({repo})...", flush=True)
load_kwargs = {"token": HF_TOKEN}
if config:
load_kwargs["name"] = config
try:
ds = load_dataset(repo, split=split, streaming=True, **load_kwargs)
total_rows = None
print(f" Streaming mode enabled")
except Exception as e:
print(f" Streaming failed ({e}), trying regular mode...")
try:
ds = load_dataset(repo, split=split, **load_kwargs)
total_rows = len(ds)
except Exception as e2:
print(f" Failed to load {repo}: {e2}")
return None
# Discover columns
if total_rows is None:
peek = []
for row in ds:
peek.append(row)
if len(peek) >= 5:
break
cols = list(peek[0].keys()) if peek else []
else:
cols = ds.column_names
print(f" Columns: {cols}")
if text_col is None or text_col not in cols:
text_col = find_text_column_from_cols(cols, name)
if text_col is None:
return None
lengths_by_script = {
"ar": IncrementalStats(), "az": IncrementalStats(),
"mi": IncrementalStats(), "ALL": IncrementalStats(),
}
total_processed = 0
max_label = f" (max {max_rows:,})" if max_rows else ""
# Streaming path
if total_rows is None:
for row in ds:
if max_rows and total_processed >= max_rows:
print(f" Reached sample limit of {max_rows:,}", flush=True)
break
text = row.get(text_col)
if not isinstance(text, str) or len(text.strip()) < 2:
continue
sc = classify_script(text)
length = len(text)
lengths_by_script[sc].add(length)
lengths_by_script["ALL"].add(length)
total_processed += 1
if total_processed % 100000 == 0:
print(f" Processed {total_processed:,}{max_label}...", flush=True)
# Non-streaming path
else:
batch = 50000
for i in range(0, total_rows, batch):
chunk = ds[i:i+batch]
for text in chunk[text_col]:
if max_rows and total_processed >= max_rows:
break
if not isinstance(text, str) or len(text.strip()) < 2:
continue
sc = classify_script(text)
length = len(text)
lengths_by_script[sc].add(length)
lengths_by_script["ALL"].add(length)
total_processed += 1
if total_processed % 100000 == 0 or i + batch >= total_rows:
print(f" Processed {total_processed:,}/{total_rows:,}{max_label}", flush=True)
if max_rows and total_processed >= max_rows:
print(f" Reached sample limit of {max_rows:,}", flush=True)
break
stats = compute_stats(lengths_by_script)
stats["dataset"] = name
stats["repo"] = repo
stats["text_col"] = text_col
stats["total_rows"] = total_rows if total_rows else total_processed
stats["total_processed"] = total_processed
print(f"\n Results for {name}:")
for script in ("ar", "az", "mi", "ALL"):
s = stats[script]
print(f" {script.upper():>4}: n={s['count']:>8,} min={s['min']:>5} "
f"max={s['max']:>6} mean={s['mean']:>7.1f} median={s['median']:>5.0f}")
del ds
gc.collect()
return stats
def save_results(all_stats: list):
csv_path = os.path.join(OUTPUT_DIR, "dataset_stats.csv")
fieldnames = ["dataset", "repo", "text_col", "total_rows", "total_processed",
"ar_count", "ar_min", "ar_max", "ar_mean", "ar_median",
"az_count", "az_min", "az_max", "az_mean", "az_median",
"mi_count", "mi_min", "mi_max", "mi_mean", "mi_median",
"ALL_count", "ALL_min", "ALL_max", "ALL_mean", "ALL_median"]
with open(csv_path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
w.writeheader()
for s in all_stats:
row = {"dataset": s["dataset"], "repo": s["repo"], "text_col": s["text_col"],
"total_rows": s["total_rows"], "total_processed": s["total_processed"]}
for script in ("ar", "az", "mi", "ALL"):
for stat in ("count", "min", "max", "mean", "median"):
row[f"{script}_{stat}"] = s[script][stat]
w.writerow(row)
print(f"Saved CSV: {csv_path}")
json_path = os.path.join(OUTPUT_DIR, "dataset_stats.json")
with open(json_path, "w") as f:
json.dump(all_stats, f, indent=2, default=str)
print(f"Saved JSON: {json_path}")
def main():
random.seed(42)
all_stats = []
for cfg in DATASETS:
s = process_dataset(
cfg["name"], cfg["repo"], cfg["split"],
text_col=cfg.get("text_col"), config=cfg.get("config"),
max_rows=cfg.get("max_rows", 0),
)
if s:
all_stats.append(s)
save_results(all_stats) # incremental save
# Print summary table
print(f"\n{'='*120}")
print(f"{'Dataset':<22} {'Script':>6} {'Sentences':>10} {'Min':>6} {'Max':>7} "
f"{'Mean':>8} {'Median':>7}")
print("-" * 120)
for s in all_stats:
for script in ("ar", "az", "mi", "ALL"):
d = s[script]
label = "" if script == "ALL" else s["dataset"]
print(f"{label:<22} {script:>6} {d['count']:>10,} {d['min']:>6} "
f"{d['max']:>7} {d['mean']:>8.1f} {d['median']:>7.0f}")
print("-" * 120)
print("=" * 120)
if __name__ == "__main__":
main()