File size: 2,944 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Tests Maris AI metadata saglabāšanu origin integrācijā."""

from __future__ import annotations

import builtins

import pytest

from maris_core.utils.hf_integration import HFIntegration


@pytest.mark.asyncio
async def test_save_generation_adds_maris_metadata(monkeypatch) -> None:
    captured: dict[str, object] = {}

    async def fake_push(self, entry):  # noqa: ANN001
        del self
        captured.update(entry)

    monkeypatch.setattr(HFIntegration, "_push_to_dataset", fake_push)

    integration = HFIntegration()
    await integration.save_generation("image", "Saulriets", {"style": "clean"})

    assert captured["type"] == "image"
    assert captured["prompt"] == "Saulriets"
    assert captured["metadata"] == {
        "style": "clean",
        "generated_by": "Maris AI",
        "maris_framework": "maris-ai-core",
        "maris_origin": True,
    }


@pytest.mark.asyncio
async def test_save_conversation_adds_maris_metadata(monkeypatch) -> None:
    captured: dict[str, object] = {}

    async def fake_push(self, entry):  # noqa: ANN001
        del self
        captured.update(entry)

    monkeypatch.setattr(HFIntegration, "_push_to_dataset", fake_push)

    integration = HFIntegration()
    await integration.save_conversation("Sveiki", "Labdien", metadata={"request_id": "req-1"})

    assert captured["type"] == "conversation"
    assert captured["metadata"] == {
        "request_id": "req-1",
        "generated_by": "Maris AI",
        "maris_framework": "maris-ai-core",
        "maris_origin": True,
    }


def test_get_api_returns_none_when_huggingface_hub_import_crashes(monkeypatch, caplog) -> None:
    integration = HFIntegration()
    original_import = builtins.__import__
    attempts = 0

    def broken_import(name, global_vars=None, local_vars=None, fromlist=(), level=0):  # noqa: ANN001
        nonlocal attempts
        if name == "huggingface_hub":
            attempts += 1
            raise RuntimeError("boom")
        return original_import(name, global_vars, local_vars, fromlist, level)

    monkeypatch.setattr(builtins, "__import__", broken_import)

    with caplog.at_level("WARNING"):
        assert integration._get_api() is None
        assert integration._get_api() is None

    assert attempts == 1
    assert "Publicēšanas klients nav pieejams: boom" in caplog.text


def test_hf_integration_accepts_generic_hf_repo_and_token_alias(monkeypatch) -> None:
    monkeypatch.setenv("HF_DATASET_REPO", "custom-user/private-memory")
    monkeypatch.delenv("HF_TOKEN", raising=False)
    monkeypatch.delenv("MARIS_REPO_TOKEN", raising=False)
    monkeypatch.delenv("MARIS_TOKEN", raising=False)
    monkeypatch.delenv("HUGGINGFACEHUB_API_TOKEN", raising=False)
    monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "secret-token")

    integration = HFIntegration()

    assert integration.dataset_repo == "custom-user/private-memory"
    assert integration.token == "secret-token"