File size: 5,226 Bytes
f440f03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """Eksportē modeli uz Maris origin repozitoriju."""
from __future__ import annotations
import argparse
import json
import logging
import os
from pathlib import Path
from maris_core.utils.env import validate_maris_model
logger = logging.getLogger(__name__)
def _getenv_any(*names: str, default: str) -> str:
for name in names:
value = os.getenv(name, "").strip()
if value:
return value
return default
def _upload_folder(api: object, folder_path: str, repo_id: str, commit_message: str) -> None:
logger.info("Eksportē modeli: %s -> %s", folder_path, repo_id)
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
api.upload_folder(
folder_path=folder_path,
repo_id=repo_id,
repo_type="model",
commit_message=commit_message,
)
def _branch_suite_export_targets(
model_path: str,
repo_id: str,
commit_message: str,
) -> list[tuple[str, str, str]]:
output_dir = Path(model_path)
manifest_path = output_dir / "branch-suite.json"
suite = json.loads(manifest_path.read_text(encoding="utf-8")) if manifest_path.is_file() else {}
branch_exports = {
"master": _getenv_any(
"MARIS_TEXT_MODEL_REPO",
"HF_TEXT_MODEL_REPO",
"TEXT_MODEL",
default="MarisUK/maris-ai-text",
),
"coder": _getenv_any(
"MARIS_CODEX_MODEL_REPO",
"HF_CODEX_MODEL_REPO",
"CODEX_MODEL",
default="MarisUK/maris-ai-codex",
),
"image": os.getenv("IMAGE_MODEL", "MarisUK/maris-ai-image"),
"music": os.getenv("MUSIC_MODEL", "MarisUK/maris-ai-music"),
"tts": os.getenv("TTS_MODEL", "MarisUK/maris-tts-runtime"),
"stt": os.getenv("STT_MODEL", "MarisUK/maris-stt-runtime"),
"video": os.getenv("VIDEO_MODEL", "MarisUK/maris-ai-video"),
}
targets: list[tuple[str, str, str]] = [(str(output_dir), repo_id, commit_message)]
seen: set[tuple[str, str]] = {(str(output_dir), repo_id)}
for branch_name, target_repo in branch_exports.items():
branch_output_dir = _resolve_branch_output_dir(output_dir, suite, branch_name)
if branch_output_dir is None:
continue
validated_repo = validate_maris_model(target_repo, f"{branch_name} export target")
normalized_dir = str(branch_output_dir)
key = (normalized_dir, validated_repo)
if key in seen:
continue
seen.add(key)
targets.append(
(
normalized_dir,
validated_repo,
f"{commit_message} ({branch_name})",
)
)
return targets
def _resolve_branch_output_dir(
output_dir: Path,
suite: dict[str, object],
branch_name: str,
) -> Path | None:
branches = suite.get("branches", {})
if isinstance(branches, dict):
branch_payload = branches.get(branch_name)
if isinstance(branch_payload, dict):
candidate = str(branch_payload.get("output_dir", "")).strip()
if candidate:
branch_dir = Path(candidate)
candidate_roots = (Path.cwd(), output_dir) if not branch_dir.is_absolute() else (None,)
for root in candidate_roots:
resolved_dir = branch_dir if root is None else (root / candidate).resolve()
if resolved_dir.is_dir():
return resolved_dir
fallback_names = {
"master": ("master", "text"),
"coder": ("coder", "codex"),
}.get(branch_name, (branch_name,))
for fallback_name in fallback_names:
fallback_dir = (output_dir / fallback_name).resolve()
if fallback_dir.is_dir():
return fallback_dir
return None
def export_model(
model_path: str,
repo_id: str | None = None,
commit_message: str = "Maris AI model export",
) -> None:
"""Augšupielādē modeli uz origin repozitoriju."""
repo_id = (
repo_id
or os.getenv("MARIS_MODEL_REPO")
or os.getenv("HF_MODEL_REPO", "MarisUK/maris-ai-master")
)
repo_id = validate_maris_model(repo_id, "MARIS_MODEL_REPO/HF_MODEL_REPO/--repo-id")
token = os.getenv("MARIS_REPO_TOKEN") or os.getenv("MARIS_TOKEN") or os.getenv("HF_TOKEN")
try:
from huggingface_hub import HfApi # type: ignore
api = HfApi(token=token)
for export_path, export_repo, export_commit in _branch_suite_export_targets(
model_path,
repo_id,
commit_message,
):
_upload_folder(api, export_path, export_repo, export_commit)
logger.info("Eksportēšana pabeigta!")
except Exception as exc: # noqa: BLE001
logger.error("Eksportēšanas kļūda: %s", exc)
raise
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Eksportē Maris AI modeli uz origin repozitoriju")
parser.add_argument("--model-path", required=True, help="Modeļa direktorija")
parser.add_argument("--repo-id", help="Origin repo ID")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
export_model(args.model_path, args.repo_id)
|