Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import TYPE_CHECKING, Any | |
| from urllib.error import HTTPError, URLError | |
| from urllib.parse import urlencode | |
| from urllib.request import Request, urlopen | |
| from .aliases import REPO_SORT_KEYS | |
| from .constants import ( | |
| DEFAULT_TIMEOUT_SEC, | |
| ) | |
| from .registry import REPO_API_ADAPTERS, REPO_SEARCH_DEFAULT_EXPAND | |
| from .validation import _endpoint_allowed, _normalize_endpoint, _sanitize_params | |
| if TYPE_CHECKING: | |
| from huggingface_hub import HfApi | |
| def _load_request_token() -> str | None: | |
| try: | |
| from fast_agent.mcp.auth.context import request_bearer_token # type: ignore | |
| token = request_bearer_token.get() | |
| if token: | |
| return token | |
| except Exception: | |
| pass | |
| return None | |
| def _load_token() -> str | None: | |
| token = _load_request_token() | |
| if token: | |
| return token | |
| return os.getenv("HF_TOKEN") or None | |
| def _json_best_effort(raw: bytes) -> Any: | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| return raw.decode("utf-8", errors="replace") | |
| def _clamp_int(value: Any, *, default: int, minimum: int, maximum: int) -> int: | |
| try: | |
| out = int(value) | |
| except Exception: | |
| out = default | |
| return max(minimum, min(out, maximum)) | |
| def _as_int(value: Any) -> int | None: | |
| try: | |
| return int(value) | |
| except Exception: | |
| return None | |
| def _canonical_repo_type(value: Any, *, default: str = "model") -> str: | |
| raw = str(value or "").strip().lower() | |
| aliases = { | |
| "model": "model", | |
| "models": "model", | |
| "dataset": "dataset", | |
| "datasets": "dataset", | |
| "space": "space", | |
| "spaces": "space", | |
| } | |
| return aliases.get(raw, default) | |
| def _normalize_repo_sort_key( | |
| repo_type: str, sort_value: Any | |
| ) -> tuple[str | None, str | None]: | |
| raw = str(sort_value or "").strip() | |
| if not raw: | |
| return None, None | |
| key = raw | |
| if key not in { | |
| "created_at", | |
| "downloads", | |
| "last_modified", | |
| "likes", | |
| "trending_score", | |
| }: | |
| return None, f"Invalid sort key '{raw}'" | |
| rt = _canonical_repo_type(repo_type) | |
| allowed = REPO_SORT_KEYS.get(rt, set()) | |
| if key not in allowed: | |
| return ( | |
| None, | |
| f"Invalid sort key '{raw}' for repo_type='{rt}'. Allowed: {', '.join(sorted(allowed))}", | |
| ) | |
| return key, None | |
| def _repo_api_adapter(repo_type: str) -> Any: | |
| rt = _canonical_repo_type(repo_type, default="") | |
| adapter = REPO_API_ADAPTERS.get(rt) | |
| if adapter is None: | |
| raise ValueError(f"Unsupported repo_type '{repo_type}'") | |
| return adapter | |
| def _repo_list_call(api: HfApi, repo_type: str, **kwargs: Any) -> list[Any]: | |
| adapter = _repo_api_adapter(repo_type) | |
| method = getattr(api, adapter.list_method_name) | |
| return list(method(**kwargs)) | |
| def _repo_detail_call(api: HfApi, repo_type: str, repo_id: str) -> Any: | |
| adapter = _repo_api_adapter(repo_type) | |
| method = getattr(api, adapter.detail_method_name) | |
| if _canonical_repo_type(repo_type) == "space": | |
| return method(repo_id, expand=list(REPO_SEARCH_DEFAULT_EXPAND["space"])) | |
| return method(repo_id) | |
| def _coerce_str_list(value: Any) -> list[str]: | |
| if value is None: | |
| return [] | |
| if isinstance(value, str): | |
| raw = [value] | |
| elif isinstance(value, (list, tuple, set)): | |
| raw = list(value) | |
| else: | |
| raise ValueError("Expected a string or list of strings") | |
| return [str(v).strip() for v in raw if str(v).strip()] | |
| def _optional_str_list(value: Any) -> list[str] | None: | |
| if value is None: | |
| return None | |
| if isinstance(value, str): | |
| out = [value.strip()] if value.strip() else [] | |
| return out or None | |
| if isinstance(value, (list, tuple, set)): | |
| out = [str(v).strip() for v in value if str(v).strip()] | |
| return out or None | |
| return None | |
| def _space_runtime_to_dict(value: Any) -> dict[str, Any] | None: | |
| if value is None: | |
| return None | |
| if isinstance(value, dict): | |
| raw = value | |
| hardware = raw.get("hardware") | |
| current_hardware = ( | |
| hardware.get("current") if isinstance(hardware, dict) else hardware | |
| ) | |
| requested_hardware = ( | |
| hardware.get("requested") | |
| if isinstance(hardware, dict) | |
| else raw.get("requested_hardware") or raw.get("requestedHardware") | |
| ) | |
| sleep_time = _as_int( | |
| raw.get("gcTimeout") | |
| if raw.get("gcTimeout") is not None | |
| else raw.get("sleep_time") or raw.get("sleepTime") | |
| ) | |
| out = { | |
| "stage": raw.get("stage"), | |
| "hardware": current_hardware, | |
| "requested_hardware": requested_hardware, | |
| "sleep_time": sleep_time, | |
| } | |
| return {key: val for key, val in out.items() if val is not None} or None | |
| out = { | |
| "stage": getattr(value, "stage", None), | |
| "hardware": getattr(value, "hardware", None), | |
| "requested_hardware": getattr(value, "requested_hardware", None), | |
| "sleep_time": _as_int(getattr(value, "sleep_time", None)), | |
| } | |
| return {key: val for key, val in out.items() if val is not None} or None | |
| def _extract_num_params(num_params: Any = None, safetensors: Any = None) -> int | None: | |
| direct = _as_int(num_params) | |
| if direct is not None: | |
| return direct | |
| total = getattr(safetensors, "total", None) | |
| if total is None and isinstance(safetensors, dict): | |
| total = safetensors.get("total") | |
| return _as_int(total) | |
| def _extract_num_params_from_object(row: Any) -> int | None: | |
| raw_num_params = getattr(row, "num_params", None) | |
| if raw_num_params is None: | |
| raw_num_params = getattr(row, "numParameters", None) | |
| if raw_num_params is None: | |
| raw_num_params = getattr(row, "num_parameters", None) | |
| return _extract_num_params(raw_num_params, getattr(row, "safetensors", None)) | |
| def _extract_num_params_from_dict(row: dict[str, Any]) -> int | None: | |
| raw_num_params = row.get("num_params") | |
| if raw_num_params is None: | |
| raw_num_params = row.get("numParameters") | |
| if raw_num_params is None: | |
| raw_num_params = row.get("num_parameters") | |
| return _extract_num_params(raw_num_params, row.get("safetensors")) | |
| def _extract_author_names(value: Any) -> list[str] | None: | |
| if not isinstance(value, (list, tuple)): | |
| return None | |
| names: list[str] = [] | |
| for item in value: | |
| if isinstance(item, str) and item.strip(): | |
| names.append(item.strip()) | |
| continue | |
| if isinstance(item, dict): | |
| name = item.get("name") | |
| if isinstance(name, str) and name.strip(): | |
| names.append(name.strip()) | |
| continue | |
| name = getattr(item, "name", None) | |
| if isinstance(name, str) and name.strip(): | |
| names.append(name.strip()) | |
| return names or None | |
| def _extract_profile_name(value: Any) -> str | None: | |
| if isinstance(value, str) and value.strip(): | |
| return value.strip() | |
| if isinstance(value, dict): | |
| for key in ("user", "name", "fullname", "handle"): | |
| candidate = value.get(key) | |
| if isinstance(candidate, str) and candidate.strip(): | |
| return candidate.strip() | |
| return None | |
| for attr in ("user", "name", "fullname", "handle"): | |
| candidate = getattr(value, attr, None) | |
| if isinstance(candidate, str) and candidate.strip(): | |
| return candidate.strip() | |
| return None | |
| def _author_from_any(value: Any) -> str | None: | |
| if isinstance(value, str) and value: | |
| return value | |
| if isinstance(value, dict): | |
| for key in ("name", "username", "user", "login"): | |
| candidate = value.get(key) | |
| if isinstance(candidate, str) and candidate: | |
| return candidate | |
| return None | |
| def _dt_to_str(value: Any) -> str | None: | |
| if value is None: | |
| return None | |
| iso = getattr(value, "isoformat", None) | |
| if callable(iso): | |
| try: | |
| return str(iso()) | |
| except Exception: | |
| pass | |
| return str(value) | |
| def _repo_web_url(repo_type: str, repo_id: str | None) -> str | None: | |
| if not isinstance(repo_id, str) or not repo_id: | |
| return None | |
| base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") | |
| rt = _canonical_repo_type(repo_type, default="") | |
| if rt == "dataset": | |
| return f"{base}/datasets/{repo_id}" | |
| if rt == "space": | |
| return f"{base}/spaces/{repo_id}" | |
| return f"{base}/{repo_id}" | |
| def _build_repo_row( | |
| *, | |
| repo_id: Any, | |
| repo_type: str, | |
| author: Any = None, | |
| likes: Any = None, | |
| downloads: Any = None, | |
| created_at: Any = None, | |
| last_modified: Any = None, | |
| pipeline_tag: Any = None, | |
| num_params: Any = None, | |
| private: Any = None, | |
| trending_score: Any = None, | |
| tags: Any = None, | |
| sha: Any = None, | |
| gated: Any = None, | |
| library_name: Any = None, | |
| description: Any = None, | |
| paperswithcode_id: Any = None, | |
| sdk: Any = None, | |
| models: Any = None, | |
| datasets: Any = None, | |
| subdomain: Any = None, | |
| runtime: Any = None, | |
| runtime_stage: Any = None, | |
| ) -> dict[str, Any]: | |
| rt = _canonical_repo_type(repo_type) | |
| author_value = author | |
| if ( | |
| not isinstance(author_value, str) | |
| and isinstance(repo_id, str) | |
| and "/" in repo_id | |
| ): | |
| author_value = repo_id.split("/", 1)[0] | |
| runtime_payload = _space_runtime_to_dict(runtime) | |
| resolved_runtime_stage = ( | |
| runtime_stage | |
| if runtime_stage is not None | |
| else runtime_payload.get("stage") | |
| if isinstance(runtime_payload, dict) | |
| else None | |
| ) | |
| return { | |
| "id": repo_id, | |
| "slug": repo_id, | |
| "repo_id": repo_id, | |
| "repo_type": rt, | |
| "author": author_value, | |
| "likes": _as_int(likes), | |
| "downloads": _as_int(downloads), | |
| "created_at": _dt_to_str(created_at), | |
| "last_modified": _dt_to_str(last_modified), | |
| "pipeline_tag": pipeline_tag, | |
| "num_params": _as_int(num_params), | |
| "private": private, | |
| "trending_score": _as_int(trending_score) | |
| if trending_score is not None | |
| else None, | |
| "repo_url": _repo_web_url(rt, repo_id if isinstance(repo_id, str) else None), | |
| "tags": _optional_str_list(tags), | |
| "sha": sha, | |
| "gated": gated, | |
| "library_name": library_name, | |
| "description": description, | |
| "paperswithcode_id": paperswithcode_id, | |
| "sdk": sdk, | |
| "models": _optional_str_list(models), | |
| "datasets": _optional_str_list(datasets), | |
| "subdomain": subdomain, | |
| "runtime_stage": resolved_runtime_stage, | |
| "runtime": runtime_payload, | |
| } | |
| def _normalize_repo_search_row(row: Any, repo_type: str) -> dict[str, Any]: | |
| return _build_repo_row( | |
| repo_id=getattr(row, "id", None), | |
| repo_type=repo_type, | |
| author=getattr(row, "author", None), | |
| likes=getattr(row, "likes", None), | |
| downloads=getattr(row, "downloads", None), | |
| created_at=getattr(row, "created_at", None), | |
| last_modified=getattr(row, "last_modified", None), | |
| pipeline_tag=getattr(row, "pipeline_tag", None), | |
| num_params=_extract_num_params_from_object(row), | |
| private=getattr(row, "private", None), | |
| trending_score=getattr(row, "trending_score", None), | |
| tags=getattr(row, "tags", None), | |
| sha=getattr(row, "sha", None), | |
| gated=getattr(row, "gated", None), | |
| library_name=getattr(row, "library_name", None), | |
| description=getattr(row, "description", None), | |
| paperswithcode_id=getattr(row, "paperswithcode_id", None), | |
| sdk=getattr(row, "sdk", None), | |
| models=getattr(row, "models", None), | |
| datasets=getattr(row, "datasets", None), | |
| subdomain=getattr(row, "subdomain", None), | |
| runtime=getattr(row, "runtime", None), | |
| ) | |
| def _normalize_repo_detail_row( | |
| detail: Any, repo_type: str, repo_id: str | |
| ) -> dict[str, Any]: | |
| row = _normalize_repo_search_row(detail, repo_type) | |
| resolved_repo_id = row.get("repo_id") or repo_id | |
| row["id"] = row.get("id") or resolved_repo_id | |
| row["slug"] = row.get("slug") or resolved_repo_id | |
| row["repo_id"] = resolved_repo_id | |
| row["repo_url"] = _repo_web_url(repo_type, resolved_repo_id) | |
| return row | |
| def _normalize_trending_row( | |
| repo: dict[str, Any], default_repo_type: str, rank: int | None = None | |
| ) -> dict[str, Any]: | |
| row = _build_repo_row( | |
| repo_id=repo.get("id"), | |
| repo_type=repo.get("type") or repo.get("repoType") or default_repo_type, | |
| author=repo.get("author"), | |
| likes=repo.get("likes"), | |
| downloads=repo.get("downloads"), | |
| created_at=repo.get("createdAt"), | |
| last_modified=repo.get("lastModified"), | |
| pipeline_tag=repo.get("pipeline_tag"), | |
| num_params=_extract_num_params_from_dict(repo), | |
| private=repo.get("private"), | |
| trending_score=repo.get("trendingScore"), | |
| tags=repo.get("tags"), | |
| sha=repo.get("sha"), | |
| gated=repo.get("gated"), | |
| library_name=repo.get("library_name"), | |
| description=repo.get("description"), | |
| paperswithcode_id=repo.get("paperswithcode_id"), | |
| sdk=repo.get("sdk"), | |
| models=repo.get("models"), | |
| datasets=repo.get("datasets"), | |
| subdomain=repo.get("subdomain"), | |
| runtime=repo.get("runtime"), | |
| runtime_stage=repo.get("runtime_stage") or repo.get("runtimeStage"), | |
| ) | |
| if rank is not None: | |
| row["trending_rank"] = rank | |
| return row | |
| def _normalize_daily_paper_row( | |
| row: dict[str, Any], rank: int | None = None | |
| ) -> dict[str, Any]: | |
| paper = row.get("paper") if isinstance(row.get("paper"), dict) else {} | |
| org = ( | |
| row.get("organization") | |
| if isinstance(row.get("organization"), dict) | |
| else paper.get("organization") | |
| ) | |
| organization = None | |
| if isinstance(org, dict): | |
| organization = org.get("name") or org.get("fullname") | |
| item = { | |
| "paper_id": paper.get("id"), | |
| "title": row.get("title") or paper.get("title"), | |
| "summary": row.get("summary") | |
| or paper.get("summary") | |
| or paper.get("ai_summary"), | |
| "published_at": row.get("publishedAt") or paper.get("publishedAt"), | |
| "submitted_on_daily_at": paper.get("submittedOnDailyAt"), | |
| "authors": _extract_author_names(paper.get("authors")), | |
| "organization": organization, | |
| "submitted_by": _extract_profile_name( | |
| row.get("submittedBy") or paper.get("submittedOnDailyBy") | |
| ), | |
| "discussion_id": paper.get("discussionId"), | |
| "upvotes": _as_int(paper.get("upvotes")), | |
| "github_repo_url": paper.get("githubRepo"), | |
| "github_stars": _as_int(paper.get("githubStars")), | |
| "project_page_url": paper.get("projectPage"), | |
| "num_comments": _as_int(row.get("numComments")), | |
| "is_author_participating": row.get("isAuthorParticipating") | |
| if isinstance(row.get("isAuthorParticipating"), bool) | |
| else None, | |
| "repo_id": row.get("repo_id") or paper.get("repo_id"), | |
| "rank": rank, | |
| } | |
| return item | |
| def _normalize_collection_repo_item(row: dict[str, Any]) -> dict[str, Any] | None: | |
| repo_id = row.get("id") or row.get("repoId") or row.get("repo_id") | |
| if not isinstance(repo_id, str) or not repo_id: | |
| return None | |
| repo_type = _canonical_repo_type( | |
| row.get("repoType") or row.get("repo_type") or row.get("type"), default="" | |
| ) | |
| if repo_type not in {"model", "dataset", "space"}: | |
| return None | |
| return _build_repo_row( | |
| repo_id=repo_id, | |
| repo_type=repo_type, | |
| author=row.get("author") or _author_from_any(row.get("authorData")), | |
| likes=row.get("likes"), | |
| downloads=row.get("downloads"), | |
| created_at=row.get("createdAt") or row.get("created_at"), | |
| last_modified=row.get("lastModified") or row.get("last_modified"), | |
| pipeline_tag=row.get("pipeline_tag") or row.get("pipelineTag"), | |
| num_params=_extract_num_params_from_dict(row), | |
| private=row.get("private"), | |
| tags=row.get("tags"), | |
| gated=row.get("gated"), | |
| library_name=row.get("library_name") or row.get("libraryName"), | |
| description=row.get("description"), | |
| paperswithcode_id=row.get("paperswithcode_id") or row.get("paperswithcodeId"), | |
| sdk=row.get("sdk"), | |
| models=row.get("models"), | |
| datasets=row.get("datasets"), | |
| subdomain=row.get("subdomain"), | |
| runtime=row.get("runtime"), | |
| runtime_stage=row.get("runtime_stage") or row.get("runtimeStage"), | |
| ) | |
| def _sort_repo_rows( | |
| rows: list[dict[str, Any]], sort_key: str | None | |
| ) -> list[dict[str, Any]]: | |
| if not sort_key: | |
| return rows | |
| if sort_key in {"likes", "downloads", "trending_score"}: | |
| return sorted( | |
| rows, key=lambda row: _as_int(row.get(sort_key)) or -1, reverse=True | |
| ) | |
| if sort_key in {"created_at", "last_modified"}: | |
| return sorted(rows, key=lambda row: str(row.get(sort_key) or ""), reverse=True) | |
| return rows | |
| def call_api_host( | |
| endpoint: str, | |
| *, | |
| method: str = "GET", | |
| params: dict[str, Any] | None = None, | |
| json_body: dict[str, Any] | None = None, | |
| timeout_sec: int = DEFAULT_TIMEOUT_SEC, | |
| strict_mode: bool = False, | |
| ) -> dict[str, Any]: | |
| method_u = method.upper().strip() | |
| if method_u not in {"GET", "POST"}: | |
| raise ValueError("Only GET and POST are supported") | |
| ep = _normalize_endpoint(endpoint) | |
| if not _endpoint_allowed(ep, strict_mode): | |
| raise ValueError(f"Endpoint not allowed: {ep}") | |
| params = _sanitize_params(ep, params) | |
| if ep == "/api/recent-activity": | |
| feed_type = str((params or {}).get("feedType", "")).strip().lower() | |
| if feed_type not in {"user", "org"}: | |
| raise ValueError("/api/recent-activity requires feedType=user|org") | |
| if not str((params or {}).get("entity", "")).strip(): | |
| raise ValueError("/api/recent-activity requires entity") | |
| base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/") | |
| q = urlencode(params or {}, doseq=True) | |
| url = f"{base}{ep}" + (f"?{q}" if q else "") | |
| headers = {"Accept": "application/json"} | |
| token = _load_token() | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| data = None | |
| if method_u == "POST": | |
| headers["Content-Type"] = "application/json" | |
| data = json.dumps(json_body or {}).encode("utf-8") | |
| req = Request(url, method=method_u, headers=headers, data=data) | |
| try: | |
| with urlopen(req, timeout=timeout_sec) as res: | |
| payload = _json_best_effort(res.read()) | |
| return { | |
| "ok": True, | |
| "status": int(res.status), | |
| "url": url, | |
| "data": payload, | |
| "error": None, | |
| } | |
| except HTTPError as e: | |
| payload = _json_best_effort(e.read()) | |
| err = ( | |
| payload | |
| if isinstance(payload, str) | |
| else json.dumps(payload, ensure_ascii=False)[:1000] | |
| ) | |
| return { | |
| "ok": False, | |
| "status": int(e.code), | |
| "url": url, | |
| "data": payload, | |
| "error": err, | |
| } | |
| except URLError as e: | |
| return { | |
| "ok": False, | |
| "status": 0, | |
| "url": url, | |
| "data": None, | |
| "error": f"Network error: {e}", | |
| } | |