"""Datasets konfigurācija un ielāde.""" from __future__ import annotations import errno import json import logging import os import tempfile from collections.abc import Callable, Iterator from pathlib import Path from typing import Any from maris_core.data.preprocessing import record_to_training_text from maris_core.utils.env import get_env_any, get_env_any_or_default logger = logging.getLogger(__name__) _SUPPORTED_DATASET_FORMATS = { ".json": "json", ".jsonl": "json", ".csv": "csv", ".parquet": "parquet", } class HFDatasetError(FileNotFoundError): """Atmiņas repozitorija konfigurācijas vai satura kļūda.""" def __init__(self, message: str, *, discovered_files: list[str] | None = None) -> None: super().__init__(message) self.discovered_files = discovered_files or [] def _iter_exception_chain(exc: Exception) -> Iterator[BaseException]: """Iziet cauri izņēmuma cēloņu ķēdei bez cikliem.""" pending: list[BaseException] = [exc] seen: set[int] = set() while pending: current = pending.pop() current_id = id(current) if current_id in seen: continue seen.add(current_id) yield current cause = getattr(current, "__cause__", None) context = getattr(current, "__context__", None) if cause is not None: pending.append(cause) if context is not None: pending.append(context) def _is_empty_dataset_error(exc: Exception) -> bool: """Nosaka, vai repozitorijam trūkst tieši ielādējamu datu failu.""" markers = ( "doesn't contain any data files", "No (supported) data files found in", ) return any( current.__class__.__name__ == "EmptyDatasetError" or any(marker in str(current) for marker in markers) for current in _iter_exception_chain(exc) ) def _is_schema_cast_error(exc: Exception) -> bool: """Nosaka, vai JSON loader nokrīt uz mainīgu objektu shēmu.""" markers = ( "Couldn't cast", "Couldn't cast array of type", ) return any( any(marker in str(current) for marker in markers) for current in _iter_exception_chain(exc) ) def _is_dataset_generation_error(exc: Exception) -> bool: """Nosaka, vai datasets slānis meta ģenerisku DatasetGenerationError.""" return any( current.__class__.__name__ == "DatasetGenerationError" for current in _iter_exception_chain(exc) ) def _should_fallback_to_snapshot(exc: Exception) -> bool: """Nosaka, vai jāizmanto snapshot-based datu ielāde apmācībai.""" return ( _is_empty_dataset_error(exc) or _is_schema_cast_error(exc) or _is_dataset_generation_error(exc) ) def _is_hf_cache_lock_io_error(exc: Exception) -> bool: """Nosaka, vai Hugging Face cache lock failu sistēmā radās I/O kļūda.""" for current in _iter_exception_chain(exc): if not isinstance(current, OSError): continue message = str(current) filename = os.fspath(getattr(current, "filename", "") or "") details = " ".join(part for part in (filename, message) if part).lower() if "huggingface" not in details and "datasets--" not in details: continue if "lock" not in details: continue if current.errno == errno.EIO or "input/output error" in details: return True return False def _resolve_hf_cache_recovery_dir() -> str: """Atgriež drošu fallback cache direktoriju HF lejupielādēm.""" configured = get_env_any("MARIS_HF_FALLBACK_CACHE_DIR") fallback_dir = (Path(tempfile.gettempdir()) / "maris-hf-cache").resolve() candidate = Path(configured).expanduser() if configured else fallback_dir try: resolved = candidate.resolve(strict=False) if resolved == Path(resolved.anchor): raise ValueError(f"nedroša cache direktorija: {resolved}") if resolved.exists() and not resolved.is_dir(): raise ValueError(f"cache ceļš nav direktorija: {resolved}") resolved.mkdir(parents=True, exist_ok=True) return str(resolved) except (OSError, ValueError) as exc: if configured: logger.warning( "Ignorējam nederīgu MARIS_HF_FALLBACK_CACHE_DIR=%s: %s; lietojam %s.", configured, exc, fallback_dir, ) fallback_dir.mkdir(parents=True, exist_ok=True) return str(fallback_dir) def _call_with_hf_cache_recovery( action: str, call: Callable[..., Any], *args: Any, recovery_cache_dir: str | None = None, **kwargs: Any, ) -> tuple[Any, str | None]: """Izsauc HF funkciju un pie cache lock I/O kļūdas pārslēdzas uz fallback cache.""" call_kwargs = dict(kwargs) if recovery_cache_dir is not None: call_kwargs["cache_dir"] = recovery_cache_dir return call(*args, **call_kwargs), recovery_cache_dir try: return call(*args, **call_kwargs), None except Exception as exc: # noqa: BLE001 if not _is_hf_cache_lock_io_error(exc): raise recovery_cache_dir = _resolve_hf_cache_recovery_dir() logger.warning( "Hugging Face cache lock kļūda (%s); atkārtojam ar fallback cache %s.", action, recovery_cache_dir, ) call_kwargs["cache_dir"] = recovery_cache_dir return call(*args, **call_kwargs), recovery_cache_dir def _preview_discovered_files(discovered_files: list[str]) -> str: """Atgriež īsu atrasto failu sarakstu kļūdu ziņojumiem.""" if not discovered_files: return "nav" return ", ".join(sorted(discovered_files)[:5]) def _preview_data_files(data_files: list[str]) -> str: """Atgriež īsu datu failu sarakstu kļūdu ziņojumiem.""" preview_paths: list[str] = [] for file_name in sorted(data_files)[:5]: file_path = Path(file_name) if "data" in file_path.parts: data_index = file_path.parts.index("data") preview_paths.append("/".join(file_path.parts[data_index:])) continue preview_paths.append(file_path.name) return ", ".join(preview_paths) if preview_paths else "nav" def _summarize_exception_messages(exc: Exception) -> str: """Savāc īsu izņēmumu ķēdes kopsavilkumu lietotājam.""" details: list[str] = [] for current in _iter_exception_chain(exc): message = str(current).strip() if not message or message in details: continue details.append(message) return " | ".join(details[:3]) if details else exc.__class__.__name__ def _build_empty_repo_message( repo_id: str, snapshot_dir: Path, discovered_files: list[str], ) -> str: """Izveido precīzu kļūdas ziņojumu tukšam atmiņas repozitorijam.""" preview = _preview_discovered_files(discovered_files) return ( f"Maris atmiņas repo {repo_id} pašlaik nesatur nevienu atbalstītu datu failu " f"(.jsonl, .json, .csv, .parquet). Snapshot direktorijā {snapshot_dir} tika " f"atrasti tikai: {preview}. Lai salabotu origin repozitorijā, augšupielādē vismaz " "vienu datu failu zem data//, piemēram " "data/conversation/bootstrap.jsonl, un tad palaid apmācību vēlreiz. Ja lieto " "Git LFS, pārliecinies, ka repozitorijā ir pats datu fails, nevis tikai " ".gitattributes ieraksts." ) def _build_incomplete_snapshot_message( repo_id: str, snapshot_dir: Path, discovered_files: list[str], repo_data_files: list[str], ) -> str: """Izveido kļūdas ziņojumu, ja repo ir faili, bet snapshot tos nesatur.""" preview = _preview_discovered_files(discovered_files) repo_preview = ", ".join(repo_data_files[:5]) return ( f"Maris atmiņas repo {repo_id} satur atbalstītus datu failus ({repo_preview}), bet " f"snapshot direktorijā {snapshot_dir} tie netika atrasti. Snapshotā redzami: " f"{preview}. Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka " "tie ir pilnībā lejupielādēti un pieejami origin repozitorijā." ) def _build_invalid_data_files_message( repo_id: str, data_files: list[str], exc: Exception, ) -> str: """Izveido kļūdas ziņojumu bojātiem vai nederīgiem datu failiem.""" preview = _preview_data_files(data_files) details = _summarize_exception_messages(exc) return ( f"Maris atmiņas repo {repo_id} datu failus neizdevās nolasīt apmācībai ({preview}). " f"Cēlonis: {details}. Pārbaudi, vai faili ir pilnībā augšupielādēti, UTF-8 kodējumā " "un satur derīgus JSON/JSONL, CSV vai Parquet ierakstus." ) def _find_snapshot_data_files(snapshot_dir: Path) -> tuple[str, list[str]]: """Atrod ielādējamus datu failus dataset snapshot direktorijā.""" snapshot_dir = _validate_snapshot_dir(snapshot_dir) search_roots = [snapshot_dir / "data", snapshot_dir] discovered_files: list[str] = [] for root in search_roots: if not root.exists(): continue files_by_format: dict[str, list[str]] = {} for file_path in _iter_snapshot_files(root): discovered_files.append(str(file_path.relative_to(snapshot_dir))) dataset_format = _SUPPORTED_DATASET_FORMATS.get(file_path.suffix.lower()) if dataset_format is None: continue files_by_format.setdefault(dataset_format, []).append(str(file_path)) if not files_by_format: continue if len(files_by_format) > 1: formats = ", ".join(sorted(files_by_format)) raise ValueError( f"Snapshot satur vairākus atbalstītus datu formātus vienlaikus: {formats}" ) dataset_format, files = next(iter(files_by_format.items())) return dataset_format, files if discovered_files: preview = _preview_discovered_files(discovered_files) raise HFDatasetError( "Dataset snapshot direktorijā nav atrasti atbalstīti dati: " f"{snapshot_dir}. Atrastie faili: {preview}. " "Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka " "tie ir pilnībā lejupielādēti.", discovered_files=discovered_files, ) raise HFDatasetError( f"Dataset snapshot direktorijā nav atrasti atbalstīti dati: {snapshot_dir}" ) def _validate_snapshot_dir(snapshot_dir: Path) -> Path: """Normalizē un pārbauda snapshot direktoriju pirms rekursīvas skenēšanas.""" resolved = snapshot_dir.expanduser().resolve() if not resolved.exists() or not resolved.is_dir(): raise HFDatasetError(f"Dataset snapshot direktorija nav derīga: {resolved}") if resolved == Path(resolved.anchor): raise HFDatasetError( f"Dataset snapshot direktorija nav droša rekursīvai skenēšanai: {resolved}" ) return resolved def _is_invalid_snapshot_dir_error(exc: HFDatasetError) -> bool: """Nosaka, vai snapshot fallback atgrieza nederīgu vai bīstamu ceļu.""" message = str(exc) return "nav derīga" in message or "nav droša rekursīvai skenēšanai" in message def _iter_snapshot_files(root: Path) -> Iterator[Path]: """Iterē snapshot failus, izlaižot nederīgus sistēmas mezglus un OSError ceļus.""" def handle_walk_error(exc: OSError) -> None: logger.warning( "Skipping unreadable dataset snapshot path %s: %s", getattr(exc, "filename", root), exc, ) for current_root, dirnames, filenames in os.walk( root, topdown=True, onerror=handle_walk_error, followlinks=False, ): dirnames.sort() for filename in sorted(filenames): candidate = Path(current_root) / filename try: if candidate.is_file(): yield candidate except OSError as exc: logger.warning("Skipping unreadable dataset snapshot file %s: %s", candidate, exc) def _list_repo_data_files(repo_id: str, token: str | None) -> list[str]: """Atrod atbalstītos datu failus tieši atmiņas repozitorijā.""" from huggingface_hub import HfApi # type: ignore api = HfApi(token=token) repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") return sorted( path for path in repo_files if _SUPPORTED_DATASET_FORMATS.get(Path(path).suffix.lower()) is not None ) def _coerce_training_record(record: Any) -> dict[str, Any]: """Normalizē vienu ierakstu uz stabilu text-only formu apmācības datasetam.""" if isinstance(record, dict): return {"text": record_to_training_text(record)} try: serialized = json.dumps(record, ensure_ascii=False, sort_keys=True) except TypeError: serialized = str(record) return {"text": serialized} def _build_train_dataset(record_factory: Callable[[], Iterator[dict[str, Any]]]) -> Any: """Izveido train split datasetu no ierakstu ģeneratora.""" from datasets import Dataset, DatasetDict # type: ignore return DatasetDict({"train": Dataset.from_generator(record_factory)}) def _iter_dataset_splits(dataset: Any) -> Iterator[tuple[str, Any]]: """Atgriež pieejamos dataset splitus neatkarīgi no konkrētā tipa.""" if isinstance(dataset, dict): yield from dataset.items() return items = getattr(dataset, "items", None) if callable(items): yield from items() return keys = getattr(dataset, "keys", None) if callable(keys): for key in keys(): yield key, dataset[key] def _normalize_training_dataset(dataset: Any, repo_id: str) -> Any: """Pārveido multi-split datasetu uz train split apmācībai.""" split_items = list(_iter_dataset_splits(dataset)) if not split_items or any(name == "train" for name, _ in split_items): return dataset logger.warning( "Maris atmiņas repo %s atgrieza splitus bez 'train'; apvienojam tos vienā train split apmācībai.", repo_id, ) def iter_records() -> Iterator[dict[str, Any]]: for _, split in split_items: for record in split: yield _coerce_training_record(record) return _build_train_dataset(iter_records) def _iter_json_payload_records(payload: Any) -> Iterator[dict[str, Any]]: """Izvērš JSON payload uz ierakstu plūsmu.""" if isinstance(payload, list): for item in payload: yield _coerce_training_record(item) return yield _coerce_training_record(payload) def _iter_json_file_records(data_files: list[str]) -> Iterator[dict[str, Any]]: """Nolasa JSON/JSONL failus bez stingras nested-object shēmas.""" for file_name in data_files: file_path = Path(file_name) suffix = file_path.suffix.lower() if suffix == ".jsonl": with file_path.open(encoding="utf-8") as handle: for line in handle: stripped = line.strip() if not stripped: continue yield _coerce_training_record(json.loads(stripped)) continue if suffix == ".json": payload = json.loads(file_path.read_text(encoding="utf-8")) yield from _iter_json_payload_records(payload) continue raise ValueError(f"Neatbalstīts JSON datu fails: {file_path}") def _load_data_files_as_training_dataset( repo_id: str, dataset_format: str, data_files: list[str], *, recovery_cache_dir: str | None = None, ) -> Any: """Ielādē atrastos datu failus apmācībai kā train split.""" try: if dataset_format == "json": return _build_train_dataset(lambda: _iter_json_file_records(data_files)) from datasets import load_dataset # type: ignore dataset, _ = _call_with_hf_cache_recovery( f"load_dataset({dataset_format})", load_dataset, dataset_format, data_files={"train": data_files}, recovery_cache_dir=recovery_cache_dir, ) return dataset except Exception as exc: # noqa: BLE001 raise HFDatasetError(_build_invalid_data_files_message(repo_id, data_files, exc)) from exc def load_hf_dataset(repo_id: str | None = None) -> Any: """Ielādē Maris atmiņas datasets.""" repo_id = repo_id or get_env_any_or_default( "MARIS_MEMORY_REPO", "MARIS_DATASET_REPO", "HF_DATASET_REPO", default="MarisUK/maris-ai-memory", ) token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") recovery_cache_dir: str | None = None try: from datasets import load_dataset # type: ignore logger.info("Ielādē dataset: %s", repo_id) dataset, recovery_cache_dir = _call_with_hf_cache_recovery( f"load_dataset({repo_id})", load_dataset, repo_id, token=token, recovery_cache_dir=recovery_cache_dir, ) return _normalize_training_dataset(dataset, repo_id) except Exception as exc: # noqa: BLE001 if not _should_fallback_to_snapshot(exc): logger.error("Dataseta ielādes kļūda: %s", exc) raise try: from huggingface_hub import snapshot_download # type: ignore logger.warning( "Maris atmiņas repo %s nevarēja ielādēt tieši; mēģinām apmācības snapshot fallback.", repo_id, ) snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery( f"snapshot_download({repo_id})", snapshot_download, repo_id=repo_id, repo_type="dataset", token=token, recovery_cache_dir=recovery_cache_dir, ) snapshot_dir = Path(snapshot_dir_str) try: dataset_format, data_files = _find_snapshot_data_files(snapshot_dir) except HFDatasetError as missing_exc: if _is_invalid_snapshot_dir_error(missing_exc): raise repo_data_files = _list_repo_data_files(repo_id, token) if not repo_data_files: raise HFDatasetError( _build_empty_repo_message( repo_id, snapshot_dir, missing_exc.discovered_files, ), discovered_files=missing_exc.discovered_files, ) from missing_exc logger.warning( "Snapshot cache nesatur dataset failus; lejupielādējam atrastos data failus tieši: %s", ", ".join(repo_data_files[:5]), ) snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery( f"snapshot_download({repo_id}, allow_patterns)", snapshot_download, repo_id=repo_id, repo_type="dataset", token=token, allow_patterns=repo_data_files, recovery_cache_dir=recovery_cache_dir, ) snapshot_dir = Path(snapshot_dir_str) try: dataset_format, data_files = _find_snapshot_data_files(snapshot_dir) except HFDatasetError as final_exc: raise HFDatasetError( _build_incomplete_snapshot_message( repo_id, snapshot_dir, final_exc.discovered_files, repo_data_files, ), discovered_files=final_exc.discovered_files, ) from final_exc return _load_data_files_as_training_dataset( repo_id, dataset_format, data_files, recovery_cache_dir=recovery_cache_dir, ) except Exception as fallback_exc: # noqa: BLE001 logger.error("Dataseta ielādes kļūda: %s", fallback_exc) raise