Files changed (1) hide show
  1. app.py +85 -0
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import time
3
  import hashlib
 
 
 
4
  from fastapi import FastAPI, Request, HTTPException, status, Header
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import (
@@ -54,6 +57,14 @@ from helper.ratelimit import (
54
 
55
  app = FastAPI()
56
 
 
 
 
 
 
 
 
 
57
  app.add_middleware(
58
  CORSMiddleware,
59
  allow_origins=["*"],
@@ -62,6 +73,20 @@ app.add_middleware(
62
  )
63
  app.include_router(asset_router)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  @app.get("/")
67
  async def reroute_to_home():
@@ -1032,6 +1057,66 @@ async def tiers():
1032
  content=paid_plans,
1033
  )
1034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1035
 
1036
  @app.get("/portal")
1037
  @app.post("/portal")
 
1
  import os
2
  import time
3
  import hashlib
4
+ from fastapi import WebSocket, WebSocketDisconnect
5
+ from collections import defaultdict, deque
6
+ import json
7
  from fastapi import FastAPI, Request, HTTPException, status, Header
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import (
 
57
 
58
  app = FastAPI()
59
 
60
+ WEBSOCKET_KEY = os.getenv("WEBSOCKET_KEY")
61
+
62
+ # authentication attempt tracking
63
+ AUTH_ATTEMPTS = defaultdict(lambda: deque())
64
+
65
+ AUTH_WINDOW_SECONDS = 60
66
+ AUTH_MAX_ATTEMPTS = 10
67
+
68
  app.add_middleware(
69
  CORSMiddleware,
70
  allow_origins=["*"],
 
73
  )
74
  app.include_router(asset_router)
75
 
76
+ def check_ws_auth_rate_limit(ip: str):
77
+ now = time.time()
78
+ q = AUTH_ATTEMPTS[ip]
79
+
80
+ # purge old attempts
81
+ while q and now - q[0] > AUTH_WINDOW_SECONDS:
82
+ q.popleft()
83
+
84
+ if len(q) >= AUTH_MAX_ATTEMPTS:
85
+ return False
86
+
87
+ q.append(now)
88
+ return True
89
+
90
 
91
  @app.get("/")
92
  async def reroute_to_home():
 
1057
  content=paid_plans,
1058
  )
1059
 
1060
+ @app.websocket("/ws/chat")
1061
+ async def websocket_chat(ws: WebSocket):
1062
+ ip = ws.client.host
1063
+
1064
+ await ws.accept()
1065
+
1066
+ # rate limit auth attempts
1067
+ if not check_ws_auth_rate_limit(ip):
1068
+ await ws.close(code=4408)
1069
+ return
1070
+
1071
+ try:
1072
+ auth_msg = await ws.receive_text()
1073
+ auth_data = json.loads(auth_msg)
1074
+
1075
+ provided_key = auth_data.get("key")
1076
+
1077
+ if not WEBSOCKET_KEY or provided_key != WEBSOCKET_KEY:
1078
+ await ws.close(code=4403)
1079
+ return
1080
+
1081
+ # authenticated
1082
+ await ws.send_json({"type": "auth", "status": "ok"})
1083
+
1084
+ while True:
1085
+ msg = await ws.receive_text()
1086
+ data = json.loads(msg)
1087
+
1088
+ body = data.get("body")
1089
+ headers = data.get("headers", {})
1090
+
1091
+ if not body:
1092
+ await ws.send_json({"error": "Missing body"})
1093
+ continue
1094
+
1095
+ url = str(ws.url).replace("ws://", "http://").replace("wss://", "https://")
1096
+ url = url.split("/ws/chat")[0] + "/gen/chat/completions"
1097
+
1098
+ async with httpx.AsyncClient(timeout=None) as client:
1099
+ async with client.stream(
1100
+ "POST",
1101
+ url,
1102
+ json=body,
1103
+ headers=headers,
1104
+ ) as r:
1105
+
1106
+ async for line in r.aiter_lines():
1107
+ if not line:
1108
+ continue
1109
+
1110
+ await ws.send_text(line)
1111
+
1112
+ except WebSocketDisconnect:
1113
+ return
1114
+ except Exception as e:
1115
+ try:
1116
+ await ws.send_json({"error": str(e)})
1117
+ except:
1118
+ pass
1119
+ await ws.close()
1120
 
1121
  @app.get("/portal")
1122
  @app.post("/portal")