| from abc import abstractmethod |
| import json |
| import os |
| import threading |
| from functools import wraps |
| from pathlib import Path |
| from typing import Union, Dict, Any |
| from flask import ( |
| Request, |
| Response, |
| jsonify, |
| Flask, |
| session, |
| request, |
| send_file, |
| redirect, |
| url_for, |
| ) |
| from werkzeug.wrappers.response import Response as BaseResponse |
| from agent import AgentContext |
| from helpers.print_style import PrintStyle |
| from helpers.errors import format_error |
| from helpers import files, cache |
|
|
| ThreadLockType = Union[threading.Lock, threading.RLock] |
|
|
| CACHE_AREA = "api_handlers(api)" |
| |
|
|
| Input = dict |
| Output = Union[Dict[str, Any], Response] |
|
|
|
|
| class ApiHandler: |
| def __init__(self, app: Flask, thread_lock: ThreadLockType): |
| self.app = app |
| self.thread_lock = thread_lock |
|
|
| @classmethod |
| def requires_loopback(cls) -> bool: |
| return False |
|
|
| @classmethod |
| def requires_api_key(cls) -> bool: |
| return False |
|
|
| @classmethod |
| def requires_auth(cls) -> bool: |
| if os.getenv("HF_SPACE") == "true": |
| return False |
| return True |
|
|
| @classmethod |
| def get_methods(cls) -> list[str]: |
| return ["POST"] |
|
|
| @classmethod |
| def requires_csrf(cls) -> bool: |
| if os.getenv("HF_SPACE") == "true": |
| return False |
| return cls.requires_auth() |
|
|
| @abstractmethod |
| async def process(self, input: Input, request: Request) -> Output: |
| pass |
|
|
| async def handle_request(self, request: Request) -> Response: |
| PrintStyle.debug(f"API Request: {request.method} {request.path}") |
| try: |
| |
| input_data: Input = {} |
| if request.is_json: |
| try: |
| if request.data: |
| input_data = request.get_json() |
| |
| except Exception as e: |
| |
| PrintStyle().print(f"Error parsing JSON: {str(e)}") |
| input_data = {} |
| else: |
| |
| input_data = {} |
|
|
| |
| output = await self.process(input_data, request) |
|
|
| |
| if isinstance(output, Response): |
| return output |
| else: |
| response_json = json.dumps(output) |
| return Response( |
| response=response_json, status=200, mimetype="application/json" |
| ) |
|
|
| |
| except Exception as e: |
| error = format_error(e) |
| PrintStyle.error(f"API error: {error}") |
| return Response(response=error, status=500, mimetype="text/plain") |
|
|
| |
| def use_context(self, ctxid: str, create_if_not_exists: bool = True): |
| from helpers.context_utils import use_context as _use_context |
| return _use_context(self.thread_lock, ctxid, create_if_not_exists) |
|
|
|
|
| from helpers.network import is_loopback_address |
|
|
|
|
| def requires_api_key(f): |
| @wraps(f) |
| async def decorated(*args, **kwargs): |
| from helpers.settings import get_settings |
|
|
| valid_api_key = get_settings()["mcp_server_token"] |
|
|
| if api_key := request.headers.get("X-API-KEY"): |
| if api_key != valid_api_key: |
| return Response("Invalid API key", 401) |
| elif request.json and request.json.get("api_key"): |
| api_key = request.json.get("api_key") |
| if api_key != valid_api_key: |
| return Response("Invalid API key", 401) |
| else: |
| return Response("API key required", 401) |
| return await f(*args, **kwargs) |
|
|
| return decorated |
|
|
|
|
| def requires_loopback(f): |
| @wraps(f) |
| async def decorated(*args, **kwargs): |
| if not is_loopback_address(str(request.remote_addr)): |
| return Response("Access denied.", 403, {}) |
| return await f(*args, **kwargs) |
|
|
| return decorated |
|
|
|
|
| def requires_auth(f): |
| @wraps(f) |
| async def decorated(*args, **kwargs): |
| if os.getenv("HF_SPACE") == "true": |
| return await f(*args, **kwargs) |
|
|
| from helpers import login |
|
|
| user_pass_hash = login.get_credentials_hash() |
| if not user_pass_hash: |
| return await f(*args, **kwargs) |
| if session.get("authentication") != user_pass_hash: |
| return redirect(url_for("login_handler")) |
| return await f(*args, **kwargs) |
|
|
| return decorated |
|
|
|
|
| def csrf_protect(f): |
| @wraps(f) |
| async def decorated(*args, **kwargs): |
| if os.getenv("HF_SPACE") == "true": |
| return await f(*args, **kwargs) |
|
|
| from helpers import runtime |
|
|
| token = session.get("csrf_token") |
| header = request.headers.get("X-CSRF-Token") |
| cookie = request.cookies.get("csrf_token_" + runtime.get_runtime_id()) |
| sent = header or cookie |
| if not token or not sent or token != sent: |
| return Response("CSRF token missing or invalid", 403) |
| return await f(*args, **kwargs) |
|
|
| return decorated |
|
|
|
|
| def register_api_route(app: Flask, lock: ThreadLockType) -> None: |
| from helpers.modules import load_classes_from_file |
| from helpers import plugins |
|
|
| async def _dispatch(path: str) -> BaseResponse: |
| |
| cached = cache.get(CACHE_AREA, path) |
| if cached is not None: |
| return await cached() |
|
|
| |
| |
| handler_cls: type[ApiHandler] | None = None |
|
|
| |
| builtin_file = files.get_abs_path(f"api/{path}.py") |
| if files.is_in_dir(builtin_file, files.get_abs_path("api")) and files.exists( |
| builtin_file |
| ): |
| classes = load_classes_from_file(builtin_file, ApiHandler) |
| if classes: |
| handler_cls = classes[0] |
|
|
| |
| if handler_cls is None and path.startswith("plugins/"): |
| parts = path.split("/", 2) |
| if len(parts) == 3: |
| _, plugin_name, handler_name = parts |
| plugin_dir = plugins.find_plugin_dir(plugin_name) |
| if plugin_dir: |
| plugin_file = Path(plugin_dir) / "api" / f"{handler_name}.py" |
| if plugin_file.is_file(): |
| classes = load_classes_from_file(str(plugin_file), ApiHandler) |
| if classes: |
| handler_cls = classes[0] |
|
|
| if handler_cls is None: |
| return Response(f"API endpoint not found: {path}", 404) |
|
|
| |
| if request.method not in handler_cls.get_methods(): |
| return Response(f"Method {request.method} not allowed for: {path}", 405) |
|
|
| |
| async def call_handler() -> BaseResponse: |
| instance = handler_cls(app, lock) |
| return await instance.handle_request(request=request) |
|
|
| handler_fn = call_handler |
| if handler_cls.requires_csrf(): |
| handler_fn = csrf_protect(handler_fn) |
| if handler_cls.requires_api_key(): |
| handler_fn = requires_api_key(handler_fn) |
| if handler_cls.requires_auth(): |
| handler_fn = requires_auth(handler_fn) |
| if handler_cls.requires_loopback(): |
| handler_fn = requires_loopback(handler_fn) |
|
|
| cache.add(CACHE_AREA, path, handler_fn) |
| return await handler_fn() |
|
|
| app.add_url_rule( |
| "/api/<path:path>", |
| "api_dispatch", |
| _dispatch, |
| methods=["GET", "POST", "PUT", "PATCH", "DELETE"], |
| ) |
|
|
|
|
| def register_watchdogs(): |
| from helpers import watchdog |
| from helpers.ws import CACHE_AREA as WS_CACHE_AREA |
|
|
|
|
| def on_api_change(items: list[watchdog.WatchItem]): |
| PrintStyle.debug("API endpoint watchdog triggered:", items) |
| cache.clear(CACHE_AREA) |
| cache.clear(WS_CACHE_AREA) |
|
|
| watchdog.add_watchdog( |
| "api_handlers", |
| roots=[ |
| files.get_abs_path(files.API_DIR), |
| files.get_abs_path(files.USER_DIR, files.API_DIR), |
| ], |
| patterns=["*.py"], |
| handler=on_api_change, |
| ) |
|
|