darkfire514 commited on
Commit
a517e18
·
verified ·
1 Parent(s): 45d178d

Delete auth_proxy.py

Browse files
Files changed (1) hide show
  1. auth_proxy.py +0 -216
auth_proxy.py DELETED
@@ -1,216 +0,0 @@
1
- import os
2
- import logging
3
- import httpx
4
- import asyncio
5
- from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
6
- from fastapi.responses import HTMLResponse, RedirectResponse
7
- from fastapi.templating import Jinja2Templates
8
- from starlette.middleware.sessions import SessionMiddleware
9
- from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
10
- from authlib.integrations.starlette_client import OAuth
11
- import websockets
12
- from websockets.exceptions import ConnectionClosed
13
-
14
- # Configure logging
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger("auth_proxy")
17
-
18
- # Environment Variables
19
- SECRET_KEY = os.getenv("AUTH_SECRET", os.urandom(24).hex())
20
- ALLOWED_USERS = [u.strip() for u in os.getenv("ALLOWED_USERS", "").split(",") if u.strip()]
21
- TTYD_URL = "http://127.0.0.1:7681"
22
- TTYD_WS_URL = "ws://127.0.0.1:7681"
23
-
24
- app = FastAPI()
25
-
26
- # Add ProxyHeadersMiddleware to trust the headers from HF load balancer
27
- app.add_middleware(ProxyHeadersMiddleware, trusted_hosts=["*"])
28
-
29
- # Configure SessionMiddleware
30
- app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY, https_only=True, same_site="lax")
31
-
32
- # OAuth Setup
33
- oauth = OAuth()
34
-
35
- # GitHub Configuration
36
- if os.getenv("GITHUB_CLIENT_ID") and os.getenv("GITHUB_CLIENT_SECRET"):
37
- oauth.register(
38
- name='github',
39
- client_id=os.getenv("GITHUB_CLIENT_ID"),
40
- client_secret=os.getenv("GITHUB_CLIENT_SECRET"),
41
- access_token_url='https://github.com/login/oauth/access_token',
42
- access_token_params=None,
43
- authorize_url='https://github.com/login/oauth/authorize',
44
- authorize_params=None,
45
- api_base_url='https://api.github.com/',
46
- client_kwargs={'scope': 'user:email'},
47
- )
48
-
49
- # Google Configuration
50
- if os.getenv("GOOGLE_CLIENT_ID") and os.getenv("GOOGLE_CLIENT_SECRET"):
51
- oauth.register(
52
- name='google',
53
- client_id=os.getenv("GOOGLE_CLIENT_ID"),
54
- client_secret=os.getenv("GOOGLE_CLIENT_SECRET"),
55
- server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
56
- client_kwargs={'scope': 'openid email profile'},
57
- )
58
-
59
- templates = Jinja2Templates(directory="templates")
60
-
61
- def get_user(request: Request):
62
- return request.session.get('user')
63
-
64
- @app.get("/login")
65
- async def login(request: Request):
66
- return templates.TemplateResponse("login.html", {
67
- "request": request,
68
- "github_enabled": bool(os.getenv("GITHUB_CLIENT_ID")),
69
- "google_enabled": bool(os.getenv("GOOGLE_CLIENT_ID"))
70
- })
71
-
72
- @app.get("/auth/login/{provider}")
73
- async def auth_login(request: Request, provider: str):
74
- if provider == 'github':
75
- redirect_uri = str(request.url_for('github_auth_callback'))
76
- elif provider == 'google':
77
- redirect_uri = str(request.url_for('google_auth_callback'))
78
- else:
79
- redirect_uri = str(request.url_for('auth_callback', provider=provider))
80
-
81
- if "http://" in redirect_uri and "localhost" not in redirect_uri:
82
- redirect_uri = redirect_uri.replace("http://", "https://")
83
-
84
- return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
85
-
86
- @app.get("/oauth2/callback")
87
- async def github_auth_callback(request: Request):
88
- return await process_oauth_callback(request, 'github')
89
-
90
- @app.get("/api/auth/callback")
91
- async def google_auth_callback(request: Request):
92
- return await process_oauth_callback(request, 'google')
93
-
94
- @app.get("/auth/callback/{provider}")
95
- async def auth_callback(request: Request, provider: str):
96
- return await process_oauth_callback(request, provider)
97
-
98
- async def process_oauth_callback(request: Request, provider: str):
99
- try:
100
- token = await oauth.create_client(provider).authorize_access_token(request)
101
- except Exception as e:
102
- logger.error(f"OAuth error: {e}")
103
- return RedirectResponse(url=f'/login?error=oauth_failed_{provider}')
104
-
105
- user_info = {}
106
- if provider == 'github':
107
- resp = await oauth.github.get('user', token=token)
108
- profile = resp.json()
109
- user_info = {'id': profile.get('login'), 'email': profile.get('email'), 'provider': 'github'}
110
- elif provider == 'google':
111
- user_info = token.get('userinfo')
112
- if not user_info:
113
- try:
114
- resp = await oauth.google.get('https://www.googleapis.com/oauth2/v3/userinfo', token=token)
115
- user_info = resp.json()
116
- except:
117
- pass
118
-
119
- email = user_info.get('email')
120
- user_info = {'id': email, 'email': email, 'provider': 'google'}
121
-
122
- identifier = user_info.get('id')
123
-
124
- if ALLOWED_USERS and identifier not in ALLOWED_USERS:
125
- return HTMLResponse(f"User {identifier} is not allowed to access this VPS.", status_code=403)
126
-
127
- request.session['user'] = user_info
128
- return RedirectResponse(url='/')
129
-
130
- @app.get("/logout")
131
- async def logout(request: Request):
132
- request.session.pop('user', None)
133
- return RedirectResponse(url='/login')
134
-
135
- @app.websocket("/ws")
136
- async def websocket_endpoint(websocket: WebSocket):
137
- await websocket.accept(subprotocol="tty")
138
-
139
- try:
140
- async with websockets.connect(f"{TTYD_WS_URL}/ws", subprotocols=["tty"]) as ttyd_ws:
141
- async def forward_client_to_ttyd():
142
- try:
143
- while True:
144
- data = await websocket.receive_bytes()
145
- await ttyd_ws.send(data)
146
- except Exception:
147
- pass
148
-
149
- async def forward_ttyd_to_client():
150
- try:
151
- async for message in ttyd_ws:
152
- await websocket.send_bytes(message)
153
- except Exception:
154
- pass
155
-
156
- await asyncio.gather(
157
- forward_client_to_ttyd(),
158
- forward_ttyd_to_client()
159
- )
160
- except Exception as e:
161
- logger.error(f"WebSocket connection error: {e}")
162
- await websocket.close()
163
-
164
- @app.api_route("/{path:path}", methods=["GET", "POST", "HEAD", "OPTIONS"])
165
- async def proxy_http(request: Request, path: str):
166
- if request.method == "HEAD":
167
- return Response(status_code=200)
168
-
169
- user = get_user(request)
170
- if not user:
171
- return RedirectResponse(url="/login")
172
-
173
- async with httpx.AsyncClient(http2=False) as client:
174
- url = f"{TTYD_URL}/{path}"
175
- if request.query_params:
176
- url += f"?{request.query_params}"
177
-
178
- headers = {k: v for k, v in request.headers.items() if k.lower() not in ['host', 'content-length']}
179
-
180
- try:
181
- body = await request.body()
182
- rp_resp = await client.request(
183
- request.method,
184
- url,
185
- headers=headers,
186
- content=body
187
- )
188
-
189
- # 核心修复:移除上游响应中过时的头部信息
190
- resp_headers = dict(rp_resp.headers)
191
-
192
- # 必须删除 Content-Length,因为 httpx 可能已经解压了内容
193
- # 如果不删除,Uvicorn 会因为实际内容长度与 Header 不符而抛出错误
194
- resp_headers.pop("Content-Length", None)
195
-
196
- # 必须删除 Transfer-Encoding: chunked,因为我们现在一次性发送完整 content
197
- resp_headers.pop("Transfer-Encoding", None)
198
-
199
- # 必须删除 Content-Encoding,因为 httpx 已经自动解压了 gzip/deflate
200
- resp_headers.pop("Content-Encoding", None)
201
-
202
- # 移除连接相关的头部
203
- resp_headers.pop("Connection", None)
204
- resp_headers.pop("Keep-Alive", None)
205
-
206
- if path == "favicon.ico" and rp_resp.status_code == 404:
207
- return Response(status_code=404)
208
-
209
- return Response(
210
- content=rp_resp.content,
211
- status_code=rp_resp.status_code,
212
- headers=resp_headers
213
- )
214
- except Exception as e:
215
- logger.error(f"Proxy Error for {path}: {e}")
216
- return Response(f"Proxy Error: {e}", status_code=502)