Spaces:
Sleeping
Sleeping
File size: 6,165 Bytes
c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc c2b6b36 d00b7fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | import ast
import zipfile
import pytest
from quickstart_core import (
build_export_files,
cached_public,
compute_requirements,
generate_cli_download,
generate_quickstart,
generate_snapshot_download,
get_effective_token,
is_valid_repo_id,
parse_hf_input,
)
@pytest.mark.parametrize(
"value,expected",
[
("owner/repo", ("model", "owner/repo")),
("bert-base-uncased", ("model", "bert-base-uncased")),
("datasets/owner/repo", ("dataset", "owner/repo")),
("datasets/squad", ("dataset", "squad")),
("spaces/owner/repo", ("space", "owner/repo")),
("https://huggingface.co/owner/repo", ("model", "owner/repo")),
("https://hf.co/owner/repo", ("model", "owner/repo")),
("https://huggingface.co/datasets/owner/repo/blob/main/data.csv", ("dataset", "owner/repo")),
("https://huggingface.co/spaces/owner/repo/tree/main", ("space", "owner/repo")),
("https://huggingface.co/datasets/squad/viewer/plain_text/train", ("dataset", "squad")),
("https://huggingface.co/owner/repo/discussions/1", ("model", "owner/repo")),
],
)
def test_parse_hf_input(value, expected):
assert parse_hf_input(value) == expected
@pytest.mark.parametrize(
"repo_id",
["owner/repo", "org-name/model.name", "user_1/repo_2", "bert-base-uncased", "squad"],
)
def test_valid_repo_ids(repo_id):
assert is_valid_repo_id(repo_id)
@pytest.mark.parametrize(
"repo_id",
["", "/owner/repo", "owner/repo/extra", "owner/..repo", "owner/repo--bad", "owner/.repo"],
)
def test_invalid_repo_ids(repo_id):
assert not is_valid_repo_id(repo_id)
def test_gguf_quickstart_is_valid_python():
code = generate_quickstart(
"model",
"owner/repo",
{"_risk": {"has_gguf": True}, "_files": [{"path": "models/model.gguf"}]},
)
ast.parse(code)
assert '"Q: Hello!\\nA:"' in code
def test_generated_gguf_snippet_escapes_unusual_filenames():
code = generate_quickstart(
"model",
"owner/repo",
{"_risk": {"has_gguf": True}, "_files": [{"path": 'models/bad"name.gguf'}]},
)
ast.parse(code)
assert 'bad"name.gguf' in code
def test_generated_snippets_compile_for_supported_paths():
cases = [
("dataset", "owner/repo", {}),
("space", "owner/repo", {"_sdk": "streamlit"}),
("space", "owner/repo", {"_sdk": "gradio"}),
("model", "owner/repo", {"_pipeline_tag": "text-generation"}),
("model", "owner/repo", {"_pipeline_tag": "text-classification"}),
("model", "owner/repo", {"_pipeline_tag": "image-classification"}),
("model", "owner/repo", {}),
]
for repo_type, repo_id, meta in cases:
ast.parse(generate_quickstart(repo_type, repo_id, meta))
ast.parse(generate_snapshot_download(repo_type, repo_id))
def test_export_rejects_invalid_repo_id():
with pytest.raises(ValueError, match="Invalid Repo ID"):
build_export_files({"Repo ID": 'owner/bad"repo', "Type": "model"})
def test_snapshot_download_code_avoids_removed_symlink_argument():
code = generate_snapshot_download("dataset", "owner/repo")
ast.parse(code)
assert "local_dir_use_symlinks" not in code
assert "repo_type='dataset'" in code
def test_cli_uses_modern_hf_command():
command = generate_cli_download("model", "owner/repo")
assert command.startswith("hf download owner/repo")
assert "huggingface-cli" not in command
def test_requirements_are_type_aware():
assert compute_requirements("dataset", {}) == ["datasets", "huggingface_hub"]
assert "llama-cpp-python" in compute_requirements("model", {"_risk": {"has_gguf": True}})
assert "accelerate" in compute_requirements("model", {"_pipeline_tag": "text-generation"})
def test_export_files_compile_for_gguf():
files = build_export_files(
{
"Repo ID": "owner/repo",
"Type": "model",
"_risk": {"has_gguf": True},
"_files": [{"path": "model.gguf"}],
}
)
ast.parse(files["run.py"])
ast.parse(files["download.py"])
assert set(files) == {"README.md", "requirements.txt", ".env.example", "run.py", "download.py"}
def test_export_zip_contract(tmp_path, monkeypatch):
# Keep tempfile output under pytest tmp path for this test.
monkeypatch.setattr("tempfile.mkdtemp", lambda prefix: str(tmp_path / prefix.rstrip("_")))
from quickstart_core import build_quickstart_zip
zip_path, message = build_quickstart_zip({"Repo ID": "owner/repo", "Type": "dataset"})
assert zip_path is not None
assert "Zip built" in message
with zipfile.ZipFile(zip_path) as archive:
assert set(archive.namelist()) == {
"README.md",
"requirements.txt",
".env.example",
"run.py",
"download.py",
}
def test_server_token_requires_allowed_owner(monkeypatch):
monkeypatch.setenv("ALLOW_SERVER_TOKEN", "1")
monkeypatch.setenv("HF_TOKEN", "secret-token")
monkeypatch.delenv("TOKEN_ALLOWED_OWNERS", raising=False)
assert get_effective_token("owner/repo") is None
monkeypatch.setenv("TOKEN_ALLOWED_OWNERS", "owner")
assert get_effective_token("owner/repo") == "secret-token"
def test_cached_public_only_caches_success(monkeypatch):
import quickstart_core
calls = {"count": 0}
def fake_fetch(repo_type, repo_id, token):
calls["count"] += 1
if calls["count"] == 1:
return False, None, "temporary error"
return True, {"Repo ID": repo_id, "Type": repo_type}, None
quickstart_core._PUBLIC_CACHE.clear()
monkeypatch.setattr(quickstart_core, "fetch_repo_info", fake_fetch)
first = cached_public("model", "owner/repo")
second = cached_public("model", "owner/repo")
third = cached_public("model", "owner/repo")
assert first[0] is False
assert second[0] is True
assert third[0] is True
assert calls["count"] == 2
def test_gradio_ui_builds():
import app
demo, theme, css = app.build_ui()
assert demo is not None
assert theme is not None
assert ".hero" in css
|