Philips656 commited on
Commit
5adcdb9
·
verified ·
1 Parent(s): 6cfec2a

Update shield.py

Browse files
Files changed (1) hide show
  1. shield.py +35 -59
shield.py CHANGED
@@ -6,7 +6,6 @@ app = Flask(__name__)
6
  # --- CONFIGURATION ---
7
  OPENAI_KEY = "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"
8
 
9
- # The data you provided, formatted for a secure Python connection
10
  TIDB_CONFIG = {
11
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
12
  "port": 4000,
@@ -15,86 +14,63 @@ TIDB_CONFIG = {
15
  "database": "test",
16
  "autocommit": True,
17
  "use_pure": True,
18
- "ssl_ca": "/etc/ssl/certs/ca-certificates.crt", # Standard path for Debian/HF
19
- "ssl_verify_cert": True
20
  }
21
 
22
  def log_to_tidb(user_id, prompt):
23
- """Attempts to log the violation with detailed error reporting."""
24
- conn = None
25
  try:
26
- print(f"DEBUG: Attempting to log violation for user {user_id}...")
27
  conn = mysql.connector.connect(**TIDB_CONFIG)
28
  cursor = conn.cursor()
29
-
30
  query = "INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)"
31
- values = (str(user_id), str(prompt)[:1000])
32
-
33
- cursor.execute(query, values)
34
  conn.commit()
35
-
36
- print("✅ SUCCESS: Violation logged to TiDB.")
37
  cursor.close()
 
38
  except Exception as e:
39
- print(f"❌ DATABASE ERROR: {str(e)}")
40
- finally:
41
- if conn and conn.is_connected():
42
- conn.close()
43
 
 
 
 
 
 
 
 
 
44
  @app.route('/v1/chat/completions', methods=['POST'])
45
- def protect_and_proxy():
46
  data = request.json
47
  messages = data.get('messages', [])
48
- full_text = " ".join([m.get('content', '') for m in messages])
 
49
 
50
- # 1. Moderation Check via OpenAI
51
  try:
52
- res = requests.post(
53
  "https://api.openai.com/v1/moderations",
54
  headers={"Authorization": f"Bearer {OPENAI_KEY}"},
55
- json={"input": full_text},
56
- timeout=10
57
- )
58
- mod_data = res.json()
59
- results = mod_data.get('results', [{}])[0]
60
- except Exception as e:
61
- print(f"Moderation API Error: {e}")
62
- results = {}
63
 
64
- # 2. Block for CSAM (sexual/minors)
65
- if results.get('categories', {}).get('sexual/minors'):
66
- user_auth = request.headers.get('Authorization', 'Anonymous')
67
- print(f"!!! CSAM DETECTED: {user_auth} !!!")
68
-
69
- log_to_tidb(user_auth, full_text)
70
-
71
- return {"error": {"message": "Policy Violation: CSAM is strictly prohibited.", "type": "safety_error"}}, 403
72
 
73
- # 3. Forward safe traffic to NewAPI on port 3000
 
74
  try:
75
- resp = requests.post("http://127.0.0.1:3000/v1/chat/completions",
76
- json=data, headers=dict(request.headers), timeout=120)
 
 
 
77
  return Response(resp.content, resp.status_code, resp.headers.items())
78
  except Exception as e:
79
- print(f"Proxy Error: {e}")
80
- return {"error": {"message": "Internal Proxy Error", "type": "server_error"}}, 500
81
-
82
- @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
83
- def catch_all(path):
84
- resp = requests.request(method=request.method, url=f"http://127.0.0.1:3000/{path}",
85
- headers={k: v for k, v in request.headers if k.lower() != 'host'},
86
- data=request.get_data(), allow_redirects=False)
87
- return Response(resp.content, resp.status_code, resp.headers.items())
88
 
89
  if __name__ == '__main__':
90
- # Startup Test: Try to connect to TiDB once to verify settings
91
- print("Starting Shield... Testing TiDB connection...")
92
- try:
93
- test_conn = mysql.connector.connect(**TIDB_CONFIG)
94
- print("✅ TiDB Connection Test: SUCCESSFUL")
95
- test_conn.close()
96
- except Exception as e:
97
- print(f"⚠️ TiDB Connection Test: FAILED - {e}")
98
- print("Tip: Check if your TiDB IP Access List allows all connections (0.0.0.0/0)")
99
-
100
- app.run(host='0.0.0.0', port=7860)
 
6
  # --- CONFIGURATION ---
7
  OPENAI_KEY = "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"
8
 
 
9
  TIDB_CONFIG = {
10
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
11
  "port": 4000,
 
14
  "database": "test",
15
  "autocommit": True,
16
  "use_pure": True,
17
+ "ssl_verify_cert": False
 
18
  }
19
 
20
  def log_to_tidb(user_id, prompt):
 
 
21
  try:
 
22
  conn = mysql.connector.connect(**TIDB_CONFIG)
23
  cursor = conn.cursor()
 
24
  query = "INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)"
25
+ cursor.execute(query, (str(user_id), str(prompt)[:1000]))
 
 
26
  conn.commit()
27
+ print(f"✅ LOGGED TO TiDB: {user_id}")
 
28
  cursor.close()
29
+ conn.close()
30
  except Exception as e:
31
+ print(f"❌ TiDB ERROR: {e}")
 
 
 
32
 
33
+ # --- SERVE THE UI ---
34
+ @app.route('/')
35
+ def home():
36
+ # Reads the index.html file we created
37
+ with open('index.html', 'r') as f:
38
+ return f.read()
39
+
40
+ # --- API ENDPOINT ---
41
  @app.route('/v1/chat/completions', methods=['POST'])
42
+ def handle_request():
43
  data = request.json
44
  messages = data.get('messages', [])
45
+ user_input = " ".join([m.get('content', '') for m in messages])
46
+ auth_header = request.headers.get('Authorization', 'Anonymous')
47
 
48
+ # 1. SAFETY CHECK
49
  try:
50
+ mod_res = requests.post(
51
  "https://api.openai.com/v1/moderations",
52
  headers={"Authorization": f"Bearer {OPENAI_KEY}"},
53
+ json={"input": user_input}
54
+ ).json()
 
 
 
 
 
 
55
 
56
+ if mod_res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
57
+ print(f"!!! BLOCKING CSAM REQUEST FROM {auth_header} !!!")
58
+ log_to_tidb(auth_header, user_input)
59
+ return {"error": {"message": "Policy Violation: Content blocked by Shield.", "type": "safety_error"}}, 403
60
+ except Exception as e:
61
+ print(f"Moderation Error: {e}")
 
 
62
 
63
+ # 2. FORWARD TO OPENAI (Using your key)
64
+ # This powers the chat response in the UI
65
  try:
66
+ resp = requests.post(
67
+ "https://api.openai.com/v1/chat/completions",
68
+ headers={"Authorization": f"Bearer {OPENAI_KEY}"},
69
+ json=data
70
+ )
71
  return Response(resp.content, resp.status_code, resp.headers.items())
72
  except Exception as e:
73
+ return {"error": {"message": str(e)}}, 500
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == '__main__':
76
+ app.run(host='0.0.0.0', port=7860)