"""Core helpers for the QuickStart Hugging Face repo assistant.""" from __future__ import annotations import html import inspect import os import re import tempfile import textwrap import zipfile from pathlib import Path from typing import Any from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError VALID_REPO_TYPES = {"model", "dataset", "space"} RE_REPO_SEGMENT = re.compile(r"^(?!.*(?:--|\.\.))[A-Za-z0-9][A-Za-z0-9_.-]{0,95}$") SENSITIVE_FILENAME_PATTERNS = [ r"(^|/)\.env$", r"secrets?", r"token", r"api[_-]?key", r"credentials?", r"id_rsa", r"\.pem$", r"\.p12$", r"\.kdbx$", ] def esc(value: Any) -> str: """HTML-escape values before injecting them into custom Gradio HTML.""" return html.escape("" if value is None else str(value), quote=True) def norm_type(value: str | None) -> str: repo_type = (value or "model").strip().lower() return repo_type if repo_type in VALID_REPO_TYPES else "model" def norm_id(value: str | None) -> str: return (value or "").strip().strip("/") def is_valid_repo_id(repo_id: str) -> bool: repo_id = (repo_id or "").strip() parts = repo_id.split("/") if len(parts) not in {1, 2}: return False return all( RE_REPO_SEGMENT.match(part) and not part.startswith(("-", ".")) and not part.endswith(("-", ".")) for part in parts ) def human_bytes(num_bytes: int | None) -> str: if not isinstance(num_bytes, int) or num_bytes <= 0: return "N/A" units = ["B", "KB", "MB", "GB", "TB"] value = float(num_bytes) unit_index = 0 while value >= 1024 and unit_index < len(units) - 1: value /= 1024 unit_index += 1 return f"{value:.2f} {units[unit_index]}" def safe_str(value: Any, max_chars: int = 500) -> str: text = "" if value is None else str(value) text = re.sub(r"\s+", " ", text).strip() if len(text) > max_chars: return text[: max_chars - 3] + "..." return text def py_literal(value: Any) -> str: """Return a safe Python string literal for generated snippets.""" return repr("" if value is None else str(value)) def parse_hf_input(user_input: str) -> tuple[str, str]: """Parse a Hugging Face URL, typed repo path, or plain owner/repo ID.""" value = (user_input or "").strip() if not value: return "model", "" if "huggingface.co" in value or "hf.co" in value: scoped_match = re.search(r"(?:huggingface\.co|hf\.co)/(datasets|spaces)/([^?#]+)", value) if scoped_match: repo_type = "dataset" if scoped_match.group(1) == "datasets" else "space" repo_id = _strip_hf_file_path(scoped_match.group(2)) return repo_type, repo_id model_match = re.search(r"(?:huggingface\.co|hf\.co)/([^?#]+)", value) if model_match: repo_id = _strip_hf_file_path(model_match.group(1)) return "model", repo_id if value.startswith("datasets/"): return "dataset", value.replace("datasets/", "", 1).strip("/") if value.startswith("spaces/"): return "space", value.replace("spaces/", "", 1).strip("/") return "model", value.strip("/") def _strip_hf_file_path(path: str) -> str: path = (path or "").strip("/") path = re.split(r"/(tree|blob|resolve|raw|viewer|discussions)/", path)[0].strip("/") return path def hf_url(repo_type: str, repo_id: str) -> str: repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) if repo_type == "dataset": return f"https://huggingface.co/datasets/{repo_id}" if repo_type == "space": return f"https://huggingface.co/spaces/{repo_id}" return f"https://huggingface.co/{repo_id}" def safe_hf_error(error: HfHubHTTPError) -> str: status = getattr(getattr(error, "response", None), "status_code", "N/A") message = getattr(error, "server_message", None) or str(error) return f"Hugging Face Error: {status} - {safe_str(message, 500)}" def call_with_supported_kwargs(fn: Any, *args: Any, **kwargs: Any) -> Any: """Call SDK functions with only supported kwargs without swallowing API errors.""" try: signature = inspect.signature(fn) except (TypeError, ValueError): return fn(*args, **kwargs) allowed = set(signature.parameters) supported_kwargs = {key: value for key, value in kwargs.items() if key in allowed} return fn(*args, **supported_kwargs) def extract_file_entries(info_obj: Any) -> list[dict[str, Any]]: entries: list[dict[str, Any]] = [] siblings = getattr(info_obj, "siblings", None) or [] for sibling in siblings: path = getattr(sibling, "rfilename", None) or getattr(sibling, "path", None) if not path: continue size = getattr(sibling, "size", None) if size is None: lfs = getattr(sibling, "lfs", None) size = getattr(lfs, "size", None) if lfs is not None else None entries.append({"path": str(path), "size": int(size) if isinstance(size, int) else None}) return entries def files_risk_report(files: list[dict[str, Any]]) -> dict[str, Any]: paths = [item.get("path", "") for item in files if item.get("path")] total_known = sum(int(item["size"]) for item in files if isinstance(item.get("size"), int)) lower_paths = [path.lower() for path in paths] suspicious_names = [ path for path in paths if any(re.search(pattern, path.lower()) for pattern in SENSITIVE_FILENAME_PATTERNS) ] return { "files_count": len(paths), "total_size_known": total_known if total_known > 0 else None, "has_gguf": any(path.endswith(".gguf") for path in lower_paths), "has_onnx": any(path.endswith(".onnx") for path in lower_paths), "has_safetensors": any(path.endswith(".safetensors") for path in lower_paths), "has_bin": any(path.endswith(".bin") for path in lower_paths), "suspicious_names": suspicious_names[:30], } def warnings_from_meta(meta: dict[str, Any]) -> list[str]: warnings: list[str] = [] risk = meta.get("_risk", {}) or {} if meta.get("Gated") == "Yes" or meta.get("Private") == "Yes": warnings.append("Repo may require HF_TOKEN because it is private or gated.") total_size = risk.get("total_size_known") if isinstance(total_size, int) and total_size > 8 * 1024**3: warnings.append("Large repo size detected (>8GB). Prefer selective download when possible.") if risk.get("has_gguf"): warnings.append( "GGUF detected. Use a llama.cpp / llama-cpp-python flow instead of generic Transformers." ) if risk.get("suspicious_names"): warnings.append( "Potentially sensitive filenames detected. This is filename-based only; review before use." ) if meta.get("Pipeline") == "text-generation": warnings.append("Text-generation models can be slow without adequate GPU/VRAM.") return warnings def to_files_table(files: list[dict[str, Any]], limit: int = 250) -> list[list[Any]]: return [ [item.get("path", ""), human_bytes(item.get("size")) if isinstance(item.get("size"), int) else "N/A"] for item in (files or [])[:limit] ] def filter_files(files: list[dict[str, Any]], query: str, limit: int = 250) -> list[list[Any]]: query = (query or "").strip().lower() if not query: return to_files_table(files, limit=limit) rows: list[list[Any]] = [] for item in files or []: path = item.get("path") or "" if query in path.lower(): size = human_bytes(item.get("size")) if isinstance(item.get("size"), int) else "N/A" rows.append([path, size]) if len(rows) >= limit: break return rows def first_file_with_ext(files: list[dict[str, Any]], extension: str) -> str | None: extension = (extension or "").lower() for item in files or []: path = item.get("path") or "" if path.lower().endswith(extension): return path return None def compute_requirements(repo_type: str, meta: dict[str, Any]) -> list[str]: repo_type = norm_type(repo_type) pipeline_tag = (meta or {}).get("_pipeline_tag", "N/A") sdk = (meta or {}).get("_sdk", "N/A") has_gguf = bool((meta or {}).get("_risk", {}).get("has_gguf") or (meta or {}).get("_has_gguf")) if repo_type == "dataset": return ["datasets", "huggingface_hub"] if repo_type == "space": if sdk == "streamlit": return ["streamlit", "huggingface_hub", "requests"] if sdk == "gradio": return ["gradio", "huggingface_hub", "requests"] return ["huggingface_hub", "requests"] if has_gguf: return ["huggingface_hub", "llama-cpp-python"] if pipeline_tag == "text-generation": return ["transformers", "huggingface_hub", "torch", "accelerate"] if pipeline_tag in {"image-classification", "image-to-text", "image-segmentation", "object-detection"}: return ["transformers", "huggingface_hub", "torch", "pillow", "requests"] return ["transformers", "huggingface_hub", "torch"] def generate_install(repo_type: str, meta: dict[str, Any]) -> str: return "python -m pip install " + " ".join(compute_requirements(repo_type, meta)) def generate_quickstart(repo_type: str, repo_id: str, meta: dict[str, Any]) -> str: repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) pipeline_tag = (meta or {}).get("_pipeline_tag", "N/A") sdk = (meta or {}).get("_sdk", "N/A") risk = (meta or {}).get("_risk", {}) or {} has_gguf = bool(risk.get("has_gguf") or (meta or {}).get("_has_gguf")) files = (meta or {}).get("_files", []) or [] repo_id_literal = py_literal(repo_id) if repo_type == "dataset": return textwrap.dedent( f""" from datasets import load_dataset ds = load_dataset({repo_id_literal}) print(ds) """ ).strip() if repo_type == "space": repo_dir_literal = py_literal(repo_id.split("/")[-1]) space_url_literal = py_literal(hf_url("space", repo_id)) if sdk == "streamlit": return textwrap.dedent( f""" import os import subprocess subprocess.check_call(["git", "clone", {space_url_literal}]) os.chdir({repo_dir_literal}) subprocess.check_call(["python", "-m", "pip", "install", "-r", "requirements.txt"]) subprocess.check_call(["streamlit", "run", "app.py"]) """ ).strip() return textwrap.dedent( f""" import os import subprocess subprocess.check_call(["git", "clone", {space_url_literal}]) os.chdir({repo_dir_literal}) subprocess.check_call(["python", "-m", "pip", "install", "-r", "requirements.txt"]) subprocess.check_call(["python", "app.py"]) """ ).strip() if has_gguf: gguf_name = first_file_with_ext(files, ".gguf") or "MODEL.gguf" gguf_name_literal = py_literal(gguf_name) return textwrap.dedent( f""" from huggingface_hub import hf_hub_download from llama_cpp import Llama gguf_path = hf_hub_download(repo_id={repo_id_literal}, filename={gguf_name_literal}) llm = Llama(model_path=gguf_path, n_ctx=4096) out = llm("Q: Hello!\\nA:", max_tokens=128) print(out["choices"][0]["text"]) """ ).strip() if pipeline_tag == "text-generation": return textwrap.dedent( f""" from transformers import pipeline pipe = pipeline( "text-generation", model={repo_id_literal}, device_map="auto", ) out = pipe("Hello, Hugging Face!", max_new_tokens=64) print(out[0]["generated_text"]) """ ).strip() if pipeline_tag == "text-classification": return textwrap.dedent( f""" from transformers import pipeline clf = pipeline("text-classification", model={repo_id_literal}) print(clf("I love this project.")) """ ).strip() if pipeline_tag == "image-classification": return textwrap.dedent( f""" from io import BytesIO import requests from PIL import Image from transformers import pipeline image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" image = Image.open(BytesIO(requests.get(image_url, timeout=20).content)) pipe = pipeline("image-classification", model={repo_id_literal}) print(pipe(image)) """ ).strip() return textwrap.dedent( f""" from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained({repo_id_literal}) model = AutoModel.from_pretrained({repo_id_literal}) print(type(tokenizer)) print(type(model)) """ ).strip() def generate_snapshot_download(repo_type: str, repo_id: str) -> str: repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) local_dir = f"./{repo_id.split('/')[-1]}" lines = [ "from huggingface_hub import snapshot_download", "", "path = snapshot_download(", f" repo_id={py_literal(repo_id)},", ] if repo_type != "model": lines.append(f" repo_type={py_literal(repo_type)},") lines.extend( [ f" local_dir={py_literal(local_dir)},", ")", 'print(f"Downloaded to: {path}")', ] ) return "\n".join(lines) def generate_cli_download(repo_type: str, repo_id: str) -> str: repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) return f'hf download {repo_id} --repo-type {repo_type} --local-dir "./downloaded_repo"' def generate_badge(repo_type: str, repo_id: str) -> str: repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) url = hf_url(repo_type, repo_id) encoded = repo_id.replace("/", "%2F") return ( f"[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-{encoded}-blue)]({url})" ) def token_allowed_for_repo(repo_id: str) -> bool: """Return whether the configured server token may be used for this repo. Server-token mode is intentionally fail-closed: enabling ALLOW_SERVER_TOKEN is not enough on its own. TOKEN_ALLOWED_OWNERS must also scope the token to trusted Hugging Face owners. """ owners = os.getenv("TOKEN_ALLOWED_OWNERS", "").strip() if not owners: return False allowed_owners = {owner.strip().lower() for owner in owners.split(",") if owner.strip()} owner = (norm_id(repo_id).split("/")[0] if "/" in norm_id(repo_id) else "").lower() return bool(owner) and owner in allowed_owners def get_effective_token(repo_id: str) -> str | None: if os.getenv("ALLOW_SERVER_TOKEN", "").strip() != "1": return None token = (os.getenv("HF_TOKEN") or "").strip() if not token: return None return token if token_allowed_for_repo(repo_id) else None def fetch_repo_info( repo_type: str, repo_id: str, token: str | None ) -> tuple[bool, dict[str, Any] | None, str | None]: api = HfApi() repo_type = norm_type(repo_type) repo_id = norm_id(repo_id) token = (token or "").strip() or None if not repo_id: return False, None, "Empty Repo ID." if not is_valid_repo_id(repo_id): return False, None, "Invalid Repo ID. Expected: repo-name or owner/name" try: if repo_type == "dataset": info = call_with_supported_kwargs(api.dataset_info, repo_id, token=token, files_metadata=True) elif repo_type == "space": info = call_with_supported_kwargs(api.space_info, repo_id, token=token, files_metadata=True) else: info = call_with_supported_kwargs(api.model_info, repo_id, token=token, files_metadata=True) card = getattr(info, "cardData", None) or {} license_name = card.get("license") or getattr(info, "license", None) or "N/A" gated = getattr(info, "gated", None) private = getattr(info, "private", None) pipeline = getattr(info, "pipeline_tag", None) or "N/A" sdk = getattr(info, "sdk", None) or "N/A" files = extract_file_entries(info) if not files: try: names = api.list_repo_files(repo_id=repo_id, repo_type=repo_type, token=token) files = [{"path": name, "size": None} for name in (names or [])] except Exception: files = [] risk = files_risk_report(files) total_size = human_bytes(risk.get("total_size_known")) if risk.get("total_size_known") else "N/A" preview: dict[str, Any] = { "Repo ID": getattr(info, "id", repo_id), "Type": repo_type, "Author": getattr(info, "author", None) or getattr(info, "owner", None) or "N/A", "Likes": getattr(info, "likes", 0) or 0, "Downloads": getattr(info, "downloads", 0) or 0, "Last Modified": safe_str(getattr(info, "lastModified", "N/A"), 200), "License": str(license_name) if license_name else "N/A", "Pipeline": str(pipeline) if pipeline else "N/A", "Gated": "Yes" if gated is True else ("No" if gated is False else "N/A"), "Private": "Yes" if private is True else ("No" if private is False else "N/A"), "Total Size": total_size, "Files Count": risk.get("files_count", 0), } if repo_type == "space": preview["SDK"] = sdk or "N/A" hardware = getattr(info, "hardware", None) if hardware: preview["Hardware"] = safe_str(hardware, 200) preview.update( { "_pipeline_tag": pipeline or "N/A", "_sdk": sdk or "N/A", "_files": files, "_risk": risk, "_has_gguf": bool(risk.get("has_gguf")), "_rid": repo_id, "_rt": repo_type, } ) return True, preview, None except HfHubHTTPError as error: return False, None, safe_hf_error(error) except Exception as error: return False, None, f"Unexpected Error: {safe_str(error, 500)}" _PUBLIC_CACHE: dict[tuple[str, str], tuple[bool, dict[str, Any] | None, str | None]] = {} def cached_public(repo_type: str, repo_id: str) -> tuple[bool, dict[str, Any] | None, str | None]: """Fetch public repo metadata and cache successful responses only. Transient network errors should not become sticky until process restart. """ key = (norm_type(repo_type), norm_id(repo_id)) if key in _PUBLIC_CACHE: return _PUBLIC_CACHE[key] result = fetch_repo_info(key[0], key[1], token=None) if result[0]: _PUBLIC_CACHE[key] = result return result def build_export_files(state: dict[str, Any]) -> dict[str, str]: if not isinstance(state, dict) or not state.get("Repo ID"): raise ValueError("No repo loaded yet.") repo_type = norm_type(state.get("Type", "model")) repo_id = norm_id(state.get("Repo ID", "")) or norm_id(state.get("_rid", "")) if not is_valid_repo_id(repo_id): raise ValueError("Invalid Repo ID. Expected: repo-name or owner/name") install = generate_install(repo_type, state) quickstart = generate_quickstart(repo_type, repo_id, state) snapshot = generate_snapshot_download(repo_type, repo_id) requirements = compute_requirements(repo_type, state) readme = textwrap.dedent( f""" # QuickStart — {repo_id} Minimal first-run scaffold generated for `{repo_id}`. ## Setup ```bash python -m venv .venv python -m pip install -r requirements.txt ``` ## Run ```bash python run.py ``` ## Download full snapshot ```bash python download.py ``` ## Reference install ```bash {install} ``` """ ).strip() run_py = "\n".join( [ "def main():", ' print("Install/reference command:")', f" print({install!r})", "", textwrap.indent(quickstart, " "), "", "", 'if __name__ == "__main__":', " main()", ] ) download_py = snapshot.strip() return { "README.md": readme + "\n", "requirements.txt": "\n".join(requirements) + "\n", ".env.example": "HF_TOKEN=\n", "run.py": run_py + "\n", "download.py": download_py + "\n", } def build_quickstart_zip(state: dict[str, Any]) -> tuple[str | None, str]: try: files = build_export_files(state) except ValueError as error: return None, str(error) repo_id = norm_id(state.get("Repo ID", "")) or norm_id(state.get("_rid", "repo")) temp_dir = Path(tempfile.mkdtemp(prefix="quickstart_")) zip_path = temp_dir / f"{repo_id.replace('/', '__')}_quickstart.zip" project_dir = temp_dir / "project" project_dir.mkdir(parents=True, exist_ok=True) for name, content in files.items(): path = project_dir / name path.write_text(content, encoding="utf-8") with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as archive: for name in files: archive.write(project_dir / name, arcname=name) return str(zip_path), "Zip built. Download it, unzip it, then run: python run.py"