Spaces:
Running
Running
File size: 10,429 Bytes
e8aa13a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 | from __future__ import annotations
import argparse
import importlib
from collections import Counter
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any
import yaml
from hakari_bench.datasets import DatasetRegistry, resolve_dataset_splits
DetectionByTask = dict[tuple[str, str], dict[str, Any]]
Detector = Callable[[str], Any]
DEFAULT_MIN_LANGUAGE_RATIO = 0.005
DEFAULT_MAIN_LANGUAGE_RATIO = 0.10
def main() -> None:
parser = argparse.ArgumentParser(description="Update query/document language detection in dataset YAML metadata.")
parser.add_argument("--config-root", type=Path, default=Path("config"))
parser.add_argument("--dataset", default=None, help="Optional dataset name or id to update.")
parser.add_argument("--task", default=None, help="Optional split/task name to update.")
parser.add_argument("--dataset-revision", default=None)
parser.add_argument("--min-language-ratio", type=float, default=DEFAULT_MIN_LANGUAGE_RATIO)
parser.add_argument("--main-language-ratio", type=float, default=DEFAULT_MAIN_LANGUAGE_RATIO)
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
registry = DatasetRegistry.load_from_root(args.config_root)
datasets = [registry.get_dataset(args.dataset)] if args.dataset else sorted(
{id(dataset): dataset for dataset in registry._datasets_by_key.values()}.values(),
key=lambda item: item.name,
)
detections: DetectionByTask = {}
for dataset in datasets:
split_names = resolve_dataset_splits(dataset)
for split_name in split_names:
task_name = _task_name_for_split(dataset, split_name)
if args.task and args.task not in {split_name, task_name}:
continue
try:
detection = load_language_detection(
dataset_id=dataset.dataset_id,
corpus_config=dataset.corpus_config,
queries_config=dataset.queries_config,
split_name=split_name,
revision=args.dataset_revision,
min_language_ratio=args.min_language_ratio,
main_language_ratio=args.main_language_ratio,
)
except Exception as exc: # pragma: no cover - exercised manually against remote datasets
print(f"warning: failed to load {dataset.dataset_id}/{split_name}: {type(exc).__name__}: {exc}")
continue
detections[(dataset.dataset_id, task_name)] = detection
print(
f"{dataset.name}/{task_name}: languages={detection['languages']} "
f"query={detection['language_detection']['query']['languages']} "
f"document={detection['language_detection']['document']['languages']}",
flush=True,
)
if args.dry_run:
return
for path in sorted((args.config_root / "datasets").glob("*.yaml")):
update_languages_in_file(path, detections)
for path in sorted((args.config_root / "dataset_collections").glob("*.yaml")):
update_languages_in_file(path, detections)
def load_language_detection(
*,
dataset_id: str,
corpus_config: str,
queries_config: str,
split_name: str,
revision: str | None,
min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO,
main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO,
) -> dict[str, Any]:
from datasets import load_dataset
queries = load_dataset(dataset_id, queries_config, split=split_name, revision=revision)
corpus = load_dataset(dataset_id, corpus_config, split=split_name, revision=revision)
return build_language_detection_metadata(
query_texts=[str(text) for text in queries["text"] if str(text)],
document_texts=[str(text) for text in corpus["text"] if str(text)],
min_language_ratio=min_language_ratio,
main_language_ratio=main_language_ratio,
)
def build_language_detection_metadata(
*,
query_texts: Iterable[str],
document_texts: Iterable[str],
detector: Detector | None = None,
min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO,
main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO,
) -> dict[str, Any]:
query_detection = detect_language_distribution(
query_texts,
detector=detector,
min_language_ratio=min_language_ratio,
)
document_detection = detect_language_distribution(
document_texts,
detector=detector,
min_language_ratio=min_language_ratio,
)
languages = select_main_languages(
query_detection["ratios"],
document_detection["ratios"],
main_language_ratio=main_language_ratio,
)
return {
"languages": languages,
"language_detection": {
"detector": "fast-langdetect",
"min_language_percent": round(min_language_ratio * 100, 3),
"main_language_percent": round(main_language_ratio * 100, 3),
"query": _language_detection_payload(query_detection),
"document": _language_detection_payload(document_detection),
},
}
def detect_language_distribution(
texts: Iterable[str],
*,
detector: Detector | None = None,
min_language_ratio: float = DEFAULT_MIN_LANGUAGE_RATIO,
) -> dict[str, Any]:
counts: Counter[str] = Counter()
total = 0
detect = detector or _detect_language
for text in texts:
normalized = str(text).strip()
if not normalized:
continue
total += 1
language, _score = _normalize_detect_result(_safe_detect(detect, normalized))
counts[language or "unknown"] += 1
if total == 0:
return {"sample_count": 0, "counts": {"unknown": 0}, "ratios": {"unknown": 1.0}}
ratios = {language: count / total for language, count in counts.items()}
filtered = {
language: ratio
for language, ratio in sorted(ratios.items(), key=lambda item: (-item[1], item[0]))
if ratio >= min_language_ratio
}
if not filtered:
filtered = {"unknown": 1.0}
return {
"sample_count": total,
"counts": dict(counts),
"ratios": filtered,
}
def select_main_languages(
query_ratios: dict[str, float],
document_ratios: dict[str, float],
*,
main_language_ratio: float = DEFAULT_MAIN_LANGUAGE_RATIO,
) -> list[str]:
languages = set(query_ratios) | set(document_ratios)
selected = [
language
for language in languages
if max(query_ratios.get(language, 0.0), document_ratios.get(language, 0.0)) >= main_language_ratio
]
if not selected:
return ["unknown"]
return sorted(
selected,
key=lambda language: (
-max(query_ratios.get(language, 0.0), document_ratios.get(language, 0.0)),
-document_ratios.get(language, 0.0),
-query_ratios.get(language, 0.0),
language,
),
)
def update_languages_in_file(path: Path, detections: DetectionByTask) -> bool:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return False
changed = _update_languages_for_mapping(data, detections)
raw_datasets = data.get("datasets")
if isinstance(raw_datasets, list):
for raw_dataset in raw_datasets:
if isinstance(raw_dataset, dict):
changed = _update_languages_for_mapping(raw_dataset, detections) or changed
if changed:
path.write_text(yaml.safe_dump(data, sort_keys=False, allow_unicode=True), encoding="utf-8")
return changed
def _update_languages_for_mapping(data: dict[str, Any], detections: DetectionByTask) -> bool:
dataset_id = str(data.get("dataset_id", ""))
task_metadata = data.get("task_metadata")
if not isinstance(task_metadata, dict):
return False
changed = False
for task_name, metadata in task_metadata.items():
if not isinstance(metadata, dict):
continue
task_detection = detections.get((dataset_id, str(task_name)))
if task_detection is None:
continue
metadata["languages"] = _main_languages_for_metadata(metadata, task_detection)
metadata["language_detection"] = task_detection["language_detection"]
changed = True
return changed
def _main_languages_for_metadata(metadata: dict[str, Any], task_detection: dict[str, Any]) -> list[str]:
if metadata.get("category") == "code" and metadata.get("language") == "en":
return ["en"]
return list(task_detection["languages"])
def _language_detection_payload(detection: dict[str, Any]) -> dict[str, Any]:
return {
"sample_count": detection["sample_count"],
"languages": {
language: round(ratio * 100, 3)
for language, ratio in sorted(detection["ratios"].items(), key=lambda item: (-item[1], item[0]))
},
}
def _safe_detect(detector: Detector, text: str) -> Any:
try:
return detector(text)
except Exception:
return None
def _detect_language(text: str) -> Any:
fast_langdetect_module = importlib.import_module("fast_langdetect")
detect_fn = getattr(fast_langdetect_module, "detect")
return detect_fn(text)
def _normalize_detect_result(result: Any) -> tuple[str | None, float]:
candidate = result[0] if isinstance(result, list) and result else result
if isinstance(candidate, dict):
language = candidate.get("lang") or candidate.get("language")
score = candidate.get("score") or candidate.get("probability") or 0.0
return (_normalize_language_code(str(language)) if language else None, float(score))
if isinstance(candidate, str):
return _normalize_language_code(candidate), 1.0
return None, 0.0
def _normalize_language_code(language: str) -> str:
return language.lower().replace("_", "-").split("-")[0]
def _task_name_for_split(dataset: Any, split_name: str) -> str:
mapping = dataset.effective_split_mapping
if mapping is not None:
if split_name in mapping:
return str(mapping[split_name])
for _logical_name, mapped_split in mapping.items():
if split_name == mapped_split:
return str(mapped_split)
return split_name
if __name__ == "__main__":
main()
|