| """Eksportē modeli uz Maris origin repozitoriju.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| from pathlib import Path |
|
|
| from maris_core.utils.env import validate_maris_model |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _getenv_any(*names: str, default: str) -> str: |
| for name in names: |
| value = os.getenv(name, "").strip() |
| if value: |
| return value |
| return default |
|
|
|
|
| def _upload_folder(api: object, folder_path: str, repo_id: str, commit_message: str) -> None: |
| logger.info("Eksportē modeli: %s -> %s", folder_path, repo_id) |
| api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) |
| api.upload_folder( |
| folder_path=folder_path, |
| repo_id=repo_id, |
| repo_type="model", |
| commit_message=commit_message, |
| ) |
|
|
|
|
| def _branch_suite_export_targets( |
| model_path: str, |
| repo_id: str, |
| commit_message: str, |
| ) -> list[tuple[str, str, str]]: |
| output_dir = Path(model_path) |
| manifest_path = output_dir / "branch-suite.json" |
| suite = json.loads(manifest_path.read_text(encoding="utf-8")) if manifest_path.is_file() else {} |
| branch_exports = { |
| "master": _getenv_any( |
| "MARIS_TEXT_MODEL_REPO", |
| "HF_TEXT_MODEL_REPO", |
| "TEXT_MODEL", |
| default="MarisUK/maris-ai-text", |
| ), |
| "coder": _getenv_any( |
| "MARIS_CODEX_MODEL_REPO", |
| "HF_CODEX_MODEL_REPO", |
| "CODEX_MODEL", |
| default="MarisUK/maris-ai-codex", |
| ), |
| "image": os.getenv("IMAGE_MODEL", "MarisUK/maris-ai-image"), |
| "music": os.getenv("MUSIC_MODEL", "MarisUK/maris-ai-music"), |
| "tts": os.getenv("TTS_MODEL", "MarisUK/maris-tts-runtime"), |
| "stt": os.getenv("STT_MODEL", "MarisUK/maris-stt-runtime"), |
| "video": os.getenv("VIDEO_MODEL", "MarisUK/maris-ai-video"), |
| } |
| targets: list[tuple[str, str, str]] = [(str(output_dir), repo_id, commit_message)] |
| seen: set[tuple[str, str]] = {(str(output_dir), repo_id)} |
| for branch_name, target_repo in branch_exports.items(): |
| branch_output_dir = _resolve_branch_output_dir(output_dir, suite, branch_name) |
| if branch_output_dir is None: |
| continue |
| validated_repo = validate_maris_model(target_repo, f"{branch_name} export target") |
| normalized_dir = str(branch_output_dir) |
| key = (normalized_dir, validated_repo) |
| if key in seen: |
| continue |
| seen.add(key) |
| targets.append( |
| ( |
| normalized_dir, |
| validated_repo, |
| f"{commit_message} ({branch_name})", |
| ) |
| ) |
| return targets |
|
|
|
|
| def _resolve_branch_output_dir( |
| output_dir: Path, |
| suite: dict[str, object], |
| branch_name: str, |
| ) -> Path | None: |
| branches = suite.get("branches", {}) |
| if isinstance(branches, dict): |
| branch_payload = branches.get(branch_name) |
| if isinstance(branch_payload, dict): |
| candidate = str(branch_payload.get("output_dir", "")).strip() |
| if candidate: |
| branch_dir = Path(candidate) |
| candidate_roots = (Path.cwd(), output_dir) if not branch_dir.is_absolute() else (None,) |
| for root in candidate_roots: |
| resolved_dir = branch_dir if root is None else (root / candidate).resolve() |
| if resolved_dir.is_dir(): |
| return resolved_dir |
|
|
| fallback_names = { |
| "master": ("master", "text"), |
| "coder": ("coder", "codex"), |
| }.get(branch_name, (branch_name,)) |
| for fallback_name in fallback_names: |
| fallback_dir = (output_dir / fallback_name).resolve() |
| if fallback_dir.is_dir(): |
| return fallback_dir |
| return None |
|
|
|
|
| def export_model( |
| model_path: str, |
| repo_id: str | None = None, |
| commit_message: str = "Maris AI model export", |
| ) -> None: |
| """Augšupielādē modeli uz origin repozitoriju.""" |
| repo_id = ( |
| repo_id |
| or os.getenv("MARIS_MODEL_REPO") |
| or os.getenv("HF_MODEL_REPO", "MarisUK/maris-ai-master") |
| ) |
| repo_id = validate_maris_model(repo_id, "MARIS_MODEL_REPO/HF_MODEL_REPO/--repo-id") |
| token = os.getenv("MARIS_REPO_TOKEN") or os.getenv("MARIS_TOKEN") or os.getenv("HF_TOKEN") |
|
|
| try: |
| from huggingface_hub import HfApi |
|
|
| api = HfApi(token=token) |
| for export_path, export_repo, export_commit in _branch_suite_export_targets( |
| model_path, |
| repo_id, |
| commit_message, |
| ): |
| _upload_folder(api, export_path, export_repo, export_commit) |
| logger.info("Eksportēšana pabeigta!") |
| except Exception as exc: |
| logger.error("Eksportēšanas kļūda: %s", exc) |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Eksportē Maris AI modeli uz origin repozitoriju") |
| parser.add_argument("--model-path", required=True, help="Modeļa direktorija") |
| parser.add_argument("--repo-id", help="Origin repo ID") |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO) |
| export_model(args.model_path, args.repo_id) |
|
|