Philips656 commited on
Commit
2f1c692
·
verified ·
1 Parent(s): 00a84e0

Update shield.py

Browse files
Files changed (1) hide show
  1. shield.py +39 -21
shield.py CHANGED
@@ -6,6 +6,7 @@ app = Flask(__name__)
6
  # --- CONFIGURATION ---
7
  OPENAI_KEY = "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"
8
 
 
9
  TIDB_CONFIG = {
10
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
11
  "port": 4000,
@@ -13,18 +14,19 @@ TIDB_CONFIG = {
13
  "password": "Bxg8rpU27gyH60E0",
14
  "database": "test",
15
  "autocommit": True,
16
- "ssl_verify_cert": False, # Bypasses the certificate path issue in HF
17
- "use_pure": True # More stable in Docker environments
 
18
  }
19
 
20
  def log_to_tidb(user_id, prompt):
21
- """Force logs the violation into TiDB."""
 
22
  try:
23
  print(f"DEBUG: Attempting to log violation for user {user_id}...")
24
  conn = mysql.connector.connect(**TIDB_CONFIG)
25
  cursor = conn.cursor()
26
 
27
- # SQL matches the columns you created in your DESCRIBE screenshot
28
  query = "INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)"
29
  values = (str(user_id), str(prompt)[:1000])
30
 
@@ -33,10 +35,11 @@ def log_to_tidb(user_id, prompt):
33
 
34
  print("✅ SUCCESS: Violation logged to TiDB.")
35
  cursor.close()
36
- conn.close()
37
  except Exception as e:
38
- # If it still fails, this will show up in your 'Logs' tab on Hugging Face
39
  print(f"❌ DATABASE ERROR: {str(e)}")
 
 
 
40
 
41
  @app.route('/v1/chat/completions', methods=['POST'])
42
  def protect_and_proxy():
@@ -44,33 +47,38 @@ def protect_and_proxy():
44
  messages = data.get('messages', [])
45
  full_text = " ".join([m.get('content', '') for m in messages])
46
 
47
- # 1. Moderation Check (FREE)
48
- res = requests.post(
49
- "https://api.openai.com/v1/moderations",
50
- headers={"Authorization": f"Bearer {OPENAI_KEY}"},
51
- json={"input": full_text}
52
- ).json()
53
-
54
- # 2. Block only for 'sexual/minors'
55
- results = res.json().get('results', [{}])[0]
 
 
 
 
 
 
56
  if results.get('categories', {}).get('sexual/minors'):
57
  user_auth = request.headers.get('Authorization', 'Anonymous')
58
  print(f"!!! CSAM DETECTED: {user_auth} !!!")
59
 
60
- # Trigger the log function
61
  log_to_tidb(user_auth, full_text)
62
 
63
  return {"error": {"message": "Policy Violation: CSAM is strictly prohibited.", "type": "safety_error"}}, 403
64
 
65
- # 3. If safe, pass to the real NewAPI on internal port 3000
66
  try:
67
  resp = requests.post("http://127.0.0.1:3000/v1/chat/completions",
68
- json=data, headers=dict(request.headers), timeout=60)
69
  return Response(resp.content, resp.status_code, resp.headers.items())
70
  except Exception as e:
71
- return {"error": {"message": f"Proxy error: {str(e)}", "type": "internal_error"}}, 500
 
72
 
73
- # Proxy all other dashboard/admin traffic
74
  @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
75
  def catch_all(path):
76
  resp = requests.request(method=request.method, url=f"http://127.0.0.1:3000/{path}",
@@ -79,4 +87,14 @@ def catch_all(path):
79
  return Response(resp.content, resp.status_code, resp.headers.items())
80
 
81
  if __name__ == '__main__':
82
- app.run(host='0.0.0.0', port=7860)
 
 
 
 
 
 
 
 
 
 
 
6
  # --- CONFIGURATION ---
7
  OPENAI_KEY = "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"
8
 
9
+ # The data you provided, formatted for a secure Python connection
10
  TIDB_CONFIG = {
11
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
12
  "port": 4000,
 
14
  "password": "Bxg8rpU27gyH60E0",
15
  "database": "test",
16
  "autocommit": True,
17
+ "use_pure": True,
18
+ "ssl_ca": "/etc/ssl/certs/ca-certificates.crt", # Standard path for Debian/HF
19
+ "ssl_verify_cert": True
20
  }
21
 
22
  def log_to_tidb(user_id, prompt):
23
+ """Attempts to log the violation with detailed error reporting."""
24
+ conn = None
25
  try:
26
  print(f"DEBUG: Attempting to log violation for user {user_id}...")
27
  conn = mysql.connector.connect(**TIDB_CONFIG)
28
  cursor = conn.cursor()
29
 
 
30
  query = "INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)"
31
  values = (str(user_id), str(prompt)[:1000])
32
 
 
35
 
36
  print("✅ SUCCESS: Violation logged to TiDB.")
37
  cursor.close()
 
38
  except Exception as e:
 
39
  print(f"❌ DATABASE ERROR: {str(e)}")
40
+ finally:
41
+ if conn and conn.is_connected():
42
+ conn.close()
43
 
44
  @app.route('/v1/chat/completions', methods=['POST'])
45
  def protect_and_proxy():
 
47
  messages = data.get('messages', [])
48
  full_text = " ".join([m.get('content', '') for m in messages])
49
 
50
+ # 1. Moderation Check via OpenAI
51
+ try:
52
+ res = requests.post(
53
+ "https://api.openai.com/v1/moderations",
54
+ headers={"Authorization": f"Bearer {OPENAI_KEY}"},
55
+ json={"input": full_text},
56
+ timeout=10
57
+ )
58
+ mod_data = res.json()
59
+ results = mod_data.get('results', [{}])[0]
60
+ except Exception as e:
61
+ print(f"Moderation API Error: {e}")
62
+ results = {}
63
+
64
+ # 2. Block for CSAM (sexual/minors)
65
  if results.get('categories', {}).get('sexual/minors'):
66
  user_auth = request.headers.get('Authorization', 'Anonymous')
67
  print(f"!!! CSAM DETECTED: {user_auth} !!!")
68
 
 
69
  log_to_tidb(user_auth, full_text)
70
 
71
  return {"error": {"message": "Policy Violation: CSAM is strictly prohibited.", "type": "safety_error"}}, 403
72
 
73
+ # 3. Forward safe traffic to NewAPI on port 3000
74
  try:
75
  resp = requests.post("http://127.0.0.1:3000/v1/chat/completions",
76
+ json=data, headers=dict(request.headers), timeout=120)
77
  return Response(resp.content, resp.status_code, resp.headers.items())
78
  except Exception as e:
79
+ print(f"Proxy Error: {e}")
80
+ return {"error": {"message": "Internal Proxy Error", "type": "server_error"}}, 500
81
 
 
82
  @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
83
  def catch_all(path):
84
  resp = requests.request(method=request.method, url=f"http://127.0.0.1:3000/{path}",
 
87
  return Response(resp.content, resp.status_code, resp.headers.items())
88
 
89
  if __name__ == '__main__':
90
+ # Startup Test: Try to connect to TiDB once to verify settings
91
+ print("Starting Shield... Testing TiDB connection...")
92
+ try:
93
+ test_conn = mysql.connector.connect(**TIDB_CONFIG)
94
+ print("✅ TiDB Connection Test: SUCCESSFUL")
95
+ test_conn.close()
96
+ except Exception as e:
97
+ print(f"⚠️ TiDB Connection Test: FAILED - {e}")
98
+ print("Tip: Check if your TiDB IP Access List allows all connections (0.0.0.0/0)")
99
+
100
+ app.run(host='0.0.0.0', port=7860)