wenjiao's picture
refactor repo code
b15b21e
"""HuggingFace repository → streaming ZIP download endpoint.
The :func:`register_routes` helper mounts the ``/api/hf-model-zip`` route on a
FastAPI app **before** the Gradio sub-app is mounted, so that HuggingFace
Spaces' ``/api/*`` proxy does not intercept it.
"""
from __future__ import annotations
import logging
import os
import queue
import re
import threading
import zipfile
from urllib.parse import quote as _url_quote
import requests
from src.envs import HF_TOKEN, INC4AI_TOKEN, lvkaokao_TOKEN
logger = logging.getLogger(__name__)
_HF_BASE = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
_ZIP_SENTINEL = object()
def get_token_for_artifact(artifact: str) -> str:
"""Return the right HF token based on the artifact's org/user prefix."""
prefix = str(artifact or "").strip().split("/")[0]
token_map = {"lvkaokao": lvkaokao_TOKEN, "INC4AI": INC4AI_TOKEN}
return token_map.get(prefix) or HF_TOKEN or ""
def _encode_repo(repo: str) -> str:
return "/".join(_url_quote(part, safe="") for part in repo.split("/"))
def _encode_path(path: str) -> str:
return "/".join(_url_quote(part, safe="") for part in path.split("/"))
def _get_repo_api_path(repo_type: str, repo: str) -> str:
enc = _encode_repo(repo)
if repo_type == "dataset":
return f"/api/datasets/{enc}"
if repo_type == "space":
return f"/api/spaces/{enc}"
return f"/api/models/{enc}"
def _get_repo_resolve_prefix(repo_type: str, repo: str) -> str:
enc = _encode_repo(repo)
if repo_type == "dataset":
return f"/datasets/{enc}"
if repo_type == "space":
return f"/spaces/{enc}"
return f"/{enc}"
def _list_repo_files(repo: str, revision: str, repo_type: str, token: str) -> list:
api_path = _get_repo_api_path(repo_type, repo)
url = f"{_HF_BASE}{api_path}/tree/{_url_quote(revision, safe='')}?recursive=1"
hdr = {"User-Agent": "hf-zip-stream/1.0"}
if token:
hdr["Authorization"] = f"Bearer {token}"
r = requests.get(url, headers=hdr, timeout=60)
if not r.ok:
raise RuntimeError(f"Failed to list files: {r.status_code} {r.text[:500]}")
return [item for item in r.json() if item.get("type") == "file"]
def _get_file_download_url(repo: str, repo_type: str, revision: str, path: str) -> str:
prefix = _get_repo_resolve_prefix(repo_type, repo)
return f"{_HF_BASE}{prefix}/resolve/{_url_quote(revision, safe='')}/{_encode_path(path)}"
def _safe_zip_filename(value: str) -> str:
value = re.sub(r"[^\w.-]+", "_", value)
return value.strip("_") or "repo"
class _QueueWriter:
"""File-like adapter that pushes ``write()`` payloads onto a queue."""
def __init__(self, q: "queue.Queue"):
self._q = q
def write(self, data):
if data:
self._q.put(bytes(data))
return len(data)
def flush(self):
pass
def _stream_repo_as_zip(repo: str, revision: str, repo_type: str, files: list, token: str):
"""Yield ZIP byte chunks as the producer thread fills them in."""
out_q: queue.Queue = queue.Queue(maxsize=16)
def _producer():
try:
writer = _QueueWriter(out_q)
hdr = {"User-Agent": "hf-zip-stream/1.0"}
if token:
hdr["Authorization"] = f"Bearer {token}"
with zipfile.ZipFile(writer, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True) as zf:
for fi in files:
path = fi["path"]
url = _get_file_download_url(repo, repo_type, revision, path)
logger.info("[ZIP] downloading %s", path)
with requests.get(url, headers=hdr, stream=True, timeout=(20, 300)) as resp:
if not resp.ok:
raise RuntimeError(f"Failed to download {path}: HTTP {resp.status_code}")
zi = zipfile.ZipInfo(filename=path)
zi.compress_type = zipfile.ZIP_STORED
if fi.get("size") is not None:
zi.file_size = int(fi["size"])
with zf.open(zi, mode="w", force_zip64=True) as entry:
for chunk in resp.iter_content(1 << 20):
if chunk:
entry.write(chunk)
logger.info("[ZIP] streaming complete for %s", repo)
except Exception as exc:
logger.error("[ZIP] producer error: %s", exc)
out_q.put(exc)
finally:
out_q.put(_ZIP_SENTINEL)
threading.Thread(target=_producer, daemon=True).start()
while True:
item = out_q.get()
if item is _ZIP_SENTINEL:
break
if isinstance(item, Exception):
raise item
yield item
def register_routes(app):
"""Attach the ``/api/hf-model-zip`` endpoint to *app* (a FastAPI app)."""
from fastapi import HTTPException, Query
from fastapi.responses import StreamingResponse
@app.get("/api/hf-model-zip")
def hf_model_zip(
repo: str = Query(...),
revision: str = Query("main"),
repoType: str = Query("model"),
):
"""Stream a ZIP of all files in a HuggingFace repo."""
repo = (repo or "").strip()
revision = (revision or "main").strip() or "main"
repo_type = (repoType or "model").strip().lower() or "model"
if repo_type not in {"model", "dataset", "space"}:
raise HTTPException(status_code=400, detail="repoType must be model, dataset, or space")
if not repo or "/" not in repo:
raise HTTPException(status_code=400, detail="Invalid repo parameter, expected owner/name")
logger.info("[ZIP] repo=%s revision=%s repo_type=%s", repo, revision, repo_type)
token = get_token_for_artifact(repo)
try:
files = _list_repo_files(repo, revision, repo_type, token)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
if not files:
raise HTTPException(status_code=404, detail="No files found in repo")
logger.info("[ZIP] %d files to stream for %s", len(files), repo)
zip_name = f"{_safe_zip_filename(repo.split('/')[-1])}-{_safe_zip_filename(revision)}.zip"
return StreamingResponse(
_stream_repo_as_zip(repo, revision, repo_type, files, token),
media_type="application/zip",
headers={
"Content-Disposition": f'attachment; filename="{zip_name}"',
"Cache-Control": "no-store",
"X-Accel-Buffering": "no",
},
)
return app