MarisUK's picture
Maris AI model sync
f440f03 verified
"""Datasets konfigurācija un ielāde."""
from __future__ import annotations
import errno
import json
import logging
import os
import tempfile
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import Any
from maris_core.data.preprocessing import record_to_training_text
from maris_core.utils.env import get_env_any, get_env_any_or_default
logger = logging.getLogger(__name__)
_SUPPORTED_DATASET_FORMATS = {
".json": "json",
".jsonl": "json",
".csv": "csv",
".parquet": "parquet",
}
class HFDatasetError(FileNotFoundError):
"""Atmiņas repozitorija konfigurācijas vai satura kļūda."""
def __init__(self, message: str, *, discovered_files: list[str] | None = None) -> None:
super().__init__(message)
self.discovered_files = discovered_files or []
def _iter_exception_chain(exc: Exception) -> Iterator[BaseException]:
"""Iziet cauri izņēmuma cēloņu ķēdei bez cikliem."""
pending: list[BaseException] = [exc]
seen: set[int] = set()
while pending:
current = pending.pop()
current_id = id(current)
if current_id in seen:
continue
seen.add(current_id)
yield current
cause = getattr(current, "__cause__", None)
context = getattr(current, "__context__", None)
if cause is not None:
pending.append(cause)
if context is not None:
pending.append(context)
def _is_empty_dataset_error(exc: Exception) -> bool:
"""Nosaka, vai repozitorijam trūkst tieši ielādējamu datu failu."""
markers = (
"doesn't contain any data files",
"No (supported) data files found in",
)
return any(
current.__class__.__name__ == "EmptyDatasetError"
or any(marker in str(current) for marker in markers)
for current in _iter_exception_chain(exc)
)
def _is_schema_cast_error(exc: Exception) -> bool:
"""Nosaka, vai JSON loader nokrīt uz mainīgu objektu shēmu."""
markers = (
"Couldn't cast",
"Couldn't cast array of type",
)
return any(
any(marker in str(current) for marker in markers) for current in _iter_exception_chain(exc)
)
def _is_dataset_generation_error(exc: Exception) -> bool:
"""Nosaka, vai datasets slānis meta ģenerisku DatasetGenerationError."""
return any(
current.__class__.__name__ == "DatasetGenerationError"
for current in _iter_exception_chain(exc)
)
def _should_fallback_to_snapshot(exc: Exception) -> bool:
"""Nosaka, vai jāizmanto snapshot-based datu ielāde apmācībai."""
return (
_is_empty_dataset_error(exc)
or _is_schema_cast_error(exc)
or _is_dataset_generation_error(exc)
)
def _is_hf_cache_lock_io_error(exc: Exception) -> bool:
"""Nosaka, vai Hugging Face cache lock failu sistēmā radās I/O kļūda."""
for current in _iter_exception_chain(exc):
if not isinstance(current, OSError):
continue
message = str(current)
filename = os.fspath(getattr(current, "filename", "") or "")
details = " ".join(part for part in (filename, message) if part).lower()
if "huggingface" not in details and "datasets--" not in details:
continue
if "lock" not in details:
continue
if current.errno == errno.EIO or "input/output error" in details:
return True
return False
def _resolve_hf_cache_recovery_dir() -> str:
"""Atgriež drošu fallback cache direktoriju HF lejupielādēm."""
configured = get_env_any("MARIS_HF_FALLBACK_CACHE_DIR")
fallback_dir = (Path(tempfile.gettempdir()) / "maris-hf-cache").resolve()
candidate = Path(configured).expanduser() if configured else fallback_dir
try:
resolved = candidate.resolve(strict=False)
if resolved == Path(resolved.anchor):
raise ValueError(f"nedroša cache direktorija: {resolved}")
if resolved.exists() and not resolved.is_dir():
raise ValueError(f"cache ceļš nav direktorija: {resolved}")
resolved.mkdir(parents=True, exist_ok=True)
return str(resolved)
except (OSError, ValueError) as exc:
if configured:
logger.warning(
"Ignorējam nederīgu MARIS_HF_FALLBACK_CACHE_DIR=%s: %s; lietojam %s.",
configured,
exc,
fallback_dir,
)
fallback_dir.mkdir(parents=True, exist_ok=True)
return str(fallback_dir)
def _call_with_hf_cache_recovery(
action: str,
call: Callable[..., Any],
*args: Any,
recovery_cache_dir: str | None = None,
**kwargs: Any,
) -> tuple[Any, str | None]:
"""Izsauc HF funkciju un pie cache lock I/O kļūdas pārslēdzas uz fallback cache."""
call_kwargs = dict(kwargs)
if recovery_cache_dir is not None:
call_kwargs["cache_dir"] = recovery_cache_dir
return call(*args, **call_kwargs), recovery_cache_dir
try:
return call(*args, **call_kwargs), None
except Exception as exc: # noqa: BLE001
if not _is_hf_cache_lock_io_error(exc):
raise
recovery_cache_dir = _resolve_hf_cache_recovery_dir()
logger.warning(
"Hugging Face cache lock kļūda (%s); atkārtojam ar fallback cache %s.",
action,
recovery_cache_dir,
)
call_kwargs["cache_dir"] = recovery_cache_dir
return call(*args, **call_kwargs), recovery_cache_dir
def _preview_discovered_files(discovered_files: list[str]) -> str:
"""Atgriež īsu atrasto failu sarakstu kļūdu ziņojumiem."""
if not discovered_files:
return "nav"
return ", ".join(sorted(discovered_files)[:5])
def _preview_data_files(data_files: list[str]) -> str:
"""Atgriež īsu datu failu sarakstu kļūdu ziņojumiem."""
preview_paths: list[str] = []
for file_name in sorted(data_files)[:5]:
file_path = Path(file_name)
if "data" in file_path.parts:
data_index = file_path.parts.index("data")
preview_paths.append("/".join(file_path.parts[data_index:]))
continue
preview_paths.append(file_path.name)
return ", ".join(preview_paths) if preview_paths else "nav"
def _summarize_exception_messages(exc: Exception) -> str:
"""Savāc īsu izņēmumu ķēdes kopsavilkumu lietotājam."""
details: list[str] = []
for current in _iter_exception_chain(exc):
message = str(current).strip()
if not message or message in details:
continue
details.append(message)
return " | ".join(details[:3]) if details else exc.__class__.__name__
def _build_empty_repo_message(
repo_id: str,
snapshot_dir: Path,
discovered_files: list[str],
) -> str:
"""Izveido precīzu kļūdas ziņojumu tukšam atmiņas repozitorijam."""
preview = _preview_discovered_files(discovered_files)
return (
f"Maris atmiņas repo {repo_id} pašlaik nesatur nevienu atbalstītu datu failu "
f"(.jsonl, .json, .csv, .parquet). Snapshot direktorijā {snapshot_dir} tika "
f"atrasti tikai: {preview}. Lai salabotu origin repozitorijā, augšupielādē vismaz "
"vienu datu failu zem data/<type>/, piemēram "
"data/conversation/bootstrap.jsonl, un tad palaid apmācību vēlreiz. Ja lieto "
"Git LFS, pārliecinies, ka repozitorijā ir pats datu fails, nevis tikai "
".gitattributes ieraksts."
)
def _build_incomplete_snapshot_message(
repo_id: str,
snapshot_dir: Path,
discovered_files: list[str],
repo_data_files: list[str],
) -> str:
"""Izveido kļūdas ziņojumu, ja repo ir faili, bet snapshot tos nesatur."""
preview = _preview_discovered_files(discovered_files)
repo_preview = ", ".join(repo_data_files[:5])
return (
f"Maris atmiņas repo {repo_id} satur atbalstītus datu failus ({repo_preview}), bet "
f"snapshot direktorijā {snapshot_dir} tie netika atrasti. Snapshotā redzami: "
f"{preview}. Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka "
"tie ir pilnībā lejupielādēti un pieejami origin repozitorijā."
)
def _build_invalid_data_files_message(
repo_id: str,
data_files: list[str],
exc: Exception,
) -> str:
"""Izveido kļūdas ziņojumu bojātiem vai nederīgiem datu failiem."""
preview = _preview_data_files(data_files)
details = _summarize_exception_messages(exc)
return (
f"Maris atmiņas repo {repo_id} datu failus neizdevās nolasīt apmācībai ({preview}). "
f"Cēlonis: {details}. Pārbaudi, vai faili ir pilnībā augšupielādēti, UTF-8 kodējumā "
"un satur derīgus JSON/JSONL, CSV vai Parquet ierakstus."
)
def _find_snapshot_data_files(snapshot_dir: Path) -> tuple[str, list[str]]:
"""Atrod ielādējamus datu failus dataset snapshot direktorijā."""
snapshot_dir = _validate_snapshot_dir(snapshot_dir)
search_roots = [snapshot_dir / "data", snapshot_dir]
discovered_files: list[str] = []
for root in search_roots:
if not root.exists():
continue
files_by_format: dict[str, list[str]] = {}
for file_path in _iter_snapshot_files(root):
discovered_files.append(str(file_path.relative_to(snapshot_dir)))
dataset_format = _SUPPORTED_DATASET_FORMATS.get(file_path.suffix.lower())
if dataset_format is None:
continue
files_by_format.setdefault(dataset_format, []).append(str(file_path))
if not files_by_format:
continue
if len(files_by_format) > 1:
formats = ", ".join(sorted(files_by_format))
raise ValueError(
f"Snapshot satur vairākus atbalstītus datu formātus vienlaikus: {formats}"
)
dataset_format, files = next(iter(files_by_format.items()))
return dataset_format, files
if discovered_files:
preview = _preview_discovered_files(discovered_files)
raise HFDatasetError(
"Dataset snapshot direktorijā nav atrasti atbalstīti dati: "
f"{snapshot_dir}. Atrastie faili: {preview}. "
"Ja dataset repozitorijs glabā datus ar Git LFS, pārliecinieties, ka "
"tie ir pilnībā lejupielādēti.",
discovered_files=discovered_files,
)
raise HFDatasetError(
f"Dataset snapshot direktorijā nav atrasti atbalstīti dati: {snapshot_dir}"
)
def _validate_snapshot_dir(snapshot_dir: Path) -> Path:
"""Normalizē un pārbauda snapshot direktoriju pirms rekursīvas skenēšanas."""
resolved = snapshot_dir.expanduser().resolve()
if not resolved.exists() or not resolved.is_dir():
raise HFDatasetError(f"Dataset snapshot direktorija nav derīga: {resolved}")
if resolved == Path(resolved.anchor):
raise HFDatasetError(
f"Dataset snapshot direktorija nav droša rekursīvai skenēšanai: {resolved}"
)
return resolved
def _is_invalid_snapshot_dir_error(exc: HFDatasetError) -> bool:
"""Nosaka, vai snapshot fallback atgrieza nederīgu vai bīstamu ceļu."""
message = str(exc)
return "nav derīga" in message or "nav droša rekursīvai skenēšanai" in message
def _iter_snapshot_files(root: Path) -> Iterator[Path]:
"""Iterē snapshot failus, izlaižot nederīgus sistēmas mezglus un OSError ceļus."""
def handle_walk_error(exc: OSError) -> None:
logger.warning(
"Skipping unreadable dataset snapshot path %s: %s",
getattr(exc, "filename", root),
exc,
)
for current_root, dirnames, filenames in os.walk(
root,
topdown=True,
onerror=handle_walk_error,
followlinks=False,
):
dirnames.sort()
for filename in sorted(filenames):
candidate = Path(current_root) / filename
try:
if candidate.is_file():
yield candidate
except OSError as exc:
logger.warning("Skipping unreadable dataset snapshot file %s: %s", candidate, exc)
def _list_repo_data_files(repo_id: str, token: str | None) -> list[str]:
"""Atrod atbalstītos datu failus tieši atmiņas repozitorijā."""
from huggingface_hub import HfApi # type: ignore
api = HfApi(token=token)
repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
return sorted(
path
for path in repo_files
if _SUPPORTED_DATASET_FORMATS.get(Path(path).suffix.lower()) is not None
)
def _coerce_training_record(record: Any) -> dict[str, Any]:
"""Normalizē vienu ierakstu uz stabilu text-only formu apmācības datasetam."""
if isinstance(record, dict):
return {"text": record_to_training_text(record)}
try:
serialized = json.dumps(record, ensure_ascii=False, sort_keys=True)
except TypeError:
serialized = str(record)
return {"text": serialized}
def _build_train_dataset(record_factory: Callable[[], Iterator[dict[str, Any]]]) -> Any:
"""Izveido train split datasetu no ierakstu ģeneratora."""
from datasets import Dataset, DatasetDict # type: ignore
return DatasetDict({"train": Dataset.from_generator(record_factory)})
def _iter_dataset_splits(dataset: Any) -> Iterator[tuple[str, Any]]:
"""Atgriež pieejamos dataset splitus neatkarīgi no konkrētā tipa."""
if isinstance(dataset, dict):
yield from dataset.items()
return
items = getattr(dataset, "items", None)
if callable(items):
yield from items()
return
keys = getattr(dataset, "keys", None)
if callable(keys):
for key in keys():
yield key, dataset[key]
def _normalize_training_dataset(dataset: Any, repo_id: str) -> Any:
"""Pārveido multi-split datasetu uz train split apmācībai."""
split_items = list(_iter_dataset_splits(dataset))
if not split_items or any(name == "train" for name, _ in split_items):
return dataset
logger.warning(
"Maris atmiņas repo %s atgrieza splitus bez 'train'; apvienojam tos vienā train split apmācībai.",
repo_id,
)
def iter_records() -> Iterator[dict[str, Any]]:
for _, split in split_items:
for record in split:
yield _coerce_training_record(record)
return _build_train_dataset(iter_records)
def _iter_json_payload_records(payload: Any) -> Iterator[dict[str, Any]]:
"""Izvērš JSON payload uz ierakstu plūsmu."""
if isinstance(payload, list):
for item in payload:
yield _coerce_training_record(item)
return
yield _coerce_training_record(payload)
def _iter_json_file_records(data_files: list[str]) -> Iterator[dict[str, Any]]:
"""Nolasa JSON/JSONL failus bez stingras nested-object shēmas."""
for file_name in data_files:
file_path = Path(file_name)
suffix = file_path.suffix.lower()
if suffix == ".jsonl":
with file_path.open(encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
yield _coerce_training_record(json.loads(stripped))
continue
if suffix == ".json":
payload = json.loads(file_path.read_text(encoding="utf-8"))
yield from _iter_json_payload_records(payload)
continue
raise ValueError(f"Neatbalstīts JSON datu fails: {file_path}")
def _load_data_files_as_training_dataset(
repo_id: str,
dataset_format: str,
data_files: list[str],
*,
recovery_cache_dir: str | None = None,
) -> Any:
"""Ielādē atrastos datu failus apmācībai kā train split."""
try:
if dataset_format == "json":
return _build_train_dataset(lambda: _iter_json_file_records(data_files))
from datasets import load_dataset # type: ignore
dataset, _ = _call_with_hf_cache_recovery(
f"load_dataset({dataset_format})",
load_dataset,
dataset_format,
data_files={"train": data_files},
recovery_cache_dir=recovery_cache_dir,
)
return dataset
except Exception as exc: # noqa: BLE001
raise HFDatasetError(_build_invalid_data_files_message(repo_id, data_files, exc)) from exc
def load_hf_dataset(repo_id: str | None = None) -> Any:
"""Ielādē Maris atmiņas datasets."""
repo_id = repo_id or get_env_any_or_default(
"MARIS_MEMORY_REPO",
"MARIS_DATASET_REPO",
"HF_DATASET_REPO",
default="MarisUK/maris-ai-memory",
)
token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN")
recovery_cache_dir: str | None = None
try:
from datasets import load_dataset # type: ignore
logger.info("Ielādē dataset: %s", repo_id)
dataset, recovery_cache_dir = _call_with_hf_cache_recovery(
f"load_dataset({repo_id})",
load_dataset,
repo_id,
token=token,
recovery_cache_dir=recovery_cache_dir,
)
return _normalize_training_dataset(dataset, repo_id)
except Exception as exc: # noqa: BLE001
if not _should_fallback_to_snapshot(exc):
logger.error("Dataseta ielādes kļūda: %s", exc)
raise
try:
from huggingface_hub import snapshot_download # type: ignore
logger.warning(
"Maris atmiņas repo %s nevarēja ielādēt tieši; mēģinām apmācības snapshot fallback.",
repo_id,
)
snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery(
f"snapshot_download({repo_id})",
snapshot_download,
repo_id=repo_id,
repo_type="dataset",
token=token,
recovery_cache_dir=recovery_cache_dir,
)
snapshot_dir = Path(snapshot_dir_str)
try:
dataset_format, data_files = _find_snapshot_data_files(snapshot_dir)
except HFDatasetError as missing_exc:
if _is_invalid_snapshot_dir_error(missing_exc):
raise
repo_data_files = _list_repo_data_files(repo_id, token)
if not repo_data_files:
raise HFDatasetError(
_build_empty_repo_message(
repo_id,
snapshot_dir,
missing_exc.discovered_files,
),
discovered_files=missing_exc.discovered_files,
) from missing_exc
logger.warning(
"Snapshot cache nesatur dataset failus; lejupielādējam atrastos data failus tieši: %s",
", ".join(repo_data_files[:5]),
)
snapshot_dir_str, recovery_cache_dir = _call_with_hf_cache_recovery(
f"snapshot_download({repo_id}, allow_patterns)",
snapshot_download,
repo_id=repo_id,
repo_type="dataset",
token=token,
allow_patterns=repo_data_files,
recovery_cache_dir=recovery_cache_dir,
)
snapshot_dir = Path(snapshot_dir_str)
try:
dataset_format, data_files = _find_snapshot_data_files(snapshot_dir)
except HFDatasetError as final_exc:
raise HFDatasetError(
_build_incomplete_snapshot_message(
repo_id,
snapshot_dir,
final_exc.discovered_files,
repo_data_files,
),
discovered_files=final_exc.discovered_files,
) from final_exc
return _load_data_files_as_training_dataset(
repo_id,
dataset_format,
data_files,
recovery_cache_dir=recovery_cache_dir,
)
except Exception as fallback_exc: # noqa: BLE001
logger.error("Dataseta ielādes kļūda: %s", fallback_exc)
raise