| """Datasets konfigurācija un ielāde.""" |
|
|
| from __future__ import annotations |
|
|
| import errno |
| import json |
| import logging |
| import os |
| import tempfile |
| from collections.abc import Callable, Iterator |
| from pathlib import Path |
| from typing import Any |
|
|
| from maris_core.data.preprocessing import record_to_training_text |
| from maris_core.utils.env import get_env_any, get_env_any_or_default |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _SUPPORTED_DATASET_FORMATS = { |
| ".json": "json", |
| ".jsonl": "json", |
| ".csv": "csv", |
| ".parquet": "parquet", |
| } |
|
|
|
|
| class HFDatasetError(FileNotFoundError): |
| """Atmiņas repozitorija konfigurācijas vai satura kļūda.""" |
|
|
| def __init__(self, message: str, *, discovered_files: list[str] | None = None) -> None: |
| super().__init__(message) |
| self.discovered_files = discovered_files or [] |
|
|
|
|
| def _iter_exception_chain(exc: Exception) -> Iterator[BaseException]: |
| """Iziet cauri izņēmuma cēloņu ķēdei bez cikliem.""" |
| pending: list[BaseException] = [exc] |
| seen: set[int] = set() |
|
|
| while pending: |
| current = pending.pop() |
| current_id = id(current) |
| if current_id in seen: |
| continue |
| seen.add(current_id) |
| yield current |
|
|
| cause = getattr(current, "__cause__", None) |
| context = getattr(current, "__context__", None) |
| if cause is not None: |
| pending.append(cause) |
| if context is not None: |
| pending.append(context) |
|
|
|
|
| def _is_empty_dataset_error(exc: Exception) -> bool: |
| """Nosaka, vai repozitorijam trūkst tieši ielādējamu datu failu.""" |
| markers = ( |
| "doesn't contain any data files", |
| "No (supported) data files found in", |
| ) |
| return any( |
| current.__class__.__name__ == "EmptyDatasetError" |
| or any(marker in str(current) for marker in markers) |
| for current in _iter_exception_chain(exc) |
| ) |
|
|
|
|
| def _is_schema_cast_error(exc: Exception) -> bool: |
| """Nosaka, vai JSON loader nokrīt uz mainīgu objektu shēmu.""" |
| markers = ( |
| "Couldn't cast", |
| "Couldn't cast array of type", |
| ) |
| return any( |
| any(marker in str(current) for marker in markers) for current in _iter_exception_chain(exc) |
| ) |
|
|
|
|
| def _is_dataset_generation_error(exc: Exception) -> bool: |
| """Nosaka, vai datasets slānis meta ģenerisku DatasetGenerationError.""" |
| return any( |
| current.__class__.__name__ == "DatasetGenerationError" |
| for current in _iter_exception_chain(exc) |
| ) |
|
|
|
|
| def _should_fallback_to_snapshot(exc: Exception) -> bool: |
| """Nosaka, vai jāizmanto snapshot-based datu ielāde apmācībai.""" |
| return ( |
| _is_empty_dataset_error(exc) |
| or _is_schema_cast_error(exc) |
| or _is_dataset_generation_error(exc) |
| ) |
|
|
|
|
| def _is_hf_cache_lock_io_error(exc: Exception) -> bool: |
| """Nosaka, vai Hugging Face cache lock failu sistēmā radās I/O kļūda.""" |
| for current in _iter_exception_chain(exc): |
| if not isinstance(current, OSError): |
| continue |
|
|
| message = str(current) |
| filename = os.fspath(getattr(current, "filename", "") or "") |
| details = " ".join(part for part in (filename, message) if part).lower() |
| if "huggingface" not in details and "datasets--" not in details: |
| continue |
| if "lock" not in details: |
| continue |
| if current.errno == errno.EIO or "input/output error" in details: |
| return True |
| return False |
|
|
|
|
| def _resolve_hf_cache_recovery_dir() -> str: |
| """Atgriež drošu fallback cache direktoriju HF lejupielādēm.""" |
| configured = get_env_any("MARIS_HF_FALLBACK_CACHE_DIR") |
| fallback_dir = (Path(tempfile.gettempdir()) / "maris-hf-cache").resolve() |
| candidate = Path(configured).expanduser() if configured else fallback_dir |
|
|
| try: |
| resolved = candidate.resolve(strict=False) |
| if resolved == Path(resolved.anchor): |
| raise ValueError(f"nedroša cache direktorija: {resolved}") |
| if resolved.exists() and not resolved.is_dir(): |
| raise ValueError(f"cache ceļš nav direktorija: {resolved}") |
| resolved.mkdir(parents=True, exist_ok=True) |
| return str(resolved) |
| except (OSError, ValueError) as exc: |
| if configured: |
| logger.warning( |
| "Ignorējam nederīgu MARIS_HF_FALLBACK_CACHE_DIR=%s: %s; lietojam %s.", |
| configured, |
| exc, |
| fallback_dir, |
| ) |
| fallback_dir.mkdir(parents=True, exist_ok=True) |
| return str(fallback_dir) |
|
|
|
|
| def _call_with_hf_cache_recovery( |
| action: str, |
| call: Callable[..., Any], |
| *args: Any, |
| recovery_cache_dir: str | None = None, |
| **kwargs: Any, |
| ) -> tuple[Any, str | None]: |
| """Izsauc HF funkciju un pie cache lock I/O kļūdas pārslēdzas uz fallback cache.""" |
| call_kwargs = dict(kwargs) |
| if recovery_cache_dir is not None: |
| call_kwargs["cache_dir"] = recovery_cache_dir |
| return call(*args, **call_kwargs), recovery_cache_dir |
|
|
| try: |
| return call(*args, **call_kwargs), None |
| except Exception as exc: |
| if not _is_hf_cache_lock_io_error(exc): |
| raise |
|
|
| recovery_cache_dir = _resolve_hf_cache_recovery_dir() |
| logger.warning( |
| "Hugging Face cache lock kļūda (%s); atkārtojam ar fallback cache %s.", |
| action, |
| recovery_cache_dir, |
| ) |
| call_kwargs["cache_dir"] = recovery_cache_dir |
| return call(*args, **call_kwargs), recovery_cache_dir |
|
|
|
|
| def _preview_discovered_files(discovered_files: list[str]) -> str: |
| """Atgriež īsu atrasto failu sarakstu kļūdu ziņojumiem.""" |
| if not discovered_files: |
| return "nav" |
| return ", ".join(sorted(discovered_files)[:5]) |
|
|
|
|
| def _preview_data_files(data_files: list[str]) -> str: |
| """Atgriež īsu datu failu sarakstu kļūdu ziņojumiem.""" |
| preview_paths: list[str] = [] |
| for file_name in sorted(data_files)[:5]: |
| file_path = Path(file_name) |
| if "data" in file_path.parts: |
| data_index = file_path.parts.index("data") |
| preview_paths.append("/".join(file_path.parts[data_index:])) |
| continue |
| preview_paths.append(file_path.name) |
| return ", ".join(preview_paths) if preview_paths else "nav" |
|
|
|
|
| def _summarize_exception_messages(exc: Exception) -> str: |
| """Savāc īsu izņēmumu ķēdes kopsavilkumu lietotājam.""" |
| details: list[str] = [] |
| for current in _iter_exception_chain(exc): |
| message = str(current).strip() |
| if not message or message in details: |
| continue |
| details.append(message) |
| return " | ".join(details[:3]) if details else exc.__class__.__name__ |
|
|
|
|
| def _build_empty_repo_message( |
| repo_id: str, |
| snapshot_dir: Path, |
| discovered_files: list[str], |
| ) -> str: |
| """Izveido precīzu kļūdas ziņojumu tukšam atmiņas repozitorijam.""" |
| preview = _preview_discovered_files(discovered_files) |
| return ( |
| f"Maris atmiņas repo {repo_id} pašlaik nesatur nevienu atbalstītu datu failu " |
| f"(.jsonl, .json, .csv, .parquet). Snapshot direktorijā {snapshot_dir} tika " |
| f"atrasti tikai: {preview}. Lai salabotu origin repozitorijā, augšupielādē vismaz " |
| "vienu datu failu zem data/<type>/, piemēram " |
| "data/conversation/bootstrap.jsonl, un tad palaid apmācību vēlreiz. Ja lieto " |
| "Git LFS, pārliecinies, ka repozitorijā ir pats datu fails, nevis tikai " |
| ".gitattributes ieraksts." |
| ) |
|
|
|
|
| def _build_incomplete_snapshot_message( |
| repo_id: str, |
| snapshot_dir: Path, |
| discovered_files: list[str], |
| repo_data_files: list[str], |
| ) -> str: |
| """Izveido kļūdas ziņojumu, ja repo ir faili, bet snapshot tos nesatur.""" |
| preview = _preview_discovered_files(discovered_files) |
| repo_preview = ", ".join(repo_data_files[:5]) |
| return ( |
| f"Maris atmiņas repo {repo_id} satur atbalstītus datu failus ({repo_preview}), bet " |
| f"snapshot direktorijā {snapshot_dir} tie netika atrasti. Snapshotā redzami: " |
| f"{preview}. Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka " |
| "tie ir pilnībā lejupielādēti un pieejami origin repozitorijā." |
| ) |
|
|
|
|
| def _build_invalid_data_files_message( |
| repo_id: str, |
| data_files: list[str], |
| exc: Exception, |
| ) -> str: |
| """Izveido kļūdas ziņojumu bojātiem vai nederīgiem datu failiem.""" |
| preview = _preview_data_files(data_files) |
| details = _summarize_exception_messages(exc) |
| return ( |
| f"Maris atmiņas repo {repo_id} datu failus neizdevās nolasīt apmācībai ({preview}). " |
| f"Cēlonis: {details}. Pārbaudi, vai faili ir pilnībā augšupielādēti, UTF-8 kodējumā " |
| "un satur derīgus JSON/JSONL, CSV vai Parquet ierakstus." |
| ) |
|
|
|
|
| def _find_snapshot_data_files(snapshot_dir: Path) -> tuple[str, list[str]]: |
| """Atrod ielādējamus datu failus dataset snapshot direktorijā.""" |
| snapshot_dir = _validate_snapshot_dir(snapshot_dir) |
| search_roots = [snapshot_dir / "data", snapshot_dir] |
| discovered_files: list[str] = [] |
|
|
| for root in search_roots: |
| if not root.exists(): |
| continue |
|
|
| files_by_format: dict[str, list[str]] = {} |
| for file_path in _iter_snapshot_files(root): |
| discovered_files.append(str(file_path.relative_to(snapshot_dir))) |
| dataset_format = _SUPPORTED_DATASET_FORMATS.get(file_path.suffix.lower()) |
| if dataset_format is None: |
| continue |
|
|
| files_by_format.setdefault(dataset_format, []).append(str(file_path)) |
|
|
| if not files_by_format: |
| continue |
|
|
| if len(files_by_format) > 1: |
| formats = ", ".join(sorted(files_by_format)) |
| raise ValueError( |
| f"Snapshot satur vairākus atbalstītus datu formātus vienlaikus: {formats}" |
| ) |
|
|
| dataset_format, files = next(iter(files_by_format.items())) |
| return dataset_format, files |
|
|
| if discovered_files: |
| preview = _preview_discovered_files(discovered_files) |
| raise HFDatasetError( |
| "Dataset snapshot direktorijā nav atrasti atbalstīti dati: " |
| f"{snapshot_dir}. Atrastie faili: {preview}. " |
| "Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka " |
| "tie ir pilnībā lejupielādēti.", |
| discovered_files=discovered_files, |
| ) |
|
|
| raise HFDatasetError( |
| f"Dataset snapshot direktorijā nav atrasti atbalstīti dati: {snapshot_dir}" |
| ) |
|
|
|
|
| def _validate_snapshot_dir(snapshot_dir: Path) -> Path: |
| """Normalizē un pārbauda snapshot direktoriju pirms rekursīvas skenēšanas.""" |
| resolved = snapshot_dir.expanduser().resolve() |
| if not resolved.exists() or not resolved.is_dir(): |
| raise HFDatasetError(f"Dataset snapshot direktorija nav derīga: {resolved}") |
| if resolved == Path(resolved.anchor): |
| raise HFDatasetError( |
| f"Dataset snapshot direktorija nav droša rekursīvai skenēšanai: {resolved}" |
| ) |
| return resolved |
|
|
|
|
| def _is_invalid_snapshot_dir_error(exc: HFDatasetError) -> bool: |
| """Nosaka, vai snapshot fallback atgrieza nederīgu vai bīstamu ceļu.""" |
| message = str(exc) |
| return "nav derīga" in message or "nav droša rekursīvai skenēšanai" in message |
|
|
|
|
| def _iter_snapshot_files(root: Path) -> Iterator[Path]: |
| """Iterē snapshot failus, izlaižot nederīgus sistēmas mezglus un OSError ceļus.""" |
|
|
| def handle_walk_error(exc: OSError) -> None: |
| logger.warning( |
| "Skipping unreadable dataset snapshot path %s: %s", |
| getattr(exc, "filename", root), |
| exc, |
| ) |
|
|
| for current_root, dirnames, filenames in os.walk( |
| root, |
| topdown=True, |
| onerror=handle_walk_error, |
| followlinks=False, |
| ): |
| dirnames.sort() |
| for filename in sorted(filenames): |
| candidate = Path(current_root) / filename |
| try: |
| if candidate.is_file(): |
| yield candidate |
| except OSError as exc: |
| logger.warning("Skipping unreadable dataset snapshot file %s: %s", candidate, exc) |
|
|
|
|
| def _list_repo_data_files(repo_id: str, token: str | None) -> list[str]: |
| """Atrod atbalstītos datu failus tieši atmiņas repozitorijā.""" |
| from huggingface_hub import HfApi |
|
|
| api = HfApi(token=token) |
| repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
| return sorted( |
| path |
| for path in repo_files |
| if _SUPPORTED_DATASET_FORMATS.get(Path(path).suffix.lower()) is not None |
| ) |
|
|
|
|
| def _coerce_training_record(record: Any) -> dict[str, Any]: |
| """Normalizē vienu ierakstu uz stabilu text-only formu apmācības datasetam.""" |
| if isinstance(record, dict): |
| return {"text": record_to_training_text(record)} |
|
|
| try: |
| serialized = json.dumps(record, ensure_ascii=False, sort_keys=True) |
| except TypeError: |
| serialized = str(record) |
| return {"text": serialized} |
|
|
|
|
| def _build_train_dataset(record_factory: Callable[[], Iterator[dict[str, Any]]]) -> Any: |
| """Izveido train split datasetu no ierakstu ģeneratora.""" |
| from datasets import Dataset, DatasetDict |
|
|
| return DatasetDict({"train": Dataset.from_generator(record_factory)}) |
|
|
|
|
| def _iter_dataset_splits(dataset: Any) -> Iterator[tuple[str, Any]]: |
| """Atgriež pieejamos dataset splitus neatkarīgi no konkrētā tipa.""" |
| if isinstance(dataset, dict): |
| yield from dataset.items() |
| return |
|
|
| items = getattr(dataset, "items", None) |
| if callable(items): |
| yield from items() |
| return |
|
|
| keys = getattr(dataset, "keys", None) |
| if callable(keys): |
| for key in keys(): |
| yield key, dataset[key] |
|
|
|
|
| def _normalize_training_dataset(dataset: Any, repo_id: str) -> Any: |
| """Pārveido multi-split datasetu uz train split apmācībai.""" |
| split_items = list(_iter_dataset_splits(dataset)) |
| if not split_items or any(name == "train" for name, _ in split_items): |
| return dataset |
|
|
| logger.warning( |
| "Maris atmiņas repo %s atgrieza splitus bez 'train'; apvienojam tos vienā train split apmācībai.", |
| repo_id, |
| ) |
|
|
| def iter_records() -> Iterator[dict[str, Any]]: |
| for _, split in split_items: |
| for record in split: |
| yield _coerce_training_record(record) |
|
|
| return _build_train_dataset(iter_records) |
|
|
|
|
| def _iter_json_payload_records(payload: Any) -> Iterator[dict[str, Any]]: |
| """Izvērš JSON payload uz ierakstu plūsmu.""" |
| if isinstance(payload, list): |
| for item in payload: |
| yield _coerce_training_record(item) |
| return |
|
|
| yield _coerce_training_record(payload) |
|
|
|
|
| def _iter_json_file_records(data_files: list[str]) -> Iterator[dict[str, Any]]: |
| """Nolasa JSON/JSONL failus bez stingras nested-object shēmas.""" |
| for file_name in data_files: |
| file_path = Path(file_name) |
| suffix = file_path.suffix.lower() |
| if suffix == ".jsonl": |
| with file_path.open(encoding="utf-8") as handle: |
| for line in handle: |
| stripped = line.strip() |
| if not stripped: |
| continue |
| yield _coerce_training_record(json.loads(stripped)) |
| continue |
|
|
| if suffix == ".json": |
| payload = json.loads(file_path.read_text(encoding="utf-8")) |
| yield from _iter_json_payload_records(payload) |
| continue |
|
|
| raise ValueError(f"Neatbalstīts JSON datu fails: {file_path}") |
|
|
|
|
| def _load_data_files_as_training_dataset( |
| repo_id: str, |
| dataset_format: str, |
| data_files: list[str], |
| *, |
| recovery_cache_dir: str | None = None, |
| ) -> Any: |
| """Ielādē atrastos datu failus apmācībai kā train split.""" |
| try: |
| if dataset_format == "json": |
| return _build_train_dataset(lambda: _iter_json_file_records(data_files)) |
|
|
| from datasets import load_dataset |
|
|
| dataset, _ = _call_with_hf_cache_recovery( |
| f"load_dataset({dataset_format})", |
| load_dataset, |
| dataset_format, |
| data_files={"train": data_files}, |
| recovery_cache_dir=recovery_cache_dir, |
| ) |
| return dataset |
| except Exception as exc: |
| raise HFDatasetError(_build_invalid_data_files_message(repo_id, data_files, exc)) from exc |
|
|
|
|
| def load_hf_dataset(repo_id: str | None = None) -> Any: |
| """Ielādē Maris atmiņas datasets.""" |
| repo_id = repo_id or get_env_any_or_default( |
| "MARIS_MEMORY_REPO", |
| "MARIS_DATASET_REPO", |
| "HF_DATASET_REPO", |
| default="MarisUK/maris-ai-memory", |
| ) |
| token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") |
| recovery_cache_dir: str | None = None |
|
|
| try: |
| from datasets import load_dataset |
|
|
| logger.info("Ielādē dataset: %s", repo_id) |
| dataset, recovery_cache_dir = _call_with_hf_cache_recovery( |
| f"load_dataset({repo_id})", |
| load_dataset, |
| repo_id, |
| token=token, |
| recovery_cache_dir=recovery_cache_dir, |
| ) |
| return _normalize_training_dataset(dataset, repo_id) |
| except Exception as exc: |
| if not _should_fallback_to_snapshot(exc): |
| logger.error("Dataseta ielādes kļūda: %s", exc) |
| raise |
|
|
| try: |
| from huggingface_hub import snapshot_download |
|
|
| logger.warning( |
| "Maris atmiņas repo %s nevarēja ielādēt tieši; mēģinām apmācības snapshot fallback.", |
| repo_id, |
| ) |
| snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery( |
| f"snapshot_download({repo_id})", |
| snapshot_download, |
| repo_id=repo_id, |
| repo_type="dataset", |
| token=token, |
| recovery_cache_dir=recovery_cache_dir, |
| ) |
| snapshot_dir = Path(snapshot_dir_str) |
| try: |
| dataset_format, data_files = _find_snapshot_data_files(snapshot_dir) |
| except HFDatasetError as missing_exc: |
| if _is_invalid_snapshot_dir_error(missing_exc): |
| raise |
| repo_data_files = _list_repo_data_files(repo_id, token) |
| if not repo_data_files: |
| raise HFDatasetError( |
| _build_empty_repo_message( |
| repo_id, |
| snapshot_dir, |
| missing_exc.discovered_files, |
| ), |
| discovered_files=missing_exc.discovered_files, |
| ) from missing_exc |
|
|
| logger.warning( |
| "Snapshot cache nesatur dataset failus; lejupielādējam atrastos data failus tieši: %s", |
| ", ".join(repo_data_files[:5]), |
| ) |
| snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery( |
| f"snapshot_download({repo_id}, allow_patterns)", |
| snapshot_download, |
| repo_id=repo_id, |
| repo_type="dataset", |
| token=token, |
| allow_patterns=repo_data_files, |
| recovery_cache_dir=recovery_cache_dir, |
| ) |
| snapshot_dir = Path(snapshot_dir_str) |
| try: |
| dataset_format, data_files = _find_snapshot_data_files(snapshot_dir) |
| except HFDatasetError as final_exc: |
| raise HFDatasetError( |
| _build_incomplete_snapshot_message( |
| repo_id, |
| snapshot_dir, |
| final_exc.discovered_files, |
| repo_data_files, |
| ), |
| discovered_files=final_exc.discovered_files, |
| ) from final_exc |
| return _load_data_files_as_training_dataset( |
| repo_id, |
| dataset_format, |
| data_files, |
| recovery_cache_dir=recovery_cache_dir, |
| ) |
| except Exception as fallback_exc: |
| logger.error("Dataseta ielādes kļūda: %s", fallback_exc) |
| raise |
|
|