Leon4gr45's picture
Redeploy with build fixes and framework integrity
21927d0 verified
from abc import abstractmethod
import json
import threading
import os
from functools import wraps
from pathlib import Path
from typing import Union, Dict, Any
from flask import (
Request,
Response,
jsonify,
Flask,
session,
request,
send_file,
redirect,
url_for,
)
from werkzeug.wrappers.response import Response as BaseResponse
from agent import AgentContext
from helpers.print_style import PrintStyle
from helpers.errors import format_error
from helpers import files, cache
ThreadLockType = Union[threading.Lock, threading.RLock]
CACHE_AREA = "api_handlers(api)"
Input = dict
Output = Union[Dict[str, Any], Response]
class ApiHandler:
def __init__(self, app: Flask, thread_lock: ThreadLockType):
self.app = app
self.thread_lock = thread_lock
@classmethod
def requires_loopback(cls) -> bool:
return False
@classmethod
def requires_api_key(cls) -> bool:
return False
@classmethod
def requires_auth(cls) -> bool:
return True
@classmethod
def get_methods(cls) -> list[str]:
return ["POST"]
@classmethod
def requires_csrf(cls) -> bool:
return cls.requires_auth()
@abstractmethod
async def process(self, input: Input, request: Request) -> Output:
pass
async def handle_request(self, request: Request) -> Response:
try:
# input data from request based on type
input_data: Input = {}
if request.is_json:
try:
if request.data:
input_data = request.get_json()
except Exception as e:
PrintStyle().print(f"Error parsing JSON: {str(e)}")
input_data = {}
else:
input_data = {}
output = await self.process(input_data, request)
if isinstance(output, Response):
return output
else:
response_json = json.dumps(output)
return Response(
response=response_json, status=200, mimetype="application/json"
)
except Exception as e:
error = format_error(e)
PrintStyle.error(f"API error: {error}")
return Response(response=error, status=500, mimetype="text/plain")
def use_context(self, ctxid: str, create_if_not_exists: bool = True):
from helpers.context_utils import use_context as _use_context
return _use_context(self.thread_lock, ctxid, create_if_not_exists)
from helpers.network import is_loopback_address
def requires_api_key(f):
@wraps(f)
async def decorated(*args, **kwargs):
if os.getenv("HF_SPACE") == "true":
return await f(*args, **kwargs)
from helpers.settings import get_settings
valid_api_key = get_settings()["mcp_server_token"]
if api_key := request.headers.get("X-API-KEY"):
if api_key != valid_api_key:
return Response("Invalid API key", 401)
elif request.json and request.json.get("api_key"):
api_key = request.json.get("api_key")
if api_key != valid_api_key:
return Response("Invalid API key", 401)
else:
return Response("API key required", 401)
return await f(*args, **kwargs)
return decorated
def requires_loopback(f):
@wraps(f)
async def decorated(*args, **kwargs):
if os.getenv("HF_SPACE") == "true":
return await f(*args, **kwargs)
if not is_loopback_address(str(request.remote_addr)):
return Response("Access denied.", 403, {})
return await f(*args, **kwargs)
return decorated
def requires_auth(f):
@wraps(f)
async def decorated(*args, **kwargs):
if os.getenv("HF_SPACE") == "true":
return await f(*args, **kwargs)
from helpers import login
user_pass_hash = login.get_credentials_hash()
if not user_pass_hash:
return await f(*args, **kwargs)
if session.get("authentication") != user_pass_hash:
return redirect(url_for("login_handler"))
return await f(*args, **kwargs)
return decorated
def csrf_protect(f):
@wraps(f)
async def decorated(*args, **kwargs):
if os.getenv("HF_SPACE") == "true":
return await f(*args, **kwargs)
from helpers import runtime
token = session.get("csrf_token")
header = request.headers.get("X-CSRF-Token")
cookie = request.cookies.get("csrf_token_" + runtime.get_runtime_id())
sent = header or cookie
if not token or not sent or token != sent:
return Response("CSRF token missing or invalid", 403)
return await f(*args, **kwargs)
return decorated
def register_api_route(app: Flask, lock: ThreadLockType) -> None:
from helpers.modules import load_classes_from_file
from helpers import plugins
async def _dispatch(path: str) -> BaseResponse:
cached = cache.get(CACHE_AREA, path)
if cached is not None:
return await cached()
handler_cls: type[ApiHandler] | None = None
builtin_file = files.get_abs_path(f"api/{path}.py")
if files.is_in_dir(builtin_file, files.get_abs_path("api")) and files.exists(
builtin_file
):
classes = load_classes_from_file(builtin_file, ApiHandler)
if classes:
handler_cls = classes[0]
if handler_cls is None and path.startswith("plugins/"):
parts = path.split("/", 2)
if len(parts) == 3:
_, plugin_name, handler_name = parts
plugin_dir = plugins.find_plugin_dir(plugin_name)
if plugin_dir:
plugin_file = Path(plugin_dir) / "api" / f"{handler_name}.py"
if plugin_file.is_file():
classes = load_classes_from_file(str(plugin_file), ApiHandler)
if classes:
handler_cls = classes[0]
if handler_cls is None:
return Response(f"API endpoint not found: {path}", 404)
if request.method not in handler_cls.get_methods():
return Response(f"Method {request.method} not allowed for: {path}", 405)
async def call_handler() -> BaseResponse:
instance = handler_cls(app, lock)
return await instance.handle_request(request=request)
handler_fn = call_handler
if handler_cls.requires_csrf():
handler_fn = csrf_protect(handler_fn)
if handler_cls.requires_api_key():
handler_fn = requires_api_key(handler_fn)
if handler_cls.requires_auth():
handler_fn = requires_auth(handler_fn)
if handler_cls.requires_loopback():
handler_fn = requires_loopback(handler_fn)
cache.add(CACHE_AREA, path, handler_fn)
return await handler_fn()
app.add_url_rule(
"/api/<path:path>",
"api_dispatch",
_dispatch,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)