Spaces:
No application file
No application file
File size: 2,239 Bytes
d4cd2e4 acda684 960ee52 d4cd2e4 960ee52 acda684 39f4cda acda684 39f4cda 00a84e0 2f1c692 d4cd2e4 acda684 960ee52 acda684 960ee52 5adcdb9 960ee52 acda684 d4cd2e4 acda684 960ee52 d4cd2e4 960ee52 2f1c692 960ee52 d4cd2e4 960ee52 acda684 960ee52 acda684 5adcdb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import requests, mysql.connector
from flask import Flask, request, Response
app = Flask(__name__)
# --- SETTINGS ---
NEWAPI_INTERNAL = "http://127.0.0.1:3000"
MODERATION_KEY = "sk-proj-LYW3iVcaE..." # Your OpenAI Key for the safety check
TIDB_CONFIG = {
"host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
"port": 4000,
"user": "uiSKPXCQ9Gzb4co.root",
"password": "Bxg8rpU27gyH60E0",
"database": "test",
"autocommit": True,
"use_pure": True,
"ssl_ca": "/etc/ssl/certs/ca-certificates.crt",
"ssl_verify_cert": True
}
def log_violation(user, prompt):
try:
conn = mysql.connector.connect(**TIDB_CONFIG)
cursor = conn.cursor()
cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)", (user, prompt[:1000]))
conn.close()
except: pass
@app.route('/v1/chat/completions', methods=['POST'])
def protect_chat():
data = request.json
content = " ".join([m.get('content', '') for m in data.get('messages', [])])
# Safety Check
try:
res = requests.post("https://api.openai.com/v1/moderations",
headers={"Authorization": f"Bearer {MODERATION_KEY}"},
json={"input": content}, timeout=3).json()
if res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
log_violation(request.headers.get('Authorization', 'Anon'), content)
return {"error": {"message": "Shield Block: Safety Violation"}}, 403
except: pass
return forward_to_newapi(request.path)
@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
def catch_all(path):
return forward_to_newapi(path)
def forward_to_newapi(path):
resp = requests.request(
method=request.method,
url=f"{NEWAPI_INTERNAL}/{path}",
headers={k: v for k, v in request.headers if k.lower() != 'host'},
data=request.get_data(),
params=request.args,
allow_redirects=False
)
return Response(resp.content, resp.status_code, resp.headers.items())
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |