Spaces:
Paused
Paused
| """ | |
| Tests: | |
| - custom_path false / no user auth: | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| -- block __pycache__ access(yes) | |
| -- rel (yes) | |
| -- abs (yes) | |
| -- block user access(fail) http://localhost:45013/file=gpt_log/admin/chat_secrets.log | |
| -- fix(commit f6bf05048c08f5cd84593f7fdc01e64dec1f584a)-> block successful | |
| - custom_path yes("/cc/gptac") / no user auth: | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| -- block __pycache__ access(yes) | |
| -- block user access(yes) | |
| - custom_path yes("/cc/gptac/") / no user auth: | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| -- block user access(yes) | |
| - custom_path yes("/cc/gptac/") / + user auth: | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| -- block user access(yes) | |
| -- block user-wise access (yes) | |
| - custom_path no + user auth: | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| -- block user access(yes) | |
| -- block user-wise access (yes) | |
| queue cocurrent effectiveness | |
| -- upload file(yes) | |
| -- download file(yes) | |
| -- websocket(yes) | |
| """ | |
| import os, requests, threading, time | |
| import uvicorn | |
| def validate_path_safety(path_or_url, user): | |
| from toolbox import get_conf, default_user_name | |
| from toolbox import FriendlyException | |
| PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING') | |
| sensitive_path = None | |
| path_or_url = os.path.relpath(path_or_url) | |
| if path_or_url.startswith(PATH_LOGGING): # 日志文件(按用户划分) | |
| sensitive_path = PATH_LOGGING | |
| elif path_or_url.startswith(PATH_PRIVATE_UPLOAD): # 用户的上传目录(按用户划分) | |
| sensitive_path = PATH_PRIVATE_UPLOAD | |
| elif path_or_url.startswith('tests') or path_or_url.startswith('build'): # 一个常用的测试目录 | |
| return True | |
| else: | |
| raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但位置非法。请将文件上传后再执行该任务。") # return False | |
| if sensitive_path: | |
| allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed | |
| for user_allowed in allowed_users: | |
| if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed): | |
| return True | |
| raise FriendlyException(f"输入文件的路径 ({path_or_url}) 存在,但属于其他用户。请将文件上传后再执行该任务。") # return False | |
| return True | |
| def _authorize_user(path_or_url, request, gradio_app): | |
| from toolbox import get_conf, default_user_name | |
| PATH_PRIVATE_UPLOAD, PATH_LOGGING = get_conf('PATH_PRIVATE_UPLOAD', 'PATH_LOGGING') | |
| sensitive_path = None | |
| path_or_url = os.path.relpath(path_or_url) | |
| if path_or_url.startswith(PATH_LOGGING): | |
| sensitive_path = PATH_LOGGING | |
| if path_or_url.startswith(PATH_PRIVATE_UPLOAD): | |
| sensitive_path = PATH_PRIVATE_UPLOAD | |
| if sensitive_path: | |
| token = request.cookies.get("access-token") or request.cookies.get("access-token-unsecure") | |
| user = gradio_app.tokens.get(token) # get user | |
| allowed_users = [user, 'autogen', 'arxiv_cache', default_user_name] # three user path that can be accessed | |
| for user_allowed in allowed_users: | |
| # exact match | |
| if f"{os.sep}".join(path_or_url.split(os.sep)[:2]) == os.path.join(sensitive_path, user_allowed): | |
| return True | |
| return False # "越权访问!" | |
| return True | |
| class Server(uvicorn.Server): | |
| # A server that runs in a separate thread | |
| def install_signal_handlers(self): | |
| pass | |
| def run_in_thread(self): | |
| self.thread = threading.Thread(target=self.run, daemon=True) | |
| self.thread.start() | |
| while not self.started: | |
| time.sleep(5e-2) | |
| def close(self): | |
| self.should_exit = True | |
| self.thread.join() | |
| def start_app(app_block, CONCURRENT_COUNT, AUTHENTICATION, PORT, SSL_KEYFILE, SSL_CERTFILE): | |
| import uvicorn | |
| import fastapi | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from gradio.routes import App | |
| from toolbox import get_conf | |
| CUSTOM_PATH, PATH_LOGGING = get_conf('CUSTOM_PATH', 'PATH_LOGGING') | |
| # --- --- configurate gradio app block --- --- | |
| app_block:gr.Blocks | |
| app_block.ssl_verify = False | |
| app_block.auth_message = '请登录' | |
| app_block.favicon_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docs/logo.png") | |
| app_block.auth = AUTHENTICATION if len(AUTHENTICATION) != 0 else None | |
| app_block.blocked_paths = ["config.py", "__pycache__", "config_private.py", "docker-compose.yml", "Dockerfile", f"{PATH_LOGGING}/admin"] | |
| app_block.dev_mode = False | |
| app_block.config = app_block.get_config_file() | |
| app_block.enable_queue = True | |
| app_block.queue(concurrency_count=CONCURRENT_COUNT) | |
| app_block.validate_queue_settings() | |
| app_block.show_api = False | |
| app_block.config = app_block.get_config_file() | |
| max_threads = 40 | |
| app_block.max_threads = max( | |
| app_block._queue.max_thread_count if app_block.enable_queue else 0, max_threads | |
| ) | |
| app_block.is_colab = False | |
| app_block.is_kaggle = False | |
| app_block.is_sagemaker = False | |
| gradio_app = App.create_app(app_block) | |
| for route in list(gradio_app.router.routes): | |
| if route.path == "/proxy={url_path:path}": | |
| gradio_app.router.routes.remove(route) | |
| # --- --- replace gradio endpoint to forbid access to sensitive files --- --- | |
| if len(AUTHENTICATION) > 0: | |
| dependencies = [] | |
| endpoint = None | |
| for route in list(gradio_app.router.routes): | |
| if route.path == "/file/{path:path}": | |
| gradio_app.router.routes.remove(route) | |
| if route.path == "/file={path_or_url:path}": | |
| dependencies = route.dependencies | |
| endpoint = route.endpoint | |
| gradio_app.router.routes.remove(route) | |
| async def file(path_or_url: str, request: fastapi.Request): | |
| if not _authorize_user(path_or_url, request, gradio_app): | |
| return "越权访问!" | |
| stripped = path_or_url.lstrip().lower() | |
| if stripped.startswith("https://") or stripped.startswith("http://"): | |
| return "账户密码授权模式下, 禁止链接!" | |
| if '../' in stripped: | |
| return "非法路径!" | |
| return await endpoint(path_or_url, request) | |
| from fastapi import Request, status | |
| from fastapi.responses import FileResponse, RedirectResponse | |
| async def logout(): | |
| response = RedirectResponse(url=CUSTOM_PATH, status_code=status.HTTP_302_FOUND) | |
| response.delete_cookie('access-token') | |
| response.delete_cookie('access-token-unsecure') | |
| return response | |
| else: | |
| dependencies = [] | |
| endpoint = None | |
| for route in list(gradio_app.router.routes): | |
| if route.path == "/file/{path:path}": | |
| gradio_app.router.routes.remove(route) | |
| if route.path == "/file={path_or_url:path}": | |
| dependencies = route.dependencies | |
| endpoint = route.endpoint | |
| gradio_app.router.routes.remove(route) | |
| async def file(path_or_url: str, request: fastapi.Request): | |
| stripped = path_or_url.lstrip().lower() | |
| if stripped.startswith("https://") or stripped.startswith("http://"): | |
| return "账户密码授权模式下, 禁止链接!" | |
| if '../' in stripped: | |
| return "非法路径!" | |
| return await endpoint(path_or_url, request) | |
| # --- --- enable TTS (text-to-speech) functionality --- --- | |
| TTS_TYPE = get_conf("TTS_TYPE") | |
| if TTS_TYPE != "DISABLE": | |
| # audio generation functionality | |
| import httpx | |
| from fastapi import FastAPI, Request, HTTPException | |
| from starlette.responses import Response | |
| async def forward_request(request: Request, method: str) -> Response: | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| # Forward the request to the target service | |
| if TTS_TYPE == "EDGE_TTS": | |
| import tempfile | |
| import edge_tts | |
| import wave | |
| import uuid | |
| from pydub import AudioSegment | |
| json = await request.json() | |
| voice = get_conf("EDGE_TTS_VOICE") | |
| tts = edge_tts.Communicate(text=json['text'], voice=voice) | |
| temp_folder = tempfile.gettempdir() | |
| temp_file_name = str(uuid.uuid4().hex) | |
| temp_file = os.path.join(temp_folder, f'{temp_file_name}.mp3') | |
| await tts.save(temp_file) | |
| try: | |
| mp3_audio = AudioSegment.from_file(temp_file, format="mp3") | |
| mp3_audio.export(temp_file, format="wav") | |
| with open(temp_file, 'rb') as wav_file: t = wav_file.read() | |
| os.remove(temp_file) | |
| return Response(content=t) | |
| except: | |
| raise RuntimeError("ffmpeg未安装,无法处理EdgeTTS音频。安装方法见`https://github.com/jiaaro/pydub#getting-ffmpeg-set-up`") | |
| if TTS_TYPE == "LOCAL_SOVITS_API": | |
| # Forward the request to the target service | |
| TARGET_URL = get_conf("GPT_SOVITS_URL") | |
| body = await request.body() | |
| resp = await client.post(TARGET_URL, content=body, timeout=60) | |
| # Return the response from the target service | |
| return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers)) | |
| except httpx.RequestError as e: | |
| raise HTTPException(status_code=400, detail=f"Request to the target service failed: {str(e)}") | |
| async def forward_post_request(request: Request): | |
| return await forward_request(request, "POST") | |
| # --- --- app_lifespan --- --- | |
| from contextlib import asynccontextmanager | |
| async def app_lifespan(app): | |
| async def startup_gradio_app(): | |
| if gradio_app.get_blocks().enable_queue: | |
| gradio_app.get_blocks().startup_events() | |
| async def shutdown_gradio_app(): | |
| pass | |
| await startup_gradio_app() # startup logic here | |
| yield # The application will serve requests after this point | |
| await shutdown_gradio_app() # cleanup/shutdown logic here | |
| # --- --- FastAPI --- --- | |
| fastapi_app = FastAPI(lifespan=app_lifespan) | |
| fastapi_app.mount(CUSTOM_PATH, gradio_app) | |
| # --- --- favicon and block fastapi api reference routes --- --- | |
| from starlette.responses import JSONResponse | |
| if CUSTOM_PATH != '/': | |
| from fastapi.responses import FileResponse | |
| async def favicon(): | |
| return FileResponse(app_block.favicon_path) | |
| async def middleware(request: Request, call_next): | |
| if request.scope['path'] in ["/docs", "/redoc", "/openapi.json"]: | |
| return JSONResponse(status_code=404, content={"message": "Not Found"}) | |
| response = await call_next(request) | |
| return response | |
| # --- --- uvicorn.Config --- --- | |
| ssl_keyfile = None if SSL_KEYFILE == "" else SSL_KEYFILE | |
| ssl_certfile = None if SSL_CERTFILE == "" else SSL_CERTFILE | |
| server_name = "0.0.0.0" | |
| config = uvicorn.Config( | |
| fastapi_app, | |
| host=server_name, | |
| port=PORT, | |
| reload=False, | |
| log_level="warning", | |
| ssl_keyfile=ssl_keyfile, | |
| ssl_certfile=ssl_certfile, | |
| ) | |
| server = Server(config) | |
| url_host_name = "localhost" if server_name == "0.0.0.0" else server_name | |
| if ssl_keyfile is not None: | |
| if ssl_certfile is None: | |
| raise ValueError( | |
| "ssl_certfile must be provided if ssl_keyfile is provided." | |
| ) | |
| path_to_local_server = f"https://{url_host_name}:{PORT}/" | |
| else: | |
| path_to_local_server = f"http://{url_host_name}:{PORT}/" | |
| if CUSTOM_PATH != '/': | |
| path_to_local_server += CUSTOM_PATH.lstrip('/').rstrip('/') + '/' | |
| # --- --- begin --- --- | |
| server.run_in_thread() | |
| # --- --- after server launch --- --- | |
| app_block.server = server | |
| app_block.server_name = server_name | |
| app_block.local_url = path_to_local_server | |
| app_block.protocol = ( | |
| "https" | |
| if app_block.local_url.startswith("https") or app_block.is_colab | |
| else "http" | |
| ) | |
| if app_block.enable_queue: | |
| app_block._queue.set_url(path_to_local_server) | |
| forbid_proxies = { | |
| "http": "", | |
| "https": "", | |
| } | |
| requests.get(f"{app_block.local_url}startup-events", verify=app_block.ssl_verify, proxies=forbid_proxies) | |
| app_block.is_running = True | |
| app_block.block_thread() |