from __future__ import annotations import ast import os import re import tokenize from io import StringIO from typing import Any, Callable, cast from .constants import ( GRAPH_SCAN_LIMIT_CAP, LIKES_SCAN_LIMIT_CAP, OUTPUT_ITEMS_TRUNCATION_LIMIT, SELECTIVE_ENDPOINT_RETURN_HARD_CAP, TRENDING_ENDPOINT_MAX_LIMIT, ) from .registry import ( ALLOWLIST_PATTERNS, HELPER_EXTERNALS, STRICT_ALLOWLIST_PATTERNS, ) def _resolve_helper_functions( namespace: dict[str, Any], ) -> dict[str, Callable[..., Any]]: resolved: dict[str, Callable[..., Any]] = {} for helper_name in HELPER_EXTERNALS: candidate = namespace.get(helper_name) if not callable(candidate): raise RuntimeError(f"Helper '{helper_name}' is not defined or not callable") resolved[helper_name] = cast(Callable[..., Any], candidate) return resolved def _normalize_endpoint(endpoint: str) -> str: ep = (endpoint or "").strip() if not ep: raise ValueError("endpoint is required") if "?" in ep: raise ValueError("endpoint must not include query string; use params") if ep.startswith("http://") or ep.startswith("https://"): raise ValueError("endpoint must be path-only") if not ep.startswith("/"): ep = "/" + ep if not ep.startswith("/api/"): ep = "/api" + ep if ep in {"/api/collections/search", "/api/collections/search/"}: ep = "/api/collections" if ".." in ep: raise ValueError("path traversal not allowed") return ep def _endpoint_allowed(endpoint: str, strict_mode: bool) -> bool: path = endpoint.split("?", 1)[0] patterns = STRICT_ALLOWLIST_PATTERNS if strict_mode else ALLOWLIST_PATTERNS return any(re.match(p, path) for p in patterns) def _sanitize_params(endpoint: str, params: dict[str, Any] | None) -> dict[str, Any]: clean = dict(params or {}) path = endpoint.split("?", 1)[0] if path == "/api/collections": if "q" not in clean and "search" in clean: clean["q"] = clean.get("search") clean.pop("search", None) if path == "/api/trending": t = str(clean.get("type") or "").strip().lower() aliases = {"models": "model", "datasets": "dataset", "spaces": "space"} if t in aliases: clean["type"] = aliases[t] lim = clean.get("limit") if lim is not None: try: n = int(lim) except Exception: n = TRENDING_ENDPOINT_MAX_LIMIT clean["limit"] = max(1, min(n, TRENDING_ENDPOINT_MAX_LIMIT)) return clean lim = clean.get("limit") if lim is None: return clean try: n = int(lim) except Exception: return clean endpoint_limit_max = SELECTIVE_ENDPOINT_RETURN_HARD_CAP if re.match(r"^/api/users/[^/]+/(followers|following)$", path): endpoint_limit_max = GRAPH_SCAN_LIMIT_CAP elif re.match(r"^/api/users/[^/]+/likes$", path): endpoint_limit_max = LIKES_SCAN_LIMIT_CAP clean["limit"] = max(1, min(n, endpoint_limit_max)) return clean def _truncate_result_payload(output: Any) -> Any: if not isinstance(output, dict): return output items = output.get("items") if not isinstance(items, list) or len(items) <= OUTPUT_ITEMS_TRUNCATION_LIMIT: return output trimmed = dict(output) trimmed_items = items[:OUTPUT_ITEMS_TRUNCATION_LIMIT] trimmed["items"] = trimmed_items trimmed["item"] = trimmed_items[0] if len(trimmed_items) == 1 else None note = f"truncated items to first {OUTPUT_ITEMS_TRUNCATION_LIMIT} rows for token efficiency" steps = trimmed.get("steps") if isinstance(steps, list): trimmed["steps"] = [*steps, note] else: trimmed["steps"] = [note] return trimmed def _verbose_result_meta_enabled() -> bool: value = os.environ.get("MONTY_VERBOSE_RESULT_META", "") return value.strip().lower() in {"1", "true", "yes", "on"} def _is_helper_meta_dict(value: Any) -> bool: return ( isinstance(value, dict) and isinstance(value.get("source"), str) and ( value.get("normalized") is True or "budget_used" in value or "budget_remaining" in value ) ) def _helper_meta_is_partial(value: dict[str, Any]) -> bool: return any( [ value.get("truncated") is True, value.get("more_available") not in {False, None}, value.get("limit_boundary_hit") is True, value.get("sample_complete") is False, value.get("exact_count") is False, value.get("ranking_complete") is False, value.get("ranking_window_hit") is True, value.get("hard_cap_applied") is True, ] ) def _compact_helper_meta(value: dict[str, Any]) -> dict[str, Any]: partial = _helper_meta_is_partial(value) compact: dict[str, Any] = { "partial": partial, } for key in ( "source", "returned", "total", "matched", "more_available", "truncated", "truncated_by", "exact_count", "sample_complete", "hard_cap_applied", "limit_boundary_hit", "can_request_more", "next_request_hint", "ranking_window", "ranking_window_hit", "ranking_complete", "ranking_next_request_hint", "relation", "username", "organization", "entity", "entity_type", "handle", ): if value.get(key) is not None: compact[key] = value.get(key) if compact.get("total") is None and value.get("total_available") is not None: compact["total"] = value.get("total_available") return compact def _compact_result_metadata(value: Any) -> Any: if _verbose_result_meta_enabled(): return value if _is_helper_meta_dict(value): return _compact_helper_meta(value) if isinstance(value, dict): return {key: _compact_result_metadata(item) for key, item in value.items()} if isinstance(value, list): return [_compact_result_metadata(item) for item in value] return value def _is_helper_envelope(output: Any) -> bool: return ( isinstance(output, dict) and isinstance(output.get("ok"), bool) and "items" in output and "meta" in output and "error" in output ) def _summarize_limit_hit(helper_name: str, result: Any) -> dict[str, Any] | None: if not _is_helper_envelope(result): return None meta = result.get("meta") if isinstance(result.get("meta"), dict) else {} if not isinstance(meta, dict): return None truncated_by = str(meta.get("truncated_by") or "") limit_hit = any( [ _helper_meta_is_partial(meta), truncated_by in {"scan_limit", "page_limit", "multiple"}, ] ) if not limit_hit: return None summary: dict[str, Any] = { "helper": helper_name, "source": meta.get("source"), "returned": meta.get("returned"), "total": meta.get("total"), "truncated": meta.get("truncated"), "truncated_by": meta.get("truncated_by"), "more_available": meta.get("more_available"), "requested_limit": meta.get("requested_limit"), "applied_limit": meta.get("applied_limit"), "next_request_hint": meta.get("next_request_hint"), "limit_boundary_hit": meta.get("limit_boundary_hit"), } if meta.get("scan_limit") is not None: summary["scan_limit"] = meta.get("scan_limit") if meta.get("applied_max_pages") is not None: summary["applied_max_pages"] = meta.get("applied_max_pages") for key in ( "ranking_window", "requested_ranking_window", "ranking_window_applied", "ranking_window_hit", "ranking_complete", "ranking_next_request_hint", ): if meta.get(key) is not None: summary[key] = meta.get(key) return summary def _wrap_raw_result( result: Any, *, ok: bool, api_calls: int, elapsed_ms: int, limit_summaries: list[dict[str, Any]] | None = None, error: str | None = None, ) -> dict[str, Any]: hits = [dict(summary) for summary in (limit_summaries or [])[:10]] meta: dict[str, Any] = { "ok": ok, "api_calls": api_calls, "elapsed_ms": elapsed_ms, "limits_reached": bool(hits), "limit_summary": hits, } if error is not None: meta["error"] = error return { "result": result, "meta": meta, } def _validate_generated_code(code: str) -> None: if not code.strip(): raise ValueError("Generated code is empty") blocked_patterns: list[tuple[str, str]] = [ (r"(?m)^\s*import\s+\S", "import statement"), (r"(?m)^\s*from\s+\S+\s+import\s+\S", "from-import statement"), (r"\bexec\s*\(", "exec("), (r"\beval\s*\(", "eval("), (r"\bopen\s*\(", "open("), (r"\b__import__\b", "__import__"), (r"(?i)\bwhile\s+true\b", "while true"), ] for pattern, label in blocked_patterns: if re.search(pattern, code): raise ValueError(f"Generated code contains blocked pattern: {label}") try: parsed = compile( # noqa: S102 - compile is used for AST validation only. code, "", "exec", flags=ast.PyCF_ONLY_AST | ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, dont_inherit=True, ) except SyntaxError as e: message = e.msg or "invalid syntax" raise ValueError(f"Generated code is not valid Python: {message}") from e if not isinstance(parsed, ast.Module): raise ValueError("Generated code must be a Python module") if not parsed.body: raise ValueError("Generated code is empty") final_stmt = parsed.body[-1] final_is_result = ( isinstance(final_stmt, ast.Expr) and isinstance(final_stmt.value, ast.Name) and final_stmt.value.id == "result" ) if not final_is_result: raise ValueError( "Generated code must assign the final output to `result` and end with a final line containing only `result` (do not stop after `result = ...`)." ) has_result_assignment = any( isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store) and node.id == "result" for node in ast.walk(parsed) ) if not has_result_assignment: raise ValueError( "Generated code must assign the final output to `result` before the final `result` line." ) for node in ast.walk(parsed): if not isinstance(node, ast.Call): continue if isinstance(node.func, ast.Name) and node.func.id == "call_api": raise ValueError( "Generated code must use documented hf_* helpers only; raw `call_api(...)` is not part of the prompt contract." ) helper_name_set = set(HELPER_EXTERNALS) has_external_call = any( isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in helper_name_set for node in ast.walk(parsed) ) if not has_external_call: raise ValueError( "Generated code must call at least one documented hf_* helper." ) def _coerce_jsonish_python_literals(code: str) -> str: """Normalize common JSON literals into valid Python names in generated code.""" replacements = { "true": "True", "false": "False", "null": "None", } out_tokens: list[tuple[int, str]] = [] for tok in tokenize.generate_tokens(StringIO(code).readline): tok_type = tok.type tok_str = tok.string if tok_type == tokenize.NAME and tok_str in replacements: tok_str = replacements[tok_str] out_tokens.append((tok_type, tok_str)) return tokenize.untokenize(out_tokens)