newapi / shield.py
Philips656's picture
Update shield.py
960ee52 verified
import requests, mysql.connector
from flask import Flask, request, Response
app = Flask(__name__)
# --- SETTINGS ---
NEWAPI_INTERNAL = "http://127.0.0.1:3000"
MODERATION_KEY = "sk-proj-LYW3iVcaE..." # Your OpenAI Key for the safety check
TIDB_CONFIG = {
"host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
"port": 4000,
"user": "uiSKPXCQ9Gzb4co.root",
"password": "Bxg8rpU27gyH60E0",
"database": "test",
"autocommit": True,
"use_pure": True,
"ssl_ca": "/etc/ssl/certs/ca-certificates.crt",
"ssl_verify_cert": True
}
def log_violation(user, prompt):
try:
conn = mysql.connector.connect(**TIDB_CONFIG)
cursor = conn.cursor()
cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)", (user, prompt[:1000]))
conn.close()
except: pass
@app.route('/v1/chat/completions', methods=['POST'])
def protect_chat():
data = request.json
content = " ".join([m.get('content', '') for m in data.get('messages', [])])
# Safety Check
try:
res = requests.post("https://api.openai.com/v1/moderations",
headers={"Authorization": f"Bearer {MODERATION_KEY}"},
json={"input": content}, timeout=3).json()
if res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
log_violation(request.headers.get('Authorization', 'Anon'), content)
return {"error": {"message": "Shield Block: Safety Violation"}}, 403
except: pass
return forward_to_newapi(request.path)
@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
def catch_all(path):
return forward_to_newapi(path)
def forward_to_newapi(path):
resp = requests.request(
method=request.method,
url=f"{NEWAPI_INTERNAL}/{path}",
headers={k: v for k, v in request.headers if k.lower() != 'host'},
data=request.get_data(),
params=request.args,
allow_redirects=False
)
return Response(resp.content, resp.status_code, resp.headers.items())
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)