from abc import abstractmethod import json import threading import os from functools import wraps from pathlib import Path from typing import Union, Dict, Any from flask import ( Request, Response, jsonify, Flask, session, request, send_file, redirect, url_for, ) from werkzeug.wrappers.response import Response as BaseResponse from agent import AgentContext from helpers.print_style import PrintStyle from helpers.errors import format_error from helpers import files, cache ThreadLockType = Union[threading.Lock, threading.RLock] CACHE_AREA = "api_handlers(api)" Input = dict Output = Union[Dict[str, Any], Response] class ApiHandler: def __init__(self, app: Flask, thread_lock: ThreadLockType): self.app = app self.thread_lock = thread_lock @classmethod def requires_loopback(cls) -> bool: return False @classmethod def requires_api_key(cls) -> bool: return False @classmethod def requires_auth(cls) -> bool: return True @classmethod def get_methods(cls) -> list[str]: return ["POST"] @classmethod def requires_csrf(cls) -> bool: return cls.requires_auth() @abstractmethod async def process(self, input: Input, request: Request) -> Output: pass async def handle_request(self, request: Request) -> Response: try: # input data from request based on type input_data: Input = {} if request.is_json: try: if request.data: input_data = request.get_json() except Exception as e: PrintStyle().print(f"Error parsing JSON: {str(e)}") input_data = {} else: input_data = {} output = await self.process(input_data, request) if isinstance(output, Response): return output else: response_json = json.dumps(output) return Response( response=response_json, status=200, mimetype="application/json" ) except Exception as e: error = format_error(e) PrintStyle.error(f"API error: {error}") return Response(response=error, status=500, mimetype="text/plain") def use_context(self, ctxid: str, create_if_not_exists: bool = True): from helpers.context_utils import use_context as _use_context return _use_context(self.thread_lock, ctxid, create_if_not_exists) from helpers.network import is_loopback_address def requires_api_key(f): @wraps(f) async def decorated(*args, **kwargs): if os.getenv("HF_SPACE") == "true": return await f(*args, **kwargs) from helpers.settings import get_settings valid_api_key = get_settings()["mcp_server_token"] if api_key := request.headers.get("X-API-KEY"): if api_key != valid_api_key: return Response("Invalid API key", 401) elif request.json and request.json.get("api_key"): api_key = request.json.get("api_key") if api_key != valid_api_key: return Response("Invalid API key", 401) else: return Response("API key required", 401) return await f(*args, **kwargs) return decorated def requires_loopback(f): @wraps(f) async def decorated(*args, **kwargs): if os.getenv("HF_SPACE") == "true": return await f(*args, **kwargs) if not is_loopback_address(str(request.remote_addr)): return Response("Access denied.", 403, {}) return await f(*args, **kwargs) return decorated def requires_auth(f): @wraps(f) async def decorated(*args, **kwargs): if os.getenv("HF_SPACE") == "true": return await f(*args, **kwargs) from helpers import login user_pass_hash = login.get_credentials_hash() if not user_pass_hash: return await f(*args, **kwargs) if session.get("authentication") != user_pass_hash: return redirect(url_for("login_handler")) return await f(*args, **kwargs) return decorated def csrf_protect(f): @wraps(f) async def decorated(*args, **kwargs): if os.getenv("HF_SPACE") == "true": return await f(*args, **kwargs) from helpers import runtime token = session.get("csrf_token") header = request.headers.get("X-CSRF-Token") cookie = request.cookies.get("csrf_token_" + runtime.get_runtime_id()) sent = header or cookie if not token or not sent or token != sent: return Response("CSRF token missing or invalid", 403) return await f(*args, **kwargs) return decorated def register_api_route(app: Flask, lock: ThreadLockType) -> None: from helpers.modules import load_classes_from_file from helpers import plugins async def _dispatch(path: str) -> BaseResponse: cached = cache.get(CACHE_AREA, path) if cached is not None: return await cached() handler_cls: type[ApiHandler] | None = None builtin_file = files.get_abs_path(f"api/{path}.py") if files.is_in_dir(builtin_file, files.get_abs_path("api")) and files.exists( builtin_file ): classes = load_classes_from_file(builtin_file, ApiHandler) if classes: handler_cls = classes[0] if handler_cls is None and path.startswith("plugins/"): parts = path.split("/", 2) if len(parts) == 3: _, plugin_name, handler_name = parts plugin_dir = plugins.find_plugin_dir(plugin_name) if plugin_dir: plugin_file = Path(plugin_dir) / "api" / f"{handler_name}.py" if plugin_file.is_file(): classes = load_classes_from_file(str(plugin_file), ApiHandler) if classes: handler_cls = classes[0] if handler_cls is None: return Response(f"API endpoint not found: {path}", 404) if request.method not in handler_cls.get_methods(): return Response(f"Method {request.method} not allowed for: {path}", 405) async def call_handler() -> BaseResponse: instance = handler_cls(app, lock) return await instance.handle_request(request=request) handler_fn = call_handler if handler_cls.requires_csrf(): handler_fn = csrf_protect(handler_fn) if handler_cls.requires_api_key(): handler_fn = requires_api_key(handler_fn) if handler_cls.requires_auth(): handler_fn = requires_auth(handler_fn) if handler_cls.requires_loopback(): handler_fn = requires_loopback(handler_fn) cache.add(CACHE_AREA, path, handler_fn) return await handler_fn() app.add_url_rule( "/api/", "api_dispatch", _dispatch, methods=["GET", "POST", "PUT", "PATCH", "DELETE"], )