MarisUK's picture
Maris AI model sync
f440f03 verified
"""Maris Hugging Face compatibility helpers for fully sanitized artifacts."""
from __future__ import annotations
import base64
import json
import logging
import os
import shutil
import tempfile
import zlib
from collections.abc import Iterator
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import Any
from maris_core.utils.env import get_env_any, validate_hf_model
MARIS_COMPATIBILITY_ARTIFACT_NAME = "maris-compatibility.json"
MARIS_COMPATIBILITY_ARTIFACT_TYPE = "maris-hf-compatibility"
MARIS_COMPATIBILITY_VERSION = 1
MARIS_MODEL_TYPE = "maris"
MARIS_MODEL_ARCHITECTURE = "MarisCompatibleCausalLM"
MARIS_TOKENIZER_CLASS = "MarisCompatibleTokenizer"
MARIS_CONFIG_CLASS = "MarisCompatibleConfig"
MARIS_PARENT_LIBRARY = "maris.compat"
logger = logging.getLogger(__name__)
_SANITIZED_COMPATIBILITY_FIELDS: dict[str, dict[str, Any]] = {
"config.json": {
"model_type": MARIS_MODEL_TYPE,
"architectures": [MARIS_MODEL_ARCHITECTURE],
"tokenizer_class": MARIS_TOKENIZER_CLASS,
"auto_map": {
"AutoConfig": MARIS_CONFIG_CLASS,
"AutoModelForCausalLM": MARIS_MODEL_ARCHITECTURE,
},
},
"tokenizer_config.json": {
"tokenizer_class": MARIS_TOKENIZER_CLASS,
"auto_map": {
"AutoTokenizer": [MARIS_TOKENIZER_CLASS, None],
},
},
"adapter_config.json": {
"base_model_class": MARIS_MODEL_ARCHITECTURE,
"parent_library": MARIS_PARENT_LIBRARY,
"auto_mapping": {
"base_model_class": MARIS_MODEL_ARCHITECTURE,
"parent_library": MARIS_PARENT_LIBRARY,
},
},
}
_RESTORABLE_COMPATIBILITY_FIELDS: dict[str, tuple[str, ...]] = {
"config.json": ("model_type", "architectures", "tokenizer_class", "auto_map"),
"tokenizer_config.json": ("tokenizer_class", "auto_map"),
"adapter_config.json": ("base_model_class", "parent_library", "auto_mapping"),
}
def _load_json_dict(path: Path) -> dict[str, Any] | None:
if not path.is_file():
return None
payload = json.loads(path.read_text(encoding="utf-8"))
return payload if isinstance(payload, dict) else None
def _save_json(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
def _encode_restore_payload(payload: dict[str, Any]) -> str:
try:
serialized = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode("utf-8")
compressed = zlib.compress(serialized, level=9)
return base64.urlsafe_b64encode(compressed).decode("ascii")
except (TypeError, ValueError, zlib.error) as exc:
raise ValueError("Neizdevās serializēt Maris HF compatibility payload.") from exc
def _decode_restore_payload(payload: str) -> dict[str, Any]:
try:
decoded = base64.urlsafe_b64decode(payload.encode("ascii"))
restored = json.loads(zlib.decompress(decoded).decode("utf-8"))
except (ValueError, json.JSONDecodeError, zlib.error) as exc:
raise ValueError("Maris HF compatibility payload ir bojāts vai nederīgs.") from exc
return restored if isinstance(restored, dict) else {}
def _build_restore_entries(output_dir: Path) -> dict[str, dict[str, str]]:
existing_manifest = _load_compatibility_manifest(output_dir) or {}
existing_artifacts = existing_manifest.get("artifacts", {})
entries: dict[str, dict[str, str]] = {}
for artifact_name, field_names in _RESTORABLE_COMPATIBILITY_FIELDS.items():
artifact_path = output_dir / artifact_name
payload = _load_json_dict(artifact_path)
if payload is None:
continue
existing_restore_fields: dict[str, Any] = {}
existing_entry = (
existing_artifacts.get(artifact_name) if isinstance(existing_artifacts, dict) else None
)
if isinstance(existing_entry, dict) and isinstance(existing_entry.get("payload"), str):
existing_restore_fields = _decode_restore_payload(existing_entry["payload"])
sanitized_fields = _SANITIZED_COMPATIBILITY_FIELDS.get(artifact_name, {})
restore_fields: dict[str, Any] = {}
for field_name in field_names:
if field_name not in payload and field_name not in existing_restore_fields:
continue
current_value = payload.get(field_name)
sanitized_value = sanitized_fields.get(field_name)
if field_name in payload and current_value != sanitized_value:
restore_fields[field_name] = current_value
continue
if field_name in existing_restore_fields:
restore_fields[field_name] = existing_restore_fields[field_name]
if not restore_fields:
continue
entries[artifact_name] = {
"encoding": "base64+zlib+json",
"payload": _encode_restore_payload(restore_fields),
}
return entries
def write_maris_compatibility_artifact(output_dir: Path, *, maris_model_id: str) -> None:
entries = _build_restore_entries(output_dir)
if not entries:
compatibility_path = output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME
if compatibility_path.exists():
compatibility_path.unlink()
return
_save_json(
output_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME,
{
"artifact_type": MARIS_COMPATIBILITY_ARTIFACT_TYPE,
"compatibility_version": MARIS_COMPATIBILITY_VERSION,
"maris_origin": "Maris AI",
"maris_model_id": maris_model_id,
"artifacts": entries,
},
)
def apply_maris_compatibility_identity(output_dir: Path) -> None:
for artifact_name, sanitized_fields in _SANITIZED_COMPATIBILITY_FIELDS.items():
artifact_path = output_dir / artifact_name
payload = _load_json_dict(artifact_path)
if payload is None:
continue
payload.update(sanitized_fields)
_save_json(artifact_path, payload)
def _load_compatibility_manifest(model_dir: Path) -> dict[str, Any] | None:
return _load_json_dict(model_dir / MARIS_COMPATIBILITY_ARTIFACT_NAME)
def has_maris_compatibility_artifact(model_dir: Path) -> bool:
manifest = _load_compatibility_manifest(model_dir)
return bool(
manifest
and manifest.get("artifact_type") == MARIS_COMPATIBILITY_ARTIFACT_TYPE
and isinstance(manifest.get("artifacts"), dict)
)
def _restore_compatibility_artifact(
payload: dict[str, Any], restore_fields: dict[str, Any]
) -> dict[str, Any]:
restored = dict(payload)
restored.update(restore_fields)
return restored
def _prepare_symlinked_model_dir(source_dir: Path) -> Path:
prepared_dir = Path(tempfile.mkdtemp(prefix="maris-hf-compat-"))
try:
for child in source_dir.iterdir():
target = prepared_dir / child.name
try:
os.symlink(child, target, target_is_directory=child.is_dir())
except OSError:
if child.is_dir():
shutil.copytree(child, target, dirs_exist_ok=True)
else:
shutil.copy2(child, target)
return prepared_dir
except Exception:
shutil.rmtree(prepared_dir, ignore_errors=True)
raise
def _resolve_repo_snapshot(
model_name_or_path: str, *, allow_remote_snapshot: bool | None = None
) -> Path | None:
if Path(model_name_or_path).exists():
return Path(model_name_or_path)
if allow_remote_snapshot is None:
allow_remote_snapshot = (
get_env_any(
"MARIS_HF_COMPAT_ALLOW_REMOTE_SNAPSHOT",
"MARIS_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT",
"HF_RUNTIME_COMPAT_ALLOW_REMOTE_SNAPSHOT",
default="",
)
.strip()
.lower()
in {"1", "true", "yes", "on"}
)
if not allow_remote_snapshot:
return None
model_name_or_path = validate_hf_model(
model_name_or_path,
"MARIS_RUNTIME_TEXT_MODEL/HF_RUNTIME_TEXT_MODEL/model_name_or_path",
)
try:
from huggingface_hub import snapshot_download # type: ignore
except ImportError:
return None
token = get_env_any("MARIS_REPO_TOKEN", "MARIS_TOKEN", "HF_TOKEN")
logger.info(
"Maris HF compatibility restore uses remote snapshot download for runtime model %s.",
model_name_or_path,
)
try:
logger.info(
"Lejupielādē runtime modeli compatibility restore vajadzībām: %s", model_name_or_path
)
snapshot_dir = snapshot_download(repo_id=model_name_or_path, repo_type="model", token=token)
except Exception as exc: # noqa: BLE001
logger.warning(
"Neizdevās lejupielādēt Maris HF compatibility snapshot modeli %s; "
"turpinu bez compatibility restore: %s",
model_name_or_path,
exc,
)
return None
return Path(snapshot_dir)
@contextmanager
def maris_hf_compatible_path(
model_name_or_path: str, *, allow_remote_snapshot: bool | None = None
) -> Iterator[str]:
source_dir = _resolve_repo_snapshot(
model_name_or_path, allow_remote_snapshot=allow_remote_snapshot
)
if source_dir is None or not source_dir.is_dir():
yield model_name_or_path
return
if not has_maris_compatibility_artifact(source_dir):
yield str(source_dir)
return
manifest = _load_compatibility_manifest(source_dir)
if manifest is None:
yield model_name_or_path
return
prepared_dir = _prepare_symlinked_model_dir(source_dir)
try:
artifacts = manifest.get("artifacts", {})
if isinstance(artifacts, dict):
for artifact_name, encoded_entry in artifacts.items():
if not isinstance(encoded_entry, dict):
continue
payload_blob = encoded_entry.get("payload")
if not isinstance(payload_blob, str):
continue
restore_fields = _decode_restore_payload(payload_blob)
prepared_artifact_path = prepared_dir / artifact_name
current_payload = _load_json_dict(prepared_artifact_path)
if current_payload is None:
continue
if prepared_artifact_path.is_symlink():
prepared_artifact_path.unlink()
_save_json(
prepared_artifact_path,
_restore_compatibility_artifact(current_payload, restore_fields),
)
yield str(prepared_dir)
finally:
with suppress(OSError):
shutil.rmtree(prepared_dir, ignore_errors=True)