Spaces:
Running
Running
| from __future__ import annotations | |
| import ast | |
| import os | |
| import re | |
| import tokenize | |
| from io import StringIO | |
| from typing import Any, Callable, cast | |
| from .constants import ( | |
| GRAPH_SCAN_LIMIT_CAP, | |
| LIKES_SCAN_LIMIT_CAP, | |
| OUTPUT_ITEMS_TRUNCATION_LIMIT, | |
| SELECTIVE_ENDPOINT_RETURN_HARD_CAP, | |
| TRENDING_ENDPOINT_MAX_LIMIT, | |
| ) | |
| from .registry import ( | |
| ALLOWLIST_PATTERNS, | |
| HELPER_EXTERNALS, | |
| STRICT_ALLOWLIST_PATTERNS, | |
| ) | |
| def _resolve_helper_functions( | |
| namespace: dict[str, Any], | |
| ) -> dict[str, Callable[..., Any]]: | |
| resolved: dict[str, Callable[..., Any]] = {} | |
| for helper_name in HELPER_EXTERNALS: | |
| candidate = namespace.get(helper_name) | |
| if not callable(candidate): | |
| raise RuntimeError(f"Helper '{helper_name}' is not defined or not callable") | |
| resolved[helper_name] = cast(Callable[..., Any], candidate) | |
| return resolved | |
| def _normalize_endpoint(endpoint: str) -> str: | |
| ep = (endpoint or "").strip() | |
| if not ep: | |
| raise ValueError("endpoint is required") | |
| if "?" in ep: | |
| raise ValueError("endpoint must not include query string; use params") | |
| if ep.startswith("http://") or ep.startswith("https://"): | |
| raise ValueError("endpoint must be path-only") | |
| if not ep.startswith("/"): | |
| ep = "/" + ep | |
| if not ep.startswith("/api/"): | |
| ep = "/api" + ep | |
| if ep in {"/api/collections/search", "/api/collections/search/"}: | |
| ep = "/api/collections" | |
| if ".." in ep: | |
| raise ValueError("path traversal not allowed") | |
| return ep | |
| def _endpoint_allowed(endpoint: str, strict_mode: bool) -> bool: | |
| path = endpoint.split("?", 1)[0] | |
| patterns = STRICT_ALLOWLIST_PATTERNS if strict_mode else ALLOWLIST_PATTERNS | |
| return any(re.match(p, path) for p in patterns) | |
| def _sanitize_params(endpoint: str, params: dict[str, Any] | None) -> dict[str, Any]: | |
| clean = dict(params or {}) | |
| path = endpoint.split("?", 1)[0] | |
| if path == "/api/collections": | |
| if "q" not in clean and "search" in clean: | |
| clean["q"] = clean.get("search") | |
| clean.pop("search", None) | |
| if path == "/api/trending": | |
| t = str(clean.get("type") or "").strip().lower() | |
| aliases = {"models": "model", "datasets": "dataset", "spaces": "space"} | |
| if t in aliases: | |
| clean["type"] = aliases[t] | |
| lim = clean.get("limit") | |
| if lim is not None: | |
| try: | |
| n = int(lim) | |
| except Exception: | |
| n = TRENDING_ENDPOINT_MAX_LIMIT | |
| clean["limit"] = max(1, min(n, TRENDING_ENDPOINT_MAX_LIMIT)) | |
| return clean | |
| lim = clean.get("limit") | |
| if lim is None: | |
| return clean | |
| try: | |
| n = int(lim) | |
| except Exception: | |
| return clean | |
| endpoint_limit_max = SELECTIVE_ENDPOINT_RETURN_HARD_CAP | |
| if re.match(r"^/api/users/[^/]+/(followers|following)$", path): | |
| endpoint_limit_max = GRAPH_SCAN_LIMIT_CAP | |
| elif re.match(r"^/api/users/[^/]+/likes$", path): | |
| endpoint_limit_max = LIKES_SCAN_LIMIT_CAP | |
| clean["limit"] = max(1, min(n, endpoint_limit_max)) | |
| return clean | |
| def _truncate_result_payload(output: Any) -> Any: | |
| if not isinstance(output, dict): | |
| return output | |
| items = output.get("items") | |
| if not isinstance(items, list) or len(items) <= OUTPUT_ITEMS_TRUNCATION_LIMIT: | |
| return output | |
| trimmed = dict(output) | |
| trimmed_items = items[:OUTPUT_ITEMS_TRUNCATION_LIMIT] | |
| trimmed["items"] = trimmed_items | |
| trimmed["item"] = trimmed_items[0] if len(trimmed_items) == 1 else None | |
| note = f"truncated items to first {OUTPUT_ITEMS_TRUNCATION_LIMIT} rows for token efficiency" | |
| steps = trimmed.get("steps") | |
| if isinstance(steps, list): | |
| trimmed["steps"] = [*steps, note] | |
| else: | |
| trimmed["steps"] = [note] | |
| return trimmed | |
| def _verbose_result_meta_enabled() -> bool: | |
| value = os.environ.get("MONTY_VERBOSE_RESULT_META", "") | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| def _is_helper_meta_dict(value: Any) -> bool: | |
| return ( | |
| isinstance(value, dict) | |
| and isinstance(value.get("source"), str) | |
| and ( | |
| value.get("normalized") is True | |
| or "budget_used" in value | |
| or "budget_remaining" in value | |
| ) | |
| ) | |
| def _helper_meta_is_partial(value: dict[str, Any]) -> bool: | |
| return any( | |
| [ | |
| value.get("truncated") is True, | |
| value.get("more_available") not in {False, None}, | |
| value.get("limit_boundary_hit") is True, | |
| value.get("sample_complete") is False, | |
| value.get("exact_count") is False, | |
| value.get("ranking_complete") is False, | |
| value.get("ranking_window_hit") is True, | |
| value.get("hard_cap_applied") is True, | |
| ] | |
| ) | |
| def _compact_helper_meta(value: dict[str, Any]) -> dict[str, Any]: | |
| partial = _helper_meta_is_partial(value) | |
| compact: dict[str, Any] = { | |
| "partial": partial, | |
| } | |
| for key in ( | |
| "source", | |
| "returned", | |
| "total", | |
| "matched", | |
| "more_available", | |
| "truncated", | |
| "truncated_by", | |
| "exact_count", | |
| "sample_complete", | |
| "hard_cap_applied", | |
| "limit_boundary_hit", | |
| "can_request_more", | |
| "next_request_hint", | |
| "ranking_window", | |
| "ranking_window_hit", | |
| "ranking_complete", | |
| "ranking_next_request_hint", | |
| "relation", | |
| "username", | |
| "organization", | |
| "entity", | |
| "entity_type", | |
| "handle", | |
| ): | |
| if value.get(key) is not None: | |
| compact[key] = value.get(key) | |
| if compact.get("total") is None and value.get("total_available") is not None: | |
| compact["total"] = value.get("total_available") | |
| return compact | |
| def _compact_result_metadata(value: Any) -> Any: | |
| if _verbose_result_meta_enabled(): | |
| return value | |
| if _is_helper_meta_dict(value): | |
| return _compact_helper_meta(value) | |
| if isinstance(value, dict): | |
| return {key: _compact_result_metadata(item) for key, item in value.items()} | |
| if isinstance(value, list): | |
| return [_compact_result_metadata(item) for item in value] | |
| return value | |
| def _is_helper_envelope(output: Any) -> bool: | |
| return ( | |
| isinstance(output, dict) | |
| and isinstance(output.get("ok"), bool) | |
| and "items" in output | |
| and "meta" in output | |
| and "error" in output | |
| ) | |
| def _summarize_limit_hit(helper_name: str, result: Any) -> dict[str, Any] | None: | |
| if not _is_helper_envelope(result): | |
| return None | |
| meta = result.get("meta") if isinstance(result.get("meta"), dict) else {} | |
| if not isinstance(meta, dict): | |
| return None | |
| truncated_by = str(meta.get("truncated_by") or "") | |
| limit_hit = any( | |
| [ | |
| _helper_meta_is_partial(meta), | |
| truncated_by in {"scan_limit", "page_limit", "multiple"}, | |
| ] | |
| ) | |
| if not limit_hit: | |
| return None | |
| summary: dict[str, Any] = { | |
| "helper": helper_name, | |
| "source": meta.get("source"), | |
| "returned": meta.get("returned"), | |
| "total": meta.get("total"), | |
| "truncated": meta.get("truncated"), | |
| "truncated_by": meta.get("truncated_by"), | |
| "more_available": meta.get("more_available"), | |
| "requested_limit": meta.get("requested_limit"), | |
| "applied_limit": meta.get("applied_limit"), | |
| "next_request_hint": meta.get("next_request_hint"), | |
| "limit_boundary_hit": meta.get("limit_boundary_hit"), | |
| } | |
| if meta.get("scan_limit") is not None: | |
| summary["scan_limit"] = meta.get("scan_limit") | |
| if meta.get("applied_max_pages") is not None: | |
| summary["applied_max_pages"] = meta.get("applied_max_pages") | |
| for key in ( | |
| "ranking_window", | |
| "requested_ranking_window", | |
| "ranking_window_applied", | |
| "ranking_window_hit", | |
| "ranking_complete", | |
| "ranking_next_request_hint", | |
| ): | |
| if meta.get(key) is not None: | |
| summary[key] = meta.get(key) | |
| return summary | |
| def _wrap_raw_result( | |
| result: Any, | |
| *, | |
| ok: bool, | |
| api_calls: int, | |
| elapsed_ms: int, | |
| limit_summaries: list[dict[str, Any]] | None = None, | |
| error: str | None = None, | |
| ) -> dict[str, Any]: | |
| hits = [dict(summary) for summary in (limit_summaries or [])[:10]] | |
| meta: dict[str, Any] = { | |
| "ok": ok, | |
| "api_calls": api_calls, | |
| "elapsed_ms": elapsed_ms, | |
| "limits_reached": bool(hits), | |
| "limit_summary": hits, | |
| } | |
| if error is not None: | |
| meta["error"] = error | |
| return { | |
| "result": result, | |
| "meta": meta, | |
| } | |
| def _validate_generated_code(code: str) -> None: | |
| if not code.strip(): | |
| raise ValueError("Generated code is empty") | |
| blocked_patterns: list[tuple[str, str]] = [ | |
| (r"(?m)^\s*import\s+\S", "import statement"), | |
| (r"(?m)^\s*from\s+\S+\s+import\s+\S", "from-import statement"), | |
| (r"\bexec\s*\(", "exec("), | |
| (r"\beval\s*\(", "eval("), | |
| (r"\bopen\s*\(", "open("), | |
| (r"\b__import__\b", "__import__"), | |
| (r"(?i)\bwhile\s+true\b", "while true"), | |
| ] | |
| for pattern, label in blocked_patterns: | |
| if re.search(pattern, code): | |
| raise ValueError(f"Generated code contains blocked pattern: {label}") | |
| try: | |
| parsed = compile( # noqa: S102 - compile is used for AST validation only. | |
| code, | |
| "<generated-monty-code>", | |
| "exec", | |
| flags=ast.PyCF_ONLY_AST | ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, | |
| dont_inherit=True, | |
| ) | |
| except SyntaxError as e: | |
| message = e.msg or "invalid syntax" | |
| raise ValueError(f"Generated code is not valid Python: {message}") from e | |
| if not isinstance(parsed, ast.Module): | |
| raise ValueError("Generated code must be a Python module") | |
| if not parsed.body: | |
| raise ValueError("Generated code is empty") | |
| final_stmt = parsed.body[-1] | |
| final_is_result = ( | |
| isinstance(final_stmt, ast.Expr) | |
| and isinstance(final_stmt.value, ast.Name) | |
| and final_stmt.value.id == "result" | |
| ) | |
| if not final_is_result: | |
| raise ValueError( | |
| "Generated code must assign the final output to `result` and end with a final line containing only `result` (do not stop after `result = ...`)." | |
| ) | |
| has_result_assignment = any( | |
| isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store) and node.id == "result" | |
| for node in ast.walk(parsed) | |
| ) | |
| if not has_result_assignment: | |
| raise ValueError( | |
| "Generated code must assign the final output to `result` before the final `result` line." | |
| ) | |
| for node in ast.walk(parsed): | |
| if not isinstance(node, ast.Call): | |
| continue | |
| if isinstance(node.func, ast.Name) and node.func.id == "call_api": | |
| raise ValueError( | |
| "Generated code must use documented hf_* helpers only; raw `call_api(...)` is not part of the prompt contract." | |
| ) | |
| helper_name_set = set(HELPER_EXTERNALS) | |
| has_external_call = any( | |
| isinstance(node, ast.Call) | |
| and isinstance(node.func, ast.Name) | |
| and node.func.id in helper_name_set | |
| for node in ast.walk(parsed) | |
| ) | |
| if not has_external_call: | |
| raise ValueError( | |
| "Generated code must call at least one documented hf_* helper." | |
| ) | |
| def _coerce_jsonish_python_literals(code: str) -> str: | |
| """Normalize common JSON literals into valid Python names in generated code.""" | |
| replacements = { | |
| "true": "True", | |
| "false": "False", | |
| "null": "None", | |
| } | |
| out_tokens: list[tuple[int, str]] = [] | |
| for tok in tokenize.generate_tokens(StringIO(code).readline): | |
| tok_type = tok.type | |
| tok_str = tok.string | |
| if tok_type == tokenize.NAME and tok_str in replacements: | |
| tok_str = replacements[tok_str] | |
| out_tokens.append((tok_type, tok_str)) | |
| return tokenize.untokenize(out_tokens) | |