Rafael Uzarowski commited on
Commit
054426d
·
unverified ·
1 Parent(s): f06d316

feat: API: support api_key and localhost without auth, signal handler for term/int

Browse files
Files changed (2) hide show
  1. python/helpers/api.py +12 -0
  2. run_ui.py +104 -5
python/helpers/api.py CHANGED
@@ -20,6 +20,18 @@ class ApiHandler:
20
  self.app = app
21
  self.thread_lock = thread_lock
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @abstractmethod
24
  async def process(self, input: Input, request: Request) -> Output:
25
  pass
 
20
  self.app = app
21
  self.thread_lock = thread_lock
22
 
23
+ @classmethod
24
+ def requires_loopback(cls):
25
+ return False
26
+
27
+ @classmethod
28
+ def requires_api_key(cls):
29
+ return False
30
+
31
+ @classmethod
32
+ def requires_auth(cls):
33
+ return True
34
+
35
  @abstractmethod
36
  async def process(self, input: Input, request: Request) -> Output:
37
  pass
run_ui.py CHANGED
@@ -1,6 +1,6 @@
1
  from functools import wraps
2
- import os
3
  import threading
 
4
  from flask import Flask, request, Response
5
  from flask_basicauth import BasicAuth
6
  from python.helpers import errors, files, git
@@ -10,6 +10,10 @@ from python.helpers.cloudflare_tunnel import CloudflareTunnel
10
  from python.helpers.extract_tools import load_classes_from_folder
11
  from python.helpers.api import ApiHandler
12
  from python.helpers.print_style import PrintStyle
 
 
 
 
13
 
14
 
15
  # initialize the internal Flask server
@@ -22,6 +26,69 @@ lock = threading.Lock()
22
  basic_auth = BasicAuth(app)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # require authentication for handlers
26
  def requires_auth(f):
27
  @wraps(f)
@@ -49,7 +116,7 @@ async def serve_index():
49
  gitinfo = None
50
  try:
51
  gitinfo = git.get_git_info()
52
- except Exception as e:
53
  gitinfo = {
54
  "version": "unknown",
55
  "commit_time": "unknown",
@@ -110,9 +177,23 @@ def run():
110
  name = handler.__module__.split(".")[-1]
111
  instance = handler(app, lock)
112
 
113
- @requires_auth
114
- async def handle_request():
115
- return await instance.handle_request(request=request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  app.add_url_rule(
118
  f"/{name}",
@@ -134,6 +215,24 @@ def run():
134
  request_handler=NoRequestLoggingWSGIRequestHandler,
135
  threaded=True,
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  process.set_server(server)
138
  server.log_startup()
139
  server.serve_forever()
 
1
  from functools import wraps
 
2
  import threading
3
+ import signal
4
  from flask import Flask, request, Response
5
  from flask_basicauth import BasicAuth
6
  from python.helpers import errors, files, git
 
10
  from python.helpers.extract_tools import load_classes_from_folder
11
  from python.helpers.api import ApiHandler
12
  from python.helpers.print_style import PrintStyle
13
+ import sys
14
+ import asyncio
15
+ import socket
16
+ import struct
17
 
18
 
19
  # initialize the internal Flask server
 
26
  basic_auth = BasicAuth(app)
27
 
28
 
29
+ def is_loopback_address(address):
30
+ loopback_checker = {
31
+ socket.AF_INET: lambda x: struct.unpack('!I', socket.inet_aton(x))[0] >> (32 - 8) == 127,
32
+ socket.AF_INET6: lambda x: x == '::1'
33
+ }
34
+ address_type = "hostname"
35
+ try:
36
+ socket.inet_pton(socket.AF_INET6, address)
37
+ address_type = "ipv6"
38
+ except socket.error:
39
+ try:
40
+ socket.inet_pton(socket.AF_INET, address)
41
+ address_type = "ipv4"
42
+ except socket.error:
43
+ address_type = "hostname"
44
+
45
+ if address_type == "ipv4":
46
+ return loopback_checker[socket.AF_INET](address)
47
+ elif address_type == "ipv6":
48
+ return loopback_checker[socket.AF_INET6](address)
49
+ else:
50
+ for family in (socket.AF_INET, socket.AF_INET6):
51
+ try:
52
+ r = socket.getaddrinfo(address, None, family, socket.SOCK_STREAM)
53
+ except socket.gaierror:
54
+ return False
55
+ for family, _, _, _, sockaddr in r:
56
+ if not loopback_checker[family](sockaddr[0]):
57
+ return False
58
+ return True
59
+
60
+
61
+ def requires_api_key(f):
62
+ @wraps(f)
63
+ async def decorated(*args, **kwargs):
64
+ valid_api_key = dotenv.get_dotenv_value("API_KEY")
65
+ if api_key := request.headers.get("X-API-KEY"):
66
+ if api_key != valid_api_key:
67
+ return Response("API key required", 401)
68
+ elif request.json and request.json.get("api_key"):
69
+ api_key = request.json.get("api_key")
70
+ if api_key != valid_api_key:
71
+ return Response("API key required", 401)
72
+ else:
73
+ return Response("API key required", 401)
74
+ return await f(*args, **kwargs)
75
+ return decorated
76
+
77
+
78
+ # allow only loopback addresses
79
+ def requires_loopback(f):
80
+ @wraps(f)
81
+ async def decorated(*args, **kwargs):
82
+ if not is_loopback_address(request.remote_addr):
83
+ return Response(
84
+ "Access denied.",
85
+ 403,
86
+ {},
87
+ )
88
+ return await f(*args, **kwargs)
89
+ return decorated
90
+
91
+
92
  # require authentication for handlers
93
  def requires_auth(f):
94
  @wraps(f)
 
116
  gitinfo = None
117
  try:
118
  gitinfo = git.get_git_info()
119
+ except Exception:
120
  gitinfo = {
121
  "version": "unknown",
122
  "commit_time": "unknown",
 
177
  name = handler.__module__.split(".")[-1]
178
  instance = handler(app, lock)
179
 
180
+ if handler.requires_loopback():
181
+ @requires_loopback
182
+ async def handle_request():
183
+ return await instance.handle_request(request=request)
184
+ elif handler.requires_auth():
185
+ @requires_auth
186
+ async def handle_request():
187
+ return await instance.handle_request(request=request)
188
+ elif handler.requires_api_key():
189
+ @requires_api_key
190
+ async def handle_request():
191
+ return await instance.handle_request(request=request)
192
+ else:
193
+ # Fallback to requires_auth
194
+ @requires_auth
195
+ async def handle_request():
196
+ return await instance.handle_request(request=request)
197
 
198
  app.add_url_rule(
199
  f"/{name}",
 
215
  request_handler=NoRequestLoggingWSGIRequestHandler,
216
  threaded=True,
217
  )
218
+
219
+ printer = PrintStyle()
220
+
221
+ def signal_handler(sig=None, frame=None):
222
+ nonlocal tunnel, server, printer
223
+ with lock:
224
+ printer.print("Caught signal, stopping server...")
225
+ server.shutdown()
226
+ process.stop_server()
227
+ if tunnel:
228
+ tunnel.stop()
229
+ tunnel = None
230
+ printer.print("Server stopped")
231
+ sys.exit(0)
232
+
233
+ signal.signal(signal.SIGINT, signal_handler)
234
+ signal.signal(signal.SIGTERM, signal_handler)
235
+
236
  process.set_server(server)
237
  server.log_startup()
238
  server.serve_forever()