Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import unicodedata | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| import pandas as pd | |
| from datasets import get_dataset_config_names, load_dataset | |
| import pycountry | |
| from tqdm.auto import tqdm | |
| from language import ALL_LANGS, LANG_ISO2_TO_ISO3, canonical_lang | |
| from sentence_sampling import sample_multi_group_bundle, sample_single_group_bundle | |
| SIB200_DATASET = "Davlan/sib200" | |
| SIB200_CACHE_DIR = Path(__file__).with_name("data") / "sib200" | |
| SIB200_PARQUET_PATH = SIB200_CACHE_DIR / "sib200_text.parquet" | |
| SIB200_SPLIT_ORDER = {"train": 0, "validation": 1, "test": 2} | |
| def _normalize_text_key(text: str) -> str: | |
| normalized = unicodedata.normalize("NFKC", text) | |
| normalized = " ".join(normalized.split()) | |
| return normalized.casefold().strip() | |
| def _normalize_source_lang(config_name: str) -> str: | |
| base = (config_name or "").strip().split("_", 1)[0].lower() | |
| if not base: | |
| return "" | |
| if len(base) == 3: | |
| language = pycountry.languages.get(alpha_3=base) | |
| if language is not None: | |
| alpha_2 = getattr(language, "alpha_2", None) | |
| if alpha_2: | |
| return canonical_lang(alpha_2.lower()) | |
| language = canonical_lang(base) | |
| return language if language in ALL_LANGS else base | |
| def _normalize_split_name(split_name: str) -> str: | |
| split = (split_name or "").strip().lower() | |
| if split == "dev": | |
| return "validation" | |
| return split | |
| def _row_to_sentence(row: pd.Series) -> dict[str, Any]: | |
| source_lang = str(row.get("source_lang", "")).strip() | |
| lang_iso2 = str(row.get("lang_iso2", "")).strip() | |
| lang_iso3 = str(row.get("lang_iso3", "")).strip() | |
| label = row.get("label", -1) | |
| topic = str(row.get("topic", "")).strip() | |
| return { | |
| "text": str(row.get("text", "")).strip(), | |
| "raw_text": str(row.get("text", "")).strip(), | |
| "source": "sib200", | |
| "source_lang": source_lang, | |
| "lang_iso2": lang_iso2, | |
| "lang_iso3": lang_iso3 or LANG_ISO2_TO_ISO3.get(lang_iso2, ""), | |
| "language": source_lang, | |
| "split": str(row.get("split", "")).strip(), | |
| "sib200_id": int(row.get("index_id", -1)) if str(row.get("index_id", "-1")).strip().lstrip("-").isdigit() else -1, | |
| "sib200_label": int(label) if str(label).strip().lstrip("-").isdigit() else -1, | |
| "sib200_topic": topic, | |
| } | |
| def _frame_from_dataset(config_name: str) -> pd.DataFrame: | |
| try: | |
| dataset = load_dataset(SIB200_DATASET, config_name) | |
| except FileNotFoundError: | |
| return pd.DataFrame() | |
| if len(dataset) == 0: | |
| return pd.DataFrame() | |
| label_names: list[str] = [] | |
| for split_name in ("train", "validation", "test"): | |
| if split_name in dataset and "label" in dataset[split_name].features: | |
| label_names = list(dataset[split_name].features["label"].names) | |
| break | |
| records: list[dict[str, Any]] = [] | |
| source_lang = _normalize_source_lang(config_name) | |
| if not source_lang: | |
| return pd.DataFrame() | |
| for split_name, split_ds in dataset.items(): | |
| normalized_split = _normalize_split_name(split_name) | |
| for row in split_ds: | |
| text = str(row.get("text", "")).strip() | |
| if not text: | |
| continue | |
| label = row.get("label", -1) | |
| label_int = int(label) if str(label).strip().lstrip("-").isdigit() else -1 | |
| topic = label_names[label_int] if 0 <= label_int < len(label_names) else "" | |
| lang_iso2 = source_lang | |
| records.append( | |
| { | |
| "index_id": int(row.get("index_id", -1)) if str(row.get("index_id", "-1")).strip().lstrip("-").isdigit() else -1, | |
| "text": text, | |
| "label": label_int, | |
| "topic": topic, | |
| "source_lang": config_name, | |
| "lang_iso2": lang_iso2, | |
| "lang_iso3": LANG_ISO2_TO_ISO3.get(lang_iso2, ""), | |
| "source": "sib200", | |
| "split": normalized_split, | |
| } | |
| ) | |
| if not records: | |
| return pd.DataFrame() | |
| frame = pd.DataFrame.from_records(records) | |
| frame["text_key"] = frame["text"].astype(str).map(_normalize_text_key) | |
| frame["split_rank"] = frame["split"].map(lambda split: SIB200_SPLIT_ORDER.get(str(split), 99)) | |
| frame = frame.sort_values(by=["source_lang", "text_key", "split_rank", "index_id"], kind="stable") | |
| frame = frame.drop_duplicates(subset=["source_lang", "text_key"], keep="first") | |
| frame = frame.drop(columns=["text_key", "split_rank"], errors="ignore").reset_index(drop=True) | |
| return frame | |
| def build_sib200_text_parquet(parquet_path: str | Path = SIB200_PARQUET_PATH) -> Path: | |
| """Download SIB-200 and persist a lean parquet cache for offline sampling.""" | |
| parquet_path = Path(parquet_path) | |
| parquet_path.parent.mkdir(parents=True, exist_ok=True) | |
| config_names = get_dataset_config_names(SIB200_DATASET) | |
| frames: list[pd.DataFrame] = [] | |
| for config_name in tqdm(config_names, desc="SIB-200 configs"): | |
| frame = _frame_from_dataset(config_name) | |
| if not frame.empty: | |
| frames.append(frame) | |
| else: | |
| tqdm.write(f"Skipping SIB-200 config without a direct TSV layout: {config_name}") | |
| if not frames: | |
| raise RuntimeError("No usable SIB-200 rows were loaded.") | |
| combined = pd.concat(frames, ignore_index=True) | |
| combined["split_rank"] = combined["split"].map(lambda split: SIB200_SPLIT_ORDER.get(str(split), 99)) | |
| combined = combined.sort_values(by=["source_lang", "split_rank", "index_id"], kind="stable").reset_index(drop=True) | |
| combined = combined.drop(columns=["split_rank"], errors="ignore") | |
| combined.to_parquet(parquet_path, index=False) | |
| print( | |
| f"Built lean SIB-200 parquet with {len(combined):,} rows " | |
| f"and {len(combined.columns)} columns at {parquet_path}." | |
| ) | |
| return parquet_path | |
| def load_sib200_table(parquet_path: str | Path = SIB200_PARQUET_PATH) -> pd.DataFrame: | |
| parquet_path = Path(parquet_path) | |
| if not parquet_path.exists(): | |
| raise FileNotFoundError( | |
| f"Missing SIB-200 cache at {parquet_path}. " | |
| "Run `./.venv/bin/python sib200_cache.py` once while online to build it." | |
| ) | |
| frame = pd.read_parquet(parquet_path) | |
| if "text" not in frame.columns: | |
| raise RuntimeError("SIB-200 parquet cache is missing the text column.") | |
| return frame | |
| def fetch_random_sib200_sentence( | |
| *, | |
| attempts: int = 8, | |
| parquet_path: str | Path = SIB200_PARQUET_PATH, | |
| ) -> dict[str, Any]: | |
| frame = load_sib200_table(parquet_path) | |
| candidate_frame = frame[frame["lang_iso2"].isin(ALL_LANGS)] if "lang_iso2" in frame.columns else frame | |
| return sample_single_group_bundle( | |
| candidate_frame, | |
| group_column="lang_iso2", | |
| row_to_sentence=_row_to_sentence, | |
| attempts=attempts, | |
| ) | |
| def fetch_random_sib200_sentence_mix( | |
| *, | |
| min_groups: int = 2, | |
| max_groups: int = 3, | |
| parquet_path: str | Path = SIB200_PARQUET_PATH, | |
| ) -> dict[str, Any]: | |
| frame = load_sib200_table(parquet_path) | |
| candidate_frame = frame[frame["lang_iso2"].isin(ALL_LANGS)] if "lang_iso2" in frame.columns else frame | |
| bundle = sample_multi_group_bundle( | |
| candidate_frame, | |
| group_column="lang_iso2", | |
| row_to_sentence=_row_to_sentence, | |
| min_groups=min_groups, | |
| max_groups=max_groups, | |
| ) | |
| return { | |
| **bundle, | |
| "source": "sib200-mix", | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Build the cached text-only SIB-200 parquet.") | |
| parser.add_argument( | |
| "--output", | |
| default=str(SIB200_PARQUET_PATH), | |
| help="Output parquet path for the cached SIB-200 text rows.", | |
| ) | |
| args = parser.parse_args() | |
| path = build_sib200_text_parquet(args.output) | |
| print(f"Wrote SIB-200 text cache to {path}") | |
| if __name__ == "__main__": | |
| main() | |