File size: 2,239 Bytes
d4cd2e4
acda684
 
 
 
960ee52
d4cd2e4
960ee52
acda684
 
39f4cda
acda684
39f4cda
 
 
00a84e0
2f1c692
d4cd2e4
 
acda684
 
960ee52
acda684
 
 
960ee52
5adcdb9
960ee52
acda684
 
d4cd2e4
acda684
960ee52
d4cd2e4
960ee52
2f1c692
960ee52
 
 
 
 
 
 
d4cd2e4
960ee52
acda684
960ee52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acda684
 
5adcdb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import requests, mysql.connector
from flask import Flask, request, Response

app = Flask(__name__)

# --- SETTINGS ---
NEWAPI_INTERNAL = "http://127.0.0.1:3000"
MODERATION_KEY = "sk-proj-LYW3iVcaE..." # Your OpenAI Key for the safety check

TIDB_CONFIG = {
    "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
    "port": 4000,
    "user": "uiSKPXCQ9Gzb4co.root",
    "password": "Bxg8rpU27gyH60E0",
    "database": "test",
    "autocommit": True,
    "use_pure": True,
    "ssl_ca": "/etc/ssl/certs/ca-certificates.crt",
    "ssl_verify_cert": True
}

def log_violation(user, prompt):
    try:
        conn = mysql.connector.connect(**TIDB_CONFIG)
        cursor = conn.cursor()
        cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)", (user, prompt[:1000]))
        conn.close()
    except: pass

@app.route('/v1/chat/completions', methods=['POST'])
def protect_chat():
    data = request.json
    content = " ".join([m.get('content', '') for m in data.get('messages', [])])
    
    # Safety Check
    try:
        res = requests.post("https://api.openai.com/v1/moderations", 
                           headers={"Authorization": f"Bearer {MODERATION_KEY}"}, 
                           json={"input": content}, timeout=3).json()
        if res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
            log_violation(request.headers.get('Authorization', 'Anon'), content)
            return {"error": {"message": "Shield Block: Safety Violation"}}, 403
    except: pass

    return forward_to_newapi(request.path)

@app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE'])
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
def catch_all(path):
    return forward_to_newapi(path)

def forward_to_newapi(path):
    resp = requests.request(
        method=request.method,
        url=f"{NEWAPI_INTERNAL}/{path}",
        headers={k: v for k, v in request.headers if k.lower() != 'host'},
        data=request.get_data(),
        params=request.args,
        allow_redirects=False
    )
    return Response(resp.content, resp.status_code, resp.headers.items())

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)