QuickStart / tests /test_quickstart_core.py
Tarek Masryo
chore: prepare space release
d00b7fc
import ast
import zipfile
import pytest
from quickstart_core import (
build_export_files,
cached_public,
compute_requirements,
generate_cli_download,
generate_quickstart,
generate_snapshot_download,
get_effective_token,
is_valid_repo_id,
parse_hf_input,
)
@pytest.mark.parametrize(
"value,expected",
[
("owner/repo", ("model", "owner/repo")),
("bert-base-uncased", ("model", "bert-base-uncased")),
("datasets/owner/repo", ("dataset", "owner/repo")),
("datasets/squad", ("dataset", "squad")),
("spaces/owner/repo", ("space", "owner/repo")),
("https://huggingface.co/owner/repo", ("model", "owner/repo")),
("https://hf.co/owner/repo", ("model", "owner/repo")),
("https://huggingface.co/datasets/owner/repo/blob/main/data.csv", ("dataset", "owner/repo")),
("https://huggingface.co/spaces/owner/repo/tree/main", ("space", "owner/repo")),
("https://huggingface.co/datasets/squad/viewer/plain_text/train", ("dataset", "squad")),
("https://huggingface.co/owner/repo/discussions/1", ("model", "owner/repo")),
],
)
def test_parse_hf_input(value, expected):
assert parse_hf_input(value) == expected
@pytest.mark.parametrize(
"repo_id",
["owner/repo", "org-name/model.name", "user_1/repo_2", "bert-base-uncased", "squad"],
)
def test_valid_repo_ids(repo_id):
assert is_valid_repo_id(repo_id)
@pytest.mark.parametrize(
"repo_id",
["", "/owner/repo", "owner/repo/extra", "owner/..repo", "owner/repo--bad", "owner/.repo"],
)
def test_invalid_repo_ids(repo_id):
assert not is_valid_repo_id(repo_id)
def test_gguf_quickstart_is_valid_python():
code = generate_quickstart(
"model",
"owner/repo",
{"_risk": {"has_gguf": True}, "_files": [{"path": "models/model.gguf"}]},
)
ast.parse(code)
assert '"Q: Hello!\\nA:"' in code
def test_generated_gguf_snippet_escapes_unusual_filenames():
code = generate_quickstart(
"model",
"owner/repo",
{"_risk": {"has_gguf": True}, "_files": [{"path": 'models/bad"name.gguf'}]},
)
ast.parse(code)
assert 'bad"name.gguf' in code
def test_generated_snippets_compile_for_supported_paths():
cases = [
("dataset", "owner/repo", {}),
("space", "owner/repo", {"_sdk": "streamlit"}),
("space", "owner/repo", {"_sdk": "gradio"}),
("model", "owner/repo", {"_pipeline_tag": "text-generation"}),
("model", "owner/repo", {"_pipeline_tag": "text-classification"}),
("model", "owner/repo", {"_pipeline_tag": "image-classification"}),
("model", "owner/repo", {}),
]
for repo_type, repo_id, meta in cases:
ast.parse(generate_quickstart(repo_type, repo_id, meta))
ast.parse(generate_snapshot_download(repo_type, repo_id))
def test_export_rejects_invalid_repo_id():
with pytest.raises(ValueError, match="Invalid Repo ID"):
build_export_files({"Repo ID": 'owner/bad"repo', "Type": "model"})
def test_snapshot_download_code_avoids_removed_symlink_argument():
code = generate_snapshot_download("dataset", "owner/repo")
ast.parse(code)
assert "local_dir_use_symlinks" not in code
assert "repo_type='dataset'" in code
def test_cli_uses_modern_hf_command():
command = generate_cli_download("model", "owner/repo")
assert command.startswith("hf download owner/repo")
assert "huggingface-cli" not in command
def test_requirements_are_type_aware():
assert compute_requirements("dataset", {}) == ["datasets", "huggingface_hub"]
assert "llama-cpp-python" in compute_requirements("model", {"_risk": {"has_gguf": True}})
assert "accelerate" in compute_requirements("model", {"_pipeline_tag": "text-generation"})
def test_export_files_compile_for_gguf():
files = build_export_files(
{
"Repo ID": "owner/repo",
"Type": "model",
"_risk": {"has_gguf": True},
"_files": [{"path": "model.gguf"}],
}
)
ast.parse(files["run.py"])
ast.parse(files["download.py"])
assert set(files) == {"README.md", "requirements.txt", ".env.example", "run.py", "download.py"}
def test_export_zip_contract(tmp_path, monkeypatch):
# Keep tempfile output under pytest tmp path for this test.
monkeypatch.setattr("tempfile.mkdtemp", lambda prefix: str(tmp_path / prefix.rstrip("_")))
from quickstart_core import build_quickstart_zip
zip_path, message = build_quickstart_zip({"Repo ID": "owner/repo", "Type": "dataset"})
assert zip_path is not None
assert "Zip built" in message
with zipfile.ZipFile(zip_path) as archive:
assert set(archive.namelist()) == {
"README.md",
"requirements.txt",
".env.example",
"run.py",
"download.py",
}
def test_server_token_requires_allowed_owner(monkeypatch):
monkeypatch.setenv("ALLOW_SERVER_TOKEN", "1")
monkeypatch.setenv("HF_TOKEN", "secret-token")
monkeypatch.delenv("TOKEN_ALLOWED_OWNERS", raising=False)
assert get_effective_token("owner/repo") is None
monkeypatch.setenv("TOKEN_ALLOWED_OWNERS", "owner")
assert get_effective_token("owner/repo") == "secret-token"
def test_cached_public_only_caches_success(monkeypatch):
import quickstart_core
calls = {"count": 0}
def fake_fetch(repo_type, repo_id, token):
calls["count"] += 1
if calls["count"] == 1:
return False, None, "temporary error"
return True, {"Repo ID": repo_id, "Type": repo_type}, None
quickstart_core._PUBLIC_CACHE.clear()
monkeypatch.setattr(quickstart_core, "fetch_repo_info", fake_fetch)
first = cached_public("model", "owner/repo")
second = cached_public("model", "owner/repo")
third = cached_public("model", "owner/repo")
assert first[0] is False
assert second[0] is True
assert third[0] is True
assert calls["count"] == 2
def test_gradio_ui_builds():
import app
demo, theme, css = app.build_ui()
assert demo is not None
assert theme is not None
assert ".hero" in css