| """Maris Hugging Face compatibility helpers for fully sanitized artifacts.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import json |
| import logging |
| import os |
| import shutil |
| import tempfile |
| import zlib |
| from collections.abc import Iterator |
| from contextlib import contextmanager, suppress |
| from pathlib import Path |
| from typing import Any |
|
|
| from maris_core.utils.env import get_env_any, validate_hf_model |
|
|
| MARIS_COMPATIBILITY_ARTIFACT_NAME = "maris-compatibility.json" |
| MARIS_COMPATIBILITY_ARTIFACT_TYPE = "maris-hf-compatibility" |
| MARIS_COMPATIBILITY_VERSION = 1 |
| MARIS_MODEL_TYPE = "maris" |
| MARIS_MODEL_ARCHITECTURE = "MarisCompatibleCausalLM" |
| MARIS_TOKENIZER_CLASS = "MarisCompatibleTokenizer" |
| MARIS_CONFIG_CLASS = "MarisCompatibleConfig" |
| MARIS_PARENT_LIBRARY = "maris.compat" |
| logger = logging.getLogger(__name__) |
|
|
| _SANITIZED_COMPATIBILITY_FIELDS: dict[str, dict[str, Any]] = { |
| "config.json": { |
| "model_type": MARIS_MODEL_TYPE, |
| "architectures": [MARIS_MODEL_ARCHITECTURE], |
| "tokenizer_class": MARIS_TOKENIZER_CLASS, |
| "auto_map": { |
| "AutoConfig": MARIS_CONFIG_CLASS, |
| "AutoModelForCausalLM": MARIS_MODEL_ARCHITECTURE, |
| }, |
| }, |
| "tokenizer_config.json": { |
| "tokenizer_class": MARIS_TOKENIZER_CLASS, |
| "auto_map": { |
| "AutoTokenizer": [MARIS_TOKENIZER_CLASS, None], |
| }, |
| }, |
| "adapter_config.json": { |
| "base_model_class": MARIS_MODEL_ARCHITECTURE, |
| "parent_library": MARIS_PARENT_LIBRARY, |
| "auto_mapping": { |
| "base_model_class": MARIS_MODEL_ARCHITECTURE, |
| "parent_library": MARIS_PARENT_LIBRARY, |
| }, |
| }, |
| } |
|
|
| _RESTORABLE_COMPATIBILITY_FIELDS: dict[str, tuple[str, ...]] = { |
| "config.json": ("model_type", "architectures", "tokenizer_class", "auto_map"), |
| "tokenizer_config.json": ("tokenizer_class", "auto_map"), |
| "adapter_config.json": ("base_model_class", "parent_library", "auto_mapping"), |
| } |
|
|
|
|
| def _load_json_dict(path: Path) -> dict[str, Any] | None: |
| if not path.is_file(): |
| return None |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| return payload if isinstance(payload, dict) else None |
|
|
|
|
| def _save_json(path: Path, payload: dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") |
|
|
|
|
| def _encode_restore_payload(payload: dict[str, Any]) -> str: |
| try: |
| serialized = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8") |
| compressed = zlib.compress(serialized, level=9) |
| return base64.urlsafe_b64encode(compressed).decode("ascii") |
| except (TypeError, ValueError, zlib.error) as exc: |
| raise ValueError("Neizdevās serializēt Maris HF compatibility payload.") from exc |
|
|
|
|
| def _decode_restore_payload(payload: str) -> dict[str, Any]: |
| try: |
| decoded = base64.urlsafe_b64decode(payload.encode("ascii")) |
| restored = json.loads(zlib.decompress(decoded).decode("utf-8")) |
| except (ValueError, json.JSONDecodeError, zlib.error) as exc: |
| raise ValueError("Maris HF compatibility payload ir bojāts vai nederīgs.") from exc |
| return restored if isinstance(restored, dict) else {} |
|
|
|
|
| def _build_restore_entries(output_dir: Path) -> dict[str, dict[str, str]]: |
| existing_manifest = _load_compatibility_manifest(output_dir) or {} |
| existing_artifacts = existing_manifest.get("artifacts", {}) |
| entries: dict[str, dict[str, str]] = {} |
| for artifact_name, field_names in _RESTORABLE_COMPATIBILITY_FIELDS.items(): |
| artifact_path = output_dir / artifact_name |
| payload = _load_json_dict(artifact_path) |
| if payload is None: |
| continue |
| existing_restore_fields: dict[str, Any] = {} |
| existing_entry = ( |
| existing_artifacts.get(artifact_name) if isinstance(existing_artifacts, dict) else None |
| ) |
| if isinstance(existing_entry, dict) and isinstance(existing_entry.get("payload"), str): |
| existing_restore_fields = _decode_restore_payload(existing_entry["payload"]) |
| sanitized_fields = _SANITIZED_COMPATIBILITY_FIELDS.get(artifact_name, {}) |
| restore_fields: dict[str, Any] = {} |
| for field_name in field_names: |
| if field_name not in payload and field_name not in existing_restore_fields: |
| continue |
| current_value = payload.get(field_name) |
| sanitized_value = sanitized_fields.get(field_name) |
| if field_name in payload and current_value != sanitized_value: |
| restore_fields[field_name] = current_value |
| continue |
| if field_name in existing_restore_fields: |
| restore_fields[field_name] = existing_restore_fields[field_name] |
| if not restore_fields: |
| continue |
| entries[artifact_name] = { |
| "encoding": "base64+zlib+json", |
| "payload": _encode_restore_payload(restore_fields), |
| } |
| return entries |
|
|
|
|
| def write_maris_compatibility_artifact(output_dir: Path, *, maris_model_id: str) -> None: |
| entries = _build_restore_entries(output_dir) |
| if not entries: |
| compatibility_path = output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME |
| if compatibility_path.exists(): |
| compatibility_path.unlink() |
| return |
| _save_json( |
| output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME, |
| { |
| "artifact_type": MARIS_COMPATIBILITY_ARTIFACT_TYPE, |
| "compatibility_version": MARIS_COMPATIBILITY_VERSION, |
| "maris_origin": "Maris AI", |
| "maris_model_id": maris_model_id, |
| "artifacts": entries, |
| }, |
| ) |
|
|
|
|
| def apply_maris_compatibility_identity(output_dir: Path) -> None: |
| for artifact_name, sanitized_fields in _SANITIZED_COMPATIBILITY_FIELDS.items(): |
| artifact_path = output_dir / artifact_name |
| payload = _load_json_dict(artifact_path) |
| if payload is None: |
| continue |
| payload.update(sanitized_fields) |
| _save_json(artifact_path, payload) |
|
|
|
|
| def _load_compatibility_manifest(model_dir: Path) -> dict[str, Any] | None: |
| return _load_json_dict(model_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME) |
|
|
|
|
| def has_maris_compatibility_artifact(model_dir: Path) -> bool: |
| manifest = _load_compatibility_manifest(model_dir) |
| return bool( |
| manifest |
| and manifest.get("artifact_type") == MARIS_COMPATIBILITY_ARTIFACT_TYPE |
| and isinstance(manifest.get("artifacts"), dict) |
| ) |
|
|
|
|
| def _restore_compatibility_artifact( |
| payload: dict[str, Any], restore_fields: dict[str, Any] |
| ) -> dict[str, Any]: |
| restored = dict(payload) |
| restored.update(restore_fields) |
| return restored |
|
|
|
|
| def _prepare_symlinked_model_dir(source_dir: Path) -> Path: |
| prepared_dir = Path(tempfile.mkdtemp(prefix="maris-hf-compat-")) |
| try: |
| for child in source_dir.iterdir(): |
| target = prepared_dir / child.name |
| try: |
| os.symlink(child, target, target_is_directory=child.is_dir()) |
| except OSError: |
| if child.is_dir(): |
| shutil.copytree(child, target, dirs_exist_ok=True) |
| else: |
| shutil.copy2(child, target) |
| return prepared_dir |
| except Exception: |
| shutil.rmtree(prepared_dir, ignore_errors=True) |
| raise |
|
|
|
|
| def _resolve_repo_snapshot( |
| model_name_or_path: str, *, allow_remote_snapshot: bool | None = None |
| ) -> Path | None: |
| if Path(model_name_or_path).exists(): |
| return Path(model_name_or_path) |
| if allow_remote_snapshot is None: |
| allow_remote_snapshot = ( |
| get_env_any( |
| "MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT", |
| "MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", |
| "HF_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", |
| default="", |
| ) |
| .strip() |
| .lower() |
| in {"1", "true", "yes", "on"} |
| ) |
| if not allow_remote_snapshot: |
| return None |
| model_name_or_path = validate_hf_model( |
| model_name_or_path, |
| "MARIS_RUNTIME_TEXT_MODEL/HF_RUNTIME_TEXT_MODEL/model_name_or_path", |
| ) |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| return None |
| token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") |
| logger.info( |
| "Maris HF compatibility restore uses remote snapshot download for runtime model %s.", |
| model_name_or_path, |
| ) |
| try: |
| logger.info( |
| "Lejupielādē runtime modeli compatibility restore vajadzībām: %s", model_name_or_path |
| ) |
| snapshot_dir = snapshot_download(repo_id=model_name_or_path, repo_type="model", token=token) |
| except Exception as exc: |
| logger.warning( |
| "Neizdevās lejupielādēt Maris HF compatibility snapshot modeli %s; " |
| "turpinu bez compatibility restore: %s", |
| model_name_or_path, |
| exc, |
| ) |
| return None |
| return Path(snapshot_dir) |
|
|
|
|
| @contextmanager |
| def maris_hf_compatible_path( |
| model_name_or_path: str, *, allow_remote_snapshot: bool | None = None |
| ) -> Iterator[str]: |
| source_dir = _resolve_repo_snapshot( |
| model_name_or_path, allow_remote_snapshot=allow_remote_snapshot |
| ) |
| if source_dir is None or not source_dir.is_dir(): |
| yield model_name_or_path |
| return |
| if not has_maris_compatibility_artifact(source_dir): |
| yield str(source_dir) |
| return |
|
|
| manifest = _load_compatibility_manifest(source_dir) |
| if manifest is None: |
| yield model_name_or_path |
| return |
|
|
| prepared_dir = _prepare_symlinked_model_dir(source_dir) |
| try: |
| artifacts = manifest.get("artifacts", {}) |
| if isinstance(artifacts, dict): |
| for artifact_name, encoded_entry in artifacts.items(): |
| if not isinstance(encoded_entry, dict): |
| continue |
| payload_blob = encoded_entry.get("payload") |
| if not isinstance(payload_blob, str): |
| continue |
| restore_fields = _decode_restore_payload(payload_blob) |
| prepared_artifact_path = prepared_dir / artifact_name |
| current_payload = _load_json_dict(prepared_artifact_path) |
| if current_payload is None: |
| continue |
| if prepared_artifact_path.is_symlink(): |
| prepared_artifact_path.unlink() |
| _save_json( |
| prepared_artifact_path, |
| _restore_compatibility_artifact(current_payload, restore_fields), |
| ) |
| yield str(prepared_dir) |
| finally: |
| with suppress(OSError): |
| shutil.rmtree(prepared_dir, ignore_errors=True) |
|
|