language-extractor-demo / sib200_cache.py
DerivedFunction1's picture
add
67bcd25
from __future__ import annotations
import argparse
import unicodedata
from functools import lru_cache
from pathlib import Path
from typing import Any
import pandas as pd
from datasets import get_dataset_config_names, load_dataset
import pycountry
from tqdm.auto import tqdm
from language import ALL_LANGS, LANG_ISO2_TO_ISO3, canonical_lang, is_latin_script_compatible
from sentence_sampling import sample_multi_group_bundle, sample_single_group_bundle
SIB200_DATASET = "Davlan/sib200"
SIB200_CACHE_DIR = Path(__file__).with_name("data") / "sib200"
SIB200_PARQUET_PATH = SIB200_CACHE_DIR / "sib200_text.parquet"
SIB200_SPLIT_ORDER = {"train": 0, "validation": 1, "test": 2}
def _normalize_text_key(text: str) -> str:
normalized = unicodedata.normalize("NFKC", text)
normalized = " ".join(normalized.split())
return normalized.casefold().strip()
def _normalize_source_lang(config_name: str) -> str:
base = (config_name or "").strip().split("_", 1)[0].lower()
if not base:
return ""
if len(base) == 3:
language = pycountry.languages.get(alpha_3=base)
if language is not None:
alpha_2 = getattr(language, "alpha_2", None)
if alpha_2:
return canonical_lang(alpha_2.lower())
language = canonical_lang(base)
return language if language in ALL_LANGS else base
def _normalize_split_name(split_name: str) -> str:
split = (split_name or "").strip().lower()
if split == "dev":
return "validation"
return split
def _row_to_sentence(row: pd.Series) -> dict[str, Any]:
source_lang = str(row.get("source_lang", "")).strip()
lang_iso2 = str(row.get("lang_iso2", "")).strip()
lang_iso3 = str(row.get("lang_iso3", "")).strip()
label = row.get("label", -1)
topic = str(row.get("topic", "")).strip()
return {
"text": str(row.get("text", "")).strip(),
"raw_text": str(row.get("text", "")).strip(),
"source": "sib200",
"source_lang": source_lang,
"lang_iso2": lang_iso2,
"lang_iso3": lang_iso3 or LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
"language": source_lang,
"split": str(row.get("split", "")).strip(),
"sib200_id": int(row.get("index_id", -1)) if str(row.get("index_id", "-1")).strip().lstrip("-").isdigit() else -1,
"sib200_label": int(label) if str(label).strip().lstrip("-").isdigit() else -1,
"sib200_topic": topic,
}
def _frame_from_dataset(config_name: str) -> pd.DataFrame:
try:
dataset = load_dataset(SIB200_DATASET, config_name)
except FileNotFoundError:
return pd.DataFrame()
if len(dataset) == 0:
return pd.DataFrame()
label_names: list[str] = []
for split_name in ("train", "validation", "test"):
if split_name in dataset and "label" in dataset[split_name].features:
label_names = list(dataset[split_name].features["label"].names)
break
records: list[dict[str, Any]] = []
source_lang = _normalize_source_lang(config_name)
if not source_lang:
return pd.DataFrame()
for split_name, split_ds in dataset.items():
normalized_split = _normalize_split_name(split_name)
for row in split_ds:
text = str(row.get("text", "")).strip()
if not text:
continue
label = row.get("label", -1)
label_int = int(label) if str(label).strip().lstrip("-").isdigit() else -1
topic = label_names[label_int] if 0 <= label_int < len(label_names) else ""
lang_iso2 = source_lang
records.append(
{
"index_id": int(row.get("index_id", -1)) if str(row.get("index_id", "-1")).strip().lstrip("-").isdigit() else -1,
"text": text,
"label": label_int,
"topic": topic,
"source_lang": config_name,
"lang_iso2": lang_iso2,
"lang_iso3": LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
"source": "sib200",
"split": normalized_split,
}
)
if not records:
return pd.DataFrame()
frame = pd.DataFrame.from_records(records)
frame["text_key"] = frame["text"].astype(str).map(_normalize_text_key)
frame["split_rank"] = frame["split"].map(lambda split: SIB200_SPLIT_ORDER.get(str(split), 99))
frame = frame.sort_values(by=["source_lang", "text_key", "split_rank", "index_id"], kind="stable")
frame = frame.drop_duplicates(subset=["source_lang", "text_key"], keep="first")
frame = frame.drop(columns=["text_key", "split_rank"], errors="ignore").reset_index(drop=True)
return frame
def build_sib200_text_parquet(parquet_path: str | Path = SIB200_PARQUET_PATH) -> Path:
"""Download SIB-200 and persist a lean parquet cache for offline sampling."""
parquet_path = Path(parquet_path)
parquet_path.parent.mkdir(parents=True, exist_ok=True)
config_names = get_dataset_config_names(SIB200_DATASET)
frames: list[pd.DataFrame] = []
for config_name in tqdm(config_names, desc="SIB-200 configs"):
frame = _frame_from_dataset(config_name)
if not frame.empty:
frames.append(frame)
else:
tqdm.write(f"Skipping SIB-200 config without a direct TSV layout: {config_name}")
if not frames:
raise RuntimeError("No usable SIB-200 rows were loaded.")
combined = pd.concat(frames, ignore_index=True)
combined["split_rank"] = combined["split"].map(lambda split: SIB200_SPLIT_ORDER.get(str(split), 99))
combined = combined.sort_values(by=["source_lang", "split_rank", "index_id"], kind="stable").reset_index(drop=True)
combined = combined.drop(columns=["split_rank"], errors="ignore")
combined.to_parquet(parquet_path, index=False)
print(
f"Built lean SIB-200 parquet with {len(combined):,} rows "
f"and {len(combined.columns)} columns at {parquet_path}."
)
return parquet_path
@lru_cache(maxsize=1)
def load_sib200_table(parquet_path: str | Path = SIB200_PARQUET_PATH) -> pd.DataFrame:
parquet_path = Path(parquet_path)
if not parquet_path.exists():
raise FileNotFoundError(
f"Missing SIB-200 cache at {parquet_path}. "
"Run `./.venv/bin/python sib200_cache.py` once while online to build it."
)
frame = pd.read_parquet(parquet_path)
if "text" not in frame.columns:
raise RuntimeError("SIB-200 parquet cache is missing the text column.")
return frame
def fetch_random_sib200_sentence(
*,
attempts: int = 8,
parquet_path: str | Path = SIB200_PARQUET_PATH,
) -> dict[str, Any]:
frame = load_sib200_table(parquet_path)
candidate_frame = frame[frame["lang_iso2"].isin(ALL_LANGS)] if "lang_iso2" in frame.columns else frame
if "source_lang" in candidate_frame.columns:
candidate_frame = candidate_frame[
candidate_frame.apply(
lambda row: is_latin_script_compatible(
str(row.get("lang_iso2", "")),
str(row.get("source_lang", "")),
),
axis=1,
)
]
return sample_single_group_bundle(
candidate_frame,
group_column="lang_iso2",
row_to_sentence=_row_to_sentence,
attempts=attempts,
)
def fetch_random_sib200_sentence_mix(
*,
min_groups: int = 2,
max_groups: int = 3,
parquet_path: str | Path = SIB200_PARQUET_PATH,
) -> dict[str, Any]:
frame = load_sib200_table(parquet_path)
candidate_frame = frame[frame["lang_iso2"].isin(ALL_LANGS)] if "lang_iso2" in frame.columns else frame
if "source_lang" in candidate_frame.columns:
candidate_frame = candidate_frame[
candidate_frame.apply(
lambda row: is_latin_script_compatible(
str(row.get("lang_iso2", "")),
str(row.get("source_lang", "")),
),
axis=1,
)
]
bundle = sample_multi_group_bundle(
candidate_frame,
group_column="lang_iso2",
row_to_sentence=_row_to_sentence,
min_groups=min_groups,
max_groups=max_groups,
)
return {
**bundle,
"source": "sib200-mix",
}
def main() -> None:
parser = argparse.ArgumentParser(description="Build the cached text-only SIB-200 parquet.")
parser.add_argument(
"--output",
default=str(SIB200_PARQUET_PATH),
help="Output parquet path for the cached SIB-200 text rows.",
)
args = parser.parse_args()
path = build_sib200_text_parquet(args.output)
print(f"Wrote SIB-200 text cache to {path}")
if __name__ == "__main__":
main()