Spaces:
No application file
No application file
| import requests, mysql.connector | |
| from flask import Flask, request, Response | |
| app = Flask(__name__) | |
| # --- SETTINGS --- | |
| NEWAPI_INTERNAL = "http://127.0.0.1:3000" | |
| MODERATION_KEY = "sk-proj-LYW3iVcaE..." # Your OpenAI Key for the safety check | |
| TIDB_CONFIG = { | |
| "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com", | |
| "port": 4000, | |
| "user": "uiSKPXCQ9Gzb4co.root", | |
| "password": "Bxg8rpU27gyH60E0", | |
| "database": "test", | |
| "autocommit": True, | |
| "use_pure": True, | |
| "ssl_ca": "/etc/ssl/certs/ca-certificates.crt", | |
| "ssl_verify_cert": True | |
| } | |
| def log_violation(user, prompt): | |
| try: | |
| conn = mysql.connector.connect(**TIDB_CONFIG) | |
| cursor = conn.cursor() | |
| cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)", (user, prompt[:1000])) | |
| conn.close() | |
| except: pass | |
| def protect_chat(): | |
| data = request.json | |
| content = " ".join([m.get('content', '') for m in data.get('messages', [])]) | |
| # Safety Check | |
| try: | |
| res = requests.post("https://api.openai.com/v1/moderations", | |
| headers={"Authorization": f"Bearer {MODERATION_KEY}"}, | |
| json={"input": content}, timeout=3).json() | |
| if res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'): | |
| log_violation(request.headers.get('Authorization', 'Anon'), content) | |
| return {"error": {"message": "Shield Block: Safety Violation"}}, 403 | |
| except: pass | |
| return forward_to_newapi(request.path) | |
| def catch_all(path): | |
| return forward_to_newapi(path) | |
| def forward_to_newapi(path): | |
| resp = requests.request( | |
| method=request.method, | |
| url=f"{NEWAPI_INTERNAL}/{path}", | |
| headers={k: v for k, v in request.headers if k.lower() != 'host'}, | |
| data=request.get_data(), | |
| params=request.args, | |
| allow_redirects=False | |
| ) | |
| return Response(resp.content, resp.status_code, resp.headers.items()) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |