from __future__ import annotations import argparse import importlib from collections import Counter from collections.abc import Callable, Iterable from pathlib import Path from typing import Any import yaml from hakari_bench.datasets import DatasetRegistry, resolve_dataset_splits DetectionByTask = dict[tuple[str, str], dict[str, Any]] Detector = Callable[[str], Any] DEFAULT_MIN_LANGUAGE_RATIO = 0.005 DEFAULT_MAIN_LANGUAGE_RATIO = 0.10 def main() -> None: parser = argparse.ArgumentParser(description="Update query/document language detection in dataset YAML metadata.") parser.add_argument("--config-root", type=Path, default=Path("config")) parser.add_argument("--dataset", default=None, help="Optional dataset name or id to update.") parser.add_argument("--task", default=None, help="Optional split/task name to update.") parser.add_argument("--dataset-revision", default=None) parser.add_argument("--min-language-ratio", type=float, default=DEFAULT_MIN_LANGUAGE_RATIO) parser.add_argument("--main-language-ratio", type=float, default=DEFAULT_MAIN_LANGUAGE_RATIO) parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() registry = DatasetRegistry.load_from_root(args.config_root) datasets = [registry.get_dataset(args.dataset)] if args.dataset else sorted( {id(dataset): dataset for dataset in registry._datasets_by_key.values()}.values(), key=lambda item: item.name, ) detections: DetectionByTask = {} for dataset in datasets: split_names = resolve_dataset_splits(dataset) for split_name in split_names: task_name = _task_name_for_split(dataset, split_name) if args.task and args.task not in {split_name, task_name}: continue try: detection = load_language_detection( dataset_id=dataset.dataset_id, corpus_config=dataset.corpus_config, queries_config=dataset.queries_config, split_name=split_name, revision=args.dataset_revision, min_language_ratio=args.min_language_ratio, main_language_ratio=args.main_language_ratio, ) except Exception as exc: # pragma: no cover - exercised manually against remote datasets print(f"warning: failed to load {dataset.dataset_id}/{split_name}: {type(exc).__name__}: {exc}") continue detections[(dataset.dataset_id, task_name)] = detection print( f"{dataset.name}/{task_name}: languages={detection['languages']} " f"query={detection['language_detection']['query']['languages']} " f"document={detection['language_detection']['document']['languages']}", flush=True, ) if args.dry_run: return for path in sorted((args.config_root / "datasets").glob("*.yaml")): update_languages_in_file(path, detections) for path in sorted((args.config_root / "dataset_collections").glob("*.yaml")): update_languages_in_file(path, detections) def load_language_detection( *, dataset_id: str, corpus_config: str, queries_config: str, split_name: str, revision: str | None, min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO, main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO, ) -> dict[str, Any]: from datasets import load_dataset queries = load_dataset(dataset_id, queries_config, split=split_name, revision=revision) corpus = load_dataset(dataset_id, corpus_config, split=split_name, revision=revision) return build_language_detection_metadata( query_texts=[str(text) for text in queries["text"] if str(text)], document_texts=[str(text) for text in corpus["text"] if str(text)], min_language_ratio=min_language_ratio, main_language_ratio=main_language_ratio, ) def build_language_detection_metadata( *, query_texts: Iterable[str], document_texts: Iterable[str], detector: Detector | None = None, min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO, main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO, ) -> dict[str, Any]: query_detection = detect_language_distribution( query_texts, detector=detector, min_language_ratio=min_language_ratio, ) document_detection = detect_language_distribution( document_texts, detector=detector, min_language_ratio=min_language_ratio, ) languages = select_main_languages( query_detection["ratios"], document_detection["ratios"], main_language_ratio=main_language_ratio, ) return { "languages": languages, "language_detection": { "detector": "fast-langdetect", "min_language_percent": round(min_language_ratio * 100, 3), "main_language_percent": round(main_language_ratio * 100, 3), "query": _language_detection_payload(query_detection), "document": _language_detection_payload(document_detection), }, } def detect_language_distribution( texts: Iterable[str], *, detector: Detector | None = None, min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO, ) -> dict[str, Any]: counts: Counter[str] = Counter() total = 0 detect = detector or _detect_language for text in texts: normalized = str(text).strip() if not normalized: continue total += 1 language, _score = _normalize_detect_result(_safe_detect(detect, normalized)) counts[language or "unknown"] += 1 if total == 0: return {"sample_count": 0, "counts": {"unknown": 0}, "ratios": {"unknown": 1.0}} ratios = {language: count / total for language, count in counts.items()} filtered = { language: ratio for language, ratio in sorted(ratios.items(), key=lambda item: (-item[1], item[0])) if ratio >= min_language_ratio } if not filtered: filtered = {"unknown": 1.0} return { "sample_count": total, "counts": dict(counts), "ratios": filtered, } def select_main_languages( query_ratios: dict[str, float], document_ratios: dict[str, float], *, main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO, ) -> list[str]: languages = set(query_ratios) | set(document_ratios) selected = [ language for language in languages if max(query_ratios.get(language, 0.0), document_ratios.get(language, 0.0)) >= main_language_ratio ] if not selected: return ["unknown"] return sorted( selected, key=lambda language: ( -max(query_ratios.get(language, 0.0), document_ratios.get(language, 0.0)), -document_ratios.get(language, 0.0), -query_ratios.get(language, 0.0), language, ), ) def update_languages_in_file(path: Path, detections: DetectionByTask) -> bool: data = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(data, dict): return False changed = _update_languages_for_mapping(data, detections) raw_datasets = data.get("datasets") if isinstance(raw_datasets, list): for raw_dataset in raw_datasets: if isinstance(raw_dataset, dict): changed = _update_languages_for_mapping(raw_dataset, detections) or changed if changed: path.write_text(yaml.safe_dump(data, sort_keys=False, allow_unicode=True), encoding="utf-8") return changed def _update_languages_for_mapping(data: dict[str, Any], detections: DetectionByTask) -> bool: dataset_id = str(data.get("dataset_id", "")) task_metadata = data.get("task_metadata") if not isinstance(task_metadata, dict): return False changed = False for task_name, metadata in task_metadata.items(): if not isinstance(metadata, dict): continue task_detection = detections.get((dataset_id, str(task_name))) if task_detection is None: continue metadata["languages"] = _main_languages_for_metadata(metadata, task_detection) metadata["language_detection"] = task_detection["language_detection"] changed = True return changed def _main_languages_for_metadata(metadata: dict[str, Any], task_detection: dict[str, Any]) -> list[str]: if metadata.get("category") == "code" and metadata.get("language") == "en": return ["en"] return list(task_detection["languages"]) def _language_detection_payload(detection: dict[str, Any]) -> dict[str, Any]: return { "sample_count": detection["sample_count"], "languages": { language: round(ratio * 100, 3) for language, ratio in sorted(detection["ratios"].items(), key=lambda item: (-item[1], item[0])) }, } def _safe_detect(detector: Detector, text: str) -> Any: try: return detector(text) except Exception: return None def _detect_language(text: str) -> Any: fast_langdetect_module = importlib.import_module("fast_langdetect") detect_fn = getattr(fast_langdetect_module, "detect") return detect_fn(text) def _normalize_detect_result(result: Any) -> tuple[str | None, float]: candidate = result[0] if isinstance(result, list) and result else result if isinstance(candidate, dict): language = candidate.get("lang") or candidate.get("language") score = candidate.get("score") or candidate.get("probability") or 0.0 return (_normalize_language_code(str(language)) if language else None, float(score)) if isinstance(candidate, str): return _normalize_language_code(candidate), 1.0 return None, 0.0 def _normalize_language_code(language: str) -> str: return language.lower().replace("_", "-").split("-")[0] def _task_name_for_split(dataset: Any, split_name: str) -> str: mapping = dataset.effective_split_mapping if mapping is not None: if split_name in mapping: return str(mapping[split_name]) for _logical_name, mapped_split in mapping.items(): if split_name == mapped_split: return str(mapped_split) return split_name if __name__ == "__main__": main()