"""Maris Hugging Face compatibility helpers for fully sanitized artifacts.""" from __future__ import annotations import base64 import json import logging import os import shutil import tempfile import zlib from collections.abc import Iterator from contextlib import contextmanager, suppress from pathlib import Path from typing import Any from maris_core.utils.env import get_env_any, validate_hf_model MARIS_COMPATIBILITY_ARTIFACT_NAME = "maris-compatibility.json" MARIS_COMPATIBILITY_ARTIFACT_TYPE = "maris-hf-compatibility" MARIS_COMPATIBILITY_VERSION = 1 MARIS_MODEL_TYPE = "maris" MARIS_MODEL_ARCHITECTURE = "MarisCompatibleCausalLM" MARIS_TOKENIZER_CLASS = "MarisCompatibleTokenizer" MARIS_CONFIG_CLASS = "MarisCompatibleConfig" MARIS_PARENT_LIBRARY = "maris.compat" logger = logging.getLogger(__name__) _SANITIZED_COMPATIBILITY_FIELDS: dict[str, dict[str, Any]] = { "config.json": { "model_type": MARIS_MODEL_TYPE, "architectures": [MARIS_MODEL_ARCHITECTURE], "tokenizer_class": MARIS_TOKENIZER_CLASS, "auto_map": { "AutoConfig": MARIS_CONFIG_CLASS, "AutoModelForCausalLM": MARIS_MODEL_ARCHITECTURE, }, }, "tokenizer_config.json": { "tokenizer_class": MARIS_TOKENIZER_CLASS, "auto_map": { "AutoTokenizer": [MARIS_TOKENIZER_CLASS, None], }, }, "adapter_config.json": { "base_model_class": MARIS_MODEL_ARCHITECTURE, "parent_library": MARIS_PARENT_LIBRARY, "auto_mapping": { "base_model_class": MARIS_MODEL_ARCHITECTURE, "parent_library": MARIS_PARENT_LIBRARY, }, }, } _RESTORABLE_COMPATIBILITY_FIELDS: dict[str, tuple[str, ...]] = { "config.json": ("model_type", "architectures", "tokenizer_class", "auto_map"), "tokenizer_config.json": ("tokenizer_class", "auto_map"), "adapter_config.json": ("base_model_class", "parent_library", "auto_mapping"), } def _load_json_dict(path: Path) -> dict[str, Any] | None: if not path.is_file(): return None payload = json.loads(path.read_text(encoding="utf-8")) return payload if isinstance(payload, dict) else None def _save_json(path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") def _encode_restore_payload(payload: dict[str, Any]) -> str: try: serialized = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8") compressed = zlib.compress(serialized, level=9) return base64.urlsafe_b64encode(compressed).decode("ascii") except (TypeError, ValueError, zlib.error) as exc: raise ValueError("Neizdevās serializēt Maris HF compatibility payload.") from exc def _decode_restore_payload(payload: str) -> dict[str, Any]: try: decoded = base64.urlsafe_b64decode(payload.encode("ascii")) restored = json.loads(zlib.decompress(decoded).decode("utf-8")) except (ValueError, json.JSONDecodeError, zlib.error) as exc: raise ValueError("Maris HF compatibility payload ir bojāts vai nederīgs.") from exc return restored if isinstance(restored, dict) else {} def _build_restore_entries(output_dir: Path) -> dict[str, dict[str, str]]: existing_manifest = _load_compatibility_manifest(output_dir) or {} existing_artifacts = existing_manifest.get("artifacts", {}) entries: dict[str, dict[str, str]] = {} for artifact_name, field_names in _RESTORABLE_COMPATIBILITY_FIELDS.items(): artifact_path = output_dir / artifact_name payload = _load_json_dict(artifact_path) if payload is None: continue existing_restore_fields: dict[str, Any] = {} existing_entry = ( existing_artifacts.get(artifact_name) if isinstance(existing_artifacts, dict) else None ) if isinstance(existing_entry, dict) and isinstance(existing_entry.get("payload"), str): existing_restore_fields = _decode_restore_payload(existing_entry["payload"]) sanitized_fields = _SANITIZED_COMPATIBILITY_FIELDS.get(artifact_name, {}) restore_fields: dict[str, Any] = {} for field_name in field_names: if field_name not in payload and field_name not in existing_restore_fields: continue current_value = payload.get(field_name) sanitized_value = sanitized_fields.get(field_name) if field_name in payload and current_value != sanitized_value: restore_fields[field_name] = current_value continue if field_name in existing_restore_fields: restore_fields[field_name] = existing_restore_fields[field_name] if not restore_fields: continue entries[artifact_name] = { "encoding": "base64+zlib+json", "payload": _encode_restore_payload(restore_fields), } return entries def write_maris_compatibility_artifact(output_dir: Path, *, maris_model_id: str) -> None: entries = _build_restore_entries(output_dir) if not entries: compatibility_path = output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME if compatibility_path.exists(): compatibility_path.unlink() return _save_json( output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME, { "artifact_type": MARIS_COMPATIBILITY_ARTIFACT_TYPE, "compatibility_version": MARIS_COMPATIBILITY_VERSION, "maris_origin": "Maris AI", "maris_model_id": maris_model_id, "artifacts": entries, }, ) def apply_maris_compatibility_identity(output_dir: Path) -> None: for artifact_name, sanitized_fields in _SANITIZED_COMPATIBILITY_FIELDS.items(): artifact_path = output_dir / artifact_name payload = _load_json_dict(artifact_path) if payload is None: continue payload.update(sanitized_fields) _save_json(artifact_path, payload) def _load_compatibility_manifest(model_dir: Path) -> dict[str, Any] | None: return _load_json_dict(model_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME) def has_maris_compatibility_artifact(model_dir: Path) -> bool: manifest = _load_compatibility_manifest(model_dir) return bool( manifest and manifest.get("artifact_type") == MARIS_COMPATIBILITY_ARTIFACT_TYPE and isinstance(manifest.get("artifacts"), dict) ) def _restore_compatibility_artifact( payload: dict[str, Any], restore_fields: dict[str, Any] ) -> dict[str, Any]: restored = dict(payload) restored.update(restore_fields) return restored def _prepare_symlinked_model_dir(source_dir: Path) -> Path: prepared_dir = Path(tempfile.mkdtemp(prefix="maris-hf-compat-")) try: for child in source_dir.iterdir(): target = prepared_dir / child.name try: os.symlink(child, target, target_is_directory=child.is_dir()) except OSError: if child.is_dir(): shutil.copytree(child, target, dirs_exist_ok=True) else: shutil.copy2(child, target) return prepared_dir except Exception: shutil.rmtree(prepared_dir, ignore_errors=True) raise def _resolve_repo_snapshot( model_name_or_path: str, *, allow_remote_snapshot: bool | None = None ) -> Path | None: if Path(model_name_or_path).exists(): return Path(model_name_or_path) if allow_remote_snapshot is None: allow_remote_snapshot = ( get_env_any( "MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT", "MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", "HF_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT", default="", ) .strip() .lower() in {"1", "true", "yes", "on"} ) if not allow_remote_snapshot: return None model_name_or_path = validate_hf_model( model_name_or_path, "MARIS_RUNTIME_TEXT_MODEL/HF_RUNTIME_TEXT_MODEL/model_name_or_path", ) try: from huggingface_hub import snapshot_download # type: ignore except ImportError: return None token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN") logger.info( "Maris HF compatibility restore uses remote snapshot download for runtime model %s.", model_name_or_path, ) try: logger.info( "Lejupielādē runtime modeli compatibility restore vajadzībām: %s", model_name_or_path ) snapshot_dir = snapshot_download(repo_id=model_name_or_path, repo_type="model", token=token) except Exception as exc: # noqa: BLE001 logger.warning( "Neizdevās lejupielādēt Maris HF compatibility snapshot modeli %s; " "turpinu bez compatibility restore: %s", model_name_or_path, exc, ) return None return Path(snapshot_dir) @contextmanager def maris_hf_compatible_path( model_name_or_path: str, *, allow_remote_snapshot: bool | None = None ) -> Iterator[str]: source_dir = _resolve_repo_snapshot( model_name_or_path, allow_remote_snapshot=allow_remote_snapshot ) if source_dir is None or not source_dir.is_dir(): yield model_name_or_path return if not has_maris_compatibility_artifact(source_dir): yield str(source_dir) return manifest = _load_compatibility_manifest(source_dir) if manifest is None: yield model_name_or_path return prepared_dir = _prepare_symlinked_model_dir(source_dir) try: artifacts = manifest.get("artifacts", {}) if isinstance(artifacts, dict): for artifact_name, encoded_entry in artifacts.items(): if not isinstance(encoded_entry, dict): continue payload_blob = encoded_entry.get("payload") if not isinstance(payload_blob, str): continue restore_fields = _decode_restore_payload(payload_blob) prepared_artifact_path = prepared_dir / artifact_name current_payload = _load_json_dict(prepared_artifact_path) if current_payload is None: continue if prepared_artifact_path.is_symlink(): prepared_artifact_path.unlink() _save_json( prepared_artifact_path, _restore_compatibility_artifact(current_payload, restore_fields), ) yield str(prepared_dir) finally: with suppress(OSError): shutil.rmtree(prepared_dir, ignore_errors=True)