File size: 6,165 Bytes
c2b6b36
 
 
 
 
 
 
d00b7fc
c2b6b36
 
 
 
d00b7fc
c2b6b36
 
 
 
 
 
 
 
 
d00b7fc
c2b6b36
d00b7fc
c2b6b36
 
d00b7fc
c2b6b36
 
d00b7fc
 
c2b6b36
 
 
 
 
 
 
 
d00b7fc
c2b6b36
 
 
 
 
 
 
d00b7fc
c2b6b36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00b7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b6b36
 
 
 
d00b7fc
c2b6b36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00b7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import ast
import zipfile

import pytest

from quickstart_core import (
    build_export_files,
    cached_public,
    compute_requirements,
    generate_cli_download,
    generate_quickstart,
    generate_snapshot_download,
    get_effective_token,
    is_valid_repo_id,
    parse_hf_input,
)


@pytest.mark.parametrize(
    "value,expected",
    [
        ("owner/repo", ("model", "owner/repo")),
        ("bert-base-uncased", ("model", "bert-base-uncased")),
        ("datasets/owner/repo", ("dataset", "owner/repo")),
        ("datasets/squad", ("dataset", "squad")),
        ("spaces/owner/repo", ("space", "owner/repo")),
        ("https://huggingface.co/owner/repo", ("model", "owner/repo")),
        ("https://hf.co/owner/repo", ("model", "owner/repo")),
        ("https://huggingface.co/datasets/owner/repo/blob/main/data.csv", ("dataset", "owner/repo")),
        ("https://huggingface.co/spaces/owner/repo/tree/main", ("space", "owner/repo")),
        ("https://huggingface.co/datasets/squad/viewer/plain_text/train", ("dataset", "squad")),
        ("https://huggingface.co/owner/repo/discussions/1", ("model", "owner/repo")),
    ],
)
def test_parse_hf_input(value, expected):
    assert parse_hf_input(value) == expected


@pytest.mark.parametrize(
    "repo_id",
    ["owner/repo", "org-name/model.name", "user_1/repo_2", "bert-base-uncased", "squad"],
)
def test_valid_repo_ids(repo_id):
    assert is_valid_repo_id(repo_id)


@pytest.mark.parametrize(
    "repo_id",
    ["", "/owner/repo", "owner/repo/extra", "owner/..repo", "owner/repo--bad", "owner/.repo"],
)
def test_invalid_repo_ids(repo_id):
    assert not is_valid_repo_id(repo_id)


def test_gguf_quickstart_is_valid_python():
    code = generate_quickstart(
        "model",
        "owner/repo",
        {"_risk": {"has_gguf": True}, "_files": [{"path": "models/model.gguf"}]},
    )
    ast.parse(code)
    assert '"Q: Hello!\\nA:"' in code


def test_generated_gguf_snippet_escapes_unusual_filenames():
    code = generate_quickstart(
        "model",
        "owner/repo",
        {"_risk": {"has_gguf": True}, "_files": [{"path": 'models/bad"name.gguf'}]},
    )
    ast.parse(code)
    assert 'bad"name.gguf' in code


def test_generated_snippets_compile_for_supported_paths():
    cases = [
        ("dataset", "owner/repo", {}),
        ("space", "owner/repo", {"_sdk": "streamlit"}),
        ("space", "owner/repo", {"_sdk": "gradio"}),
        ("model", "owner/repo", {"_pipeline_tag": "text-generation"}),
        ("model", "owner/repo", {"_pipeline_tag": "text-classification"}),
        ("model", "owner/repo", {"_pipeline_tag": "image-classification"}),
        ("model", "owner/repo", {}),
    ]
    for repo_type, repo_id, meta in cases:
        ast.parse(generate_quickstart(repo_type, repo_id, meta))
        ast.parse(generate_snapshot_download(repo_type, repo_id))


def test_export_rejects_invalid_repo_id():
    with pytest.raises(ValueError, match="Invalid Repo ID"):
        build_export_files({"Repo ID": 'owner/bad"repo', "Type": "model"})


def test_snapshot_download_code_avoids_removed_symlink_argument():
    code = generate_snapshot_download("dataset", "owner/repo")
    ast.parse(code)
    assert "local_dir_use_symlinks" not in code
    assert "repo_type='dataset'" in code


def test_cli_uses_modern_hf_command():
    command = generate_cli_download("model", "owner/repo")
    assert command.startswith("hf download owner/repo")
    assert "huggingface-cli" not in command


def test_requirements_are_type_aware():
    assert compute_requirements("dataset", {}) == ["datasets", "huggingface_hub"]
    assert "llama-cpp-python" in compute_requirements("model", {"_risk": {"has_gguf": True}})
    assert "accelerate" in compute_requirements("model", {"_pipeline_tag": "text-generation"})


def test_export_files_compile_for_gguf():
    files = build_export_files(
        {
            "Repo ID": "owner/repo",
            "Type": "model",
            "_risk": {"has_gguf": True},
            "_files": [{"path": "model.gguf"}],
        }
    )
    ast.parse(files["run.py"])
    ast.parse(files["download.py"])
    assert set(files) == {"README.md", "requirements.txt", ".env.example", "run.py", "download.py"}


def test_export_zip_contract(tmp_path, monkeypatch):
    # Keep tempfile output under pytest tmp path for this test.
    monkeypatch.setattr("tempfile.mkdtemp", lambda prefix: str(tmp_path / prefix.rstrip("_")))
    from quickstart_core import build_quickstart_zip

    zip_path, message = build_quickstart_zip({"Repo ID": "owner/repo", "Type": "dataset"})
    assert zip_path is not None
    assert "Zip built" in message
    with zipfile.ZipFile(zip_path) as archive:
        assert set(archive.namelist()) == {
            "README.md",
            "requirements.txt",
            ".env.example",
            "run.py",
            "download.py",
        }


def test_server_token_requires_allowed_owner(monkeypatch):
    monkeypatch.setenv("ALLOW_SERVER_TOKEN", "1")
    monkeypatch.setenv("HF_TOKEN", "secret-token")
    monkeypatch.delenv("TOKEN_ALLOWED_OWNERS", raising=False)

    assert get_effective_token("owner/repo") is None

    monkeypatch.setenv("TOKEN_ALLOWED_OWNERS", "owner")
    assert get_effective_token("owner/repo") == "secret-token"


def test_cached_public_only_caches_success(monkeypatch):
    import quickstart_core

    calls = {"count": 0}

    def fake_fetch(repo_type, repo_id, token):
        calls["count"] += 1
        if calls["count"] == 1:
            return False, None, "temporary error"
        return True, {"Repo ID": repo_id, "Type": repo_type}, None

    quickstart_core._PUBLIC_CACHE.clear()
    monkeypatch.setattr(quickstart_core, "fetch_repo_info", fake_fetch)

    first = cached_public("model", "owner/repo")
    second = cached_public("model", "owner/repo")
    third = cached_public("model", "owner/repo")

    assert first[0] is False
    assert second[0] is True
    assert third[0] is True
    assert calls["count"] == 2


def test_gradio_ui_builds():
    import app

    demo, theme, css = app.build_ui()
    assert demo is not None
    assert theme is not None
    assert ".hero" in css