File size: 5,226 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Eksportē modeli uz Maris origin repozitoriju."""

from __future__ import annotations

import argparse
import json
import logging
import os
from pathlib import Path

from maris_core.utils.env import validate_maris_model

logger = logging.getLogger(__name__)


def _getenv_any(*names: str, default: str) -> str:
    for name in names:
        value = os.getenv(name, "").strip()
        if value:
            return value
    return default


def _upload_folder(api: object, folder_path: str, repo_id: str, commit_message: str) -> None:
    logger.info("Eksportē modeli: %s -> %s", folder_path, repo_id)
    api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
    api.upload_folder(
        folder_path=folder_path,
        repo_id=repo_id,
        repo_type="model",
        commit_message=commit_message,
    )


def _branch_suite_export_targets(
    model_path: str,
    repo_id: str,
    commit_message: str,
) -> list[tuple[str, str, str]]:
    output_dir = Path(model_path)
    manifest_path = output_dir / "branch-suite.json"
    suite = json.loads(manifest_path.read_text(encoding="utf-8")) if manifest_path.is_file() else {}
    branch_exports = {
        "master": _getenv_any(
            "MARIS_TEXT_MODEL_REPO",
            "HF_TEXT_MODEL_REPO",
            "TEXT_MODEL",
            default="MarisUK/maris-ai-text",
        ),
        "coder": _getenv_any(
            "MARIS_CODEX_MODEL_REPO",
            "HF_CODEX_MODEL_REPO",
            "CODEX_MODEL",
            default="MarisUK/maris-ai-codex",
        ),
        "image": os.getenv("IMAGE_MODEL", "MarisUK/maris-ai-image"),
        "music": os.getenv("MUSIC_MODEL", "MarisUK/maris-ai-music"),
        "tts": os.getenv("TTS_MODEL", "MarisUK/maris-tts-runtime"),
        "stt": os.getenv("STT_MODEL", "MarisUK/maris-stt-runtime"),
        "video": os.getenv("VIDEO_MODEL", "MarisUK/maris-ai-video"),
    }
    targets: list[tuple[str, str, str]] = [(str(output_dir), repo_id, commit_message)]
    seen: set[tuple[str, str]] = {(str(output_dir), repo_id)}
    for branch_name, target_repo in branch_exports.items():
        branch_output_dir = _resolve_branch_output_dir(output_dir, suite, branch_name)
        if branch_output_dir is None:
            continue
        validated_repo = validate_maris_model(target_repo, f"{branch_name} export target")
        normalized_dir = str(branch_output_dir)
        key = (normalized_dir, validated_repo)
        if key in seen:
            continue
        seen.add(key)
        targets.append(
            (
                normalized_dir,
                validated_repo,
                f"{commit_message} ({branch_name})",
            )
        )
    return targets


def _resolve_branch_output_dir(
    output_dir: Path,
    suite: dict[str, object],
    branch_name: str,
) -> Path | None:
    branches = suite.get("branches", {})
    if isinstance(branches, dict):
        branch_payload = branches.get(branch_name)
        if isinstance(branch_payload, dict):
            candidate = str(branch_payload.get("output_dir", "")).strip()
            if candidate:
                branch_dir = Path(candidate)
                candidate_roots = (Path.cwd(), output_dir) if not branch_dir.is_absolute() else (None,)
                for root in candidate_roots:
                    resolved_dir = branch_dir if root is None else (root / candidate).resolve()
                    if resolved_dir.is_dir():
                        return resolved_dir

    fallback_names = {
        "master": ("master", "text"),
        "coder": ("coder", "codex"),
    }.get(branch_name, (branch_name,))
    for fallback_name in fallback_names:
        fallback_dir = (output_dir / fallback_name).resolve()
        if fallback_dir.is_dir():
            return fallback_dir
    return None


def export_model(
    model_path: str,
    repo_id: str | None = None,
    commit_message: str = "Maris AI model export",
) -> None:
    """Augšupielādē modeli uz origin repozitoriju."""
    repo_id = (
        repo_id
        or os.getenv("MARIS_MODEL_REPO")
        or os.getenv("HF_MODEL_REPO", "MarisUK/maris-ai-master")
    )
    repo_id = validate_maris_model(repo_id, "MARIS_MODEL_REPO/HF_MODEL_REPO/--repo-id")
    token = os.getenv("MARIS_REPO_TOKEN") or os.getenv("MARIS_TOKEN") or os.getenv("HF_TOKEN")

    try:
        from huggingface_hub import HfApi  # type: ignore

        api = HfApi(token=token)
        for export_path, export_repo, export_commit in _branch_suite_export_targets(
            model_path,
            repo_id,
            commit_message,
        ):
            _upload_folder(api, export_path, export_repo, export_commit)
        logger.info("Eksportēšana pabeigta!")
    except Exception as exc:  # noqa: BLE001
        logger.error("Eksportēšanas kļūda: %s", exc)
        raise


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Eksportē Maris AI modeli uz origin repozitoriju")
    parser.add_argument("--model-path", required=True, help="Modeļa direktorija")
    parser.add_argument("--repo-id", help="Origin repo ID")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    export_model(args.model_path, args.repo_id)