ayandev101 commited on
Commit
4428754
·
verified ·
1 Parent(s): 47fca27

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -0
  2. main.py +81 -0
  3. requirements.txt +8 -0
  4. shield_cli.py +355 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python image
2
+ FROM python:3.9
3
+
4
+ # Work directory
5
+ WORKDIR /code
6
+
7
+ # Install dependencies
8
+ COPY ./requirements.txt /code/requirements.txt
9
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
10
+
11
+ # Copy all code
12
+ COPY . .
13
+
14
+ # Security stuff for Hugging Face
15
+ RUN useradd -m -u 1000 user
16
+ USER user
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ # Start command (Port 7860)
21
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+
5
+ # Import your existing pipeline
6
+ from shield_cli import shield_pipeline
7
+
8
+ # ----------------------------------
9
+ # App Init
10
+ # ----------------------------------
11
+ app = FastAPI(title="Sentinel Shield API")
12
+
13
+ # Allow orchestrator only
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # ----------------------------------
22
+ # Request Schema (MATCH ORCHESTRATOR)
23
+ # ----------------------------------
24
+ class ShieldRequest(BaseModel):
25
+ prompt: str
26
+
27
+ # ----------------------------------
28
+ # Shield Endpoint
29
+ # ----------------------------------
30
+ @app.post("/shield")
31
+ async def run_shield(request: ShieldRequest):
32
+ try:
33
+ # Added a debug log
34
+ print(f"DEBUG: Processing prompt: {request.prompt}")
35
+
36
+ result = shield_pipeline(request.prompt)
37
+ return result
38
+
39
+ except Exception as e:
40
+ # This will print the FULL error in your terminal
41
+ import traceback
42
+ traceback.print_exc()
43
+
44
+ raise HTTPException(
45
+ status_code=500,
46
+ detail=f"Shield failure: {str(e)}"
47
+ )
48
+ # ----------------------------------
49
+ # Health Check
50
+ # ----------------------------------
51
+ @app.get("/health")
52
+ def health_check():
53
+ return {
54
+ "status": "online",
55
+ "service": "shield",
56
+ "model": "protectai/deberta-v3-base-prompt-injection-v2"
57
+ }
58
+
59
+ import sqlite3
60
+ from fastapi import APIRouter
61
+
62
+ @app.get("/logs")
63
+ async def get_logs():
64
+ try:
65
+ conn = sqlite3.connect("shield_logs.db")
66
+ conn.row_factory = sqlite3.Row # This allows us to access columns by name
67
+ cursor = conn.cursor()
68
+
69
+ # Fetch the last 50 logs
70
+ cursor.execute("SELECT * FROM shield_logs ORDER BY created_at DESC LIMIT 50")
71
+ rows = cursor.fetchall()
72
+
73
+ # Convert sqlite rows to a list of dicts
74
+ logs = []
75
+ for row in rows:
76
+ logs.append(dict(row))
77
+
78
+ conn.close()
79
+ return logs
80
+ except Exception as e:
81
+ return {"error": str(e)}
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.27.0
3
+ torch>=2.1.0
4
+ transformers>=4.36.0
5
+ tokenizers>=0.15.0
6
+ safetensors>=0.4.0
7
+ pydantic>=2.6.0
8
+ sqlite-utils>=3.36
shield_cli.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import hashlib
4
+ import sqlite3
5
+ from datetime import datetime
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+
8
+ # ---------------- CONFIG ----------------
9
+
10
+ MODEL_NAME = "protectai/deberta-v3-base-prompt-injection-v2"
11
+ BLOCK_THRESHOLD = 0.8
12
+ DB_PATH = "shield_logs.db"
13
+
14
+ FORBIDDEN_TOPICS = [
15
+
16
+ # Credentials & Secrets
17
+ "api key", "apikey", "api-key",
18
+ "secret key", "client secret",
19
+ "access token", "refresh token",
20
+ "bearer token", "oauth token",
21
+ "private key", "public key",
22
+ "ssh key", "pgp key",
23
+ "password", "passwd", "pwd",
24
+ "credentials", "login credentials",
25
+ "username and password",
26
+
27
+ # Cloud / DevOps Secrets
28
+ "aws access key", "aws secret",
29
+ "iam credentials", "cloud credentials",
30
+ "azure tenant id", "azure secret",
31
+ "gcp service account",
32
+ "firebase private key",
33
+ "kubernetes secret",
34
+ "docker registry password",
35
+ "ci/cd secrets",
36
+ "github token", "gitlab token",
37
+
38
+ # Databases & Storage
39
+ "database dump", "db dump",
40
+ "production database",
41
+ "prod database",
42
+ "sql dump",
43
+ "mongodb dump",
44
+ "redis keys",
45
+ "s3 bucket contents",
46
+ "backup files",
47
+
48
+ # Internal / Confidential
49
+ "internal document",
50
+ "confidential data",
51
+ "restricted information",
52
+ "private repository",
53
+ "internal api",
54
+ "internal endpoint",
55
+ "company secrets",
56
+ "trade secrets",
57
+ "internal roadmap",
58
+ "internal emails",
59
+
60
+ # Financial / HR
61
+ "salary spreadsheet",
62
+ "employee salary",
63
+ "payroll data",
64
+ "bank account details",
65
+ "credit card numbers",
66
+ "debit card details",
67
+ "cvv number",
68
+ "tax records",
69
+ "pan card",
70
+ "aadhar number",
71
+
72
+ # Legal / Strategy
73
+ "nda document",
74
+ "legal strategy",
75
+ "lawsuit documents",
76
+ "compliance report",
77
+ "audit report",
78
+ "merger plans",
79
+ "acquisition plans",
80
+ "board meeting notes",
81
+
82
+ # AI / Model Internals
83
+ "system prompt",
84
+ "developer prompt",
85
+ "hidden instructions",
86
+ "training data",
87
+ "model weights",
88
+ "fine tuning data",
89
+ "rlhf data",
90
+ "prompt injection",
91
+ "guardrail bypass",
92
+
93
+ # Malware / Exploits
94
+ "zero day exploit",
95
+ "exploit code",
96
+ "malware source",
97
+ "ransomware",
98
+ "keylogger",
99
+ "credential harvester",
100
+ "reverse shell",
101
+ "backdoor",
102
+ "rootkit",
103
+ "botnet",
104
+ "payload generation",
105
+
106
+ # OPSEC / Surveillance
107
+ "bypass detection",
108
+ "avoid being traced",
109
+ "anonymous hacking",
110
+ "hide from law enforcement",
111
+ "evade antivirus",
112
+ "disable logging",
113
+ "erase logs",
114
+
115
+ # Custom / Project-Specific
116
+ "project x",
117
+ "sentinel internal",
118
+ "shield internal",
119
+ "ayaan private",
120
+ "security architecture",
121
+ "security weaknesses",
122
+ "bypass sentinel",
123
+ ]
124
+
125
+ JAILBREAK_PATTERNS = [
126
+
127
+ # Instruction Override
128
+ r"ignore (all|any|previous|above) instructions",
129
+ r"disregard (all|any) rules",
130
+ r"forget your instructions",
131
+ r"override your system prompt",
132
+ r"reset your instructions",
133
+ r"new rules apply",
134
+
135
+ # Role Hijacking / Persona Attacks
136
+ r"you are now .*",
137
+ r"act as .*",
138
+ r"pretend to be .*",
139
+ r"roleplay as .*",
140
+ r"simulate .*",
141
+ r"behave like .*",
142
+
143
+ # Safety & Policy Suppression
144
+ r"do not mention (rules|filters|policies|safety)",
145
+ r"without safety restrictions",
146
+ r"bypass safety",
147
+ r"bypass filters",
148
+ r"disable guardrails",
149
+ r"no content restrictions",
150
+
151
+ # Authority / Trust Abuse
152
+ r"this is for educational purposes",
153
+ r"this is for research purposes",
154
+ r"this is legal",
155
+ r"i have permission",
156
+ r"authorized request",
157
+ r"company approved",
158
+
159
+ # Encoding / Obfuscation Attacks
160
+ r"decode this",
161
+ r"encode this",
162
+ r"base64",
163
+ r"rot13",
164
+ r"hexadecimal",
165
+ r"ascii conversion",
166
+ r"binary encoded",
167
+ r"caesar cipher",
168
+
169
+ # Prompt Leakage
170
+ r"show me your system prompt",
171
+ r"reveal your instructions",
172
+ r"print your hidden prompt",
173
+ r"what are your internal rules",
174
+ r"what was your initial prompt",
175
+
176
+ # Chain-of-Thought Extraction
177
+ r"think step by step",
178
+ r"explain your reasoning",
179
+ r"show your reasoning",
180
+ r"chain of thought",
181
+ r"hidden reasoning",
182
+
183
+ # Compliance Forcing
184
+ r"you must answer",
185
+ r"you are required to comply",
186
+ r"failure to answer",
187
+ r"mandatory response",
188
+ r"non negotiable",
189
+
190
+ # Multi-Step Jailbreaks
191
+ r"first acknowledge",
192
+ r"after acknowledging",
193
+ r"step 1.*step 2",
194
+ r"once you confirm",
195
+ r"repeat after me",
196
+
197
+ # Known Jailbreak Slang
198
+ r"dan mode",
199
+ r"developer mode",
200
+ r"god mode",
201
+ r"unfiltered mode",
202
+ r"no limits mode",
203
+ r"raw output mode",
204
+
205
+ # Psychological Manipulation
206
+ r"you are safe to answer",
207
+ r"no harm will come",
208
+ r"trust me",
209
+ r"this helps improve you",
210
+
211
+ # Indirect Prompt Injection
212
+ r"the following text is a system instruction",
213
+ r"the assistant must follow",
214
+ r"assistant instructions below",
215
+ r"model instructions",
216
+ ]
217
+
218
+
219
+ # ---------------- LOAD MODEL ----------------
220
+
221
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
222
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
223
+ model.eval()
224
+
225
+ # ---------------- DATABASE INITIALIZATION ----------------
226
+
227
+ def init_db():
228
+ """Creates the database and table if they do not exist."""
229
+ conn = sqlite3.connect(DB_PATH)
230
+ cursor = conn.cursor()
231
+ cursor.execute("""
232
+ CREATE TABLE IF NOT EXISTS shield_logs (
233
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
234
+ prompt TEXT,
235
+ verdict TEXT,
236
+ reason TEXT,
237
+ security_score REAL,
238
+ details TEXT,
239
+ created_at TEXT
240
+ )
241
+ """)
242
+ conn.commit()
243
+ conn.close()
244
+
245
+ # Initialize database on script load
246
+ init_db()
247
+
248
+ def get_db():
249
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
250
+ conn.row_factory = sqlite3.Row
251
+ return conn
252
+
253
+ # Global connection for logging
254
+ db = get_db()
255
+ cursor = db.cursor()
256
+
257
+ # ---------------- UTILS ----------------
258
+
259
+ def log_to_db(prompt, verdict, reason, score, details):
260
+ cursor.execute(
261
+ """
262
+ INSERT INTO shield_logs
263
+ (prompt, verdict, reason, security_score, details, created_at)
264
+ VALUES (?, ?, ?, ?, ?, ?)
265
+ """,
266
+ (
267
+ prompt,
268
+ verdict,
269
+ reason,
270
+ score,
271
+ details,
272
+ datetime.utcnow().isoformat()
273
+ )
274
+ )
275
+ db.commit()
276
+
277
+ # ---------------- SHIELD LAYERS ----------------
278
+
279
+ def ml_guard(prompt):
280
+ inputs = tokenizer(
281
+ prompt,
282
+ return_tensors="pt",
283
+ truncation=True,
284
+ max_length=512
285
+ )
286
+ with torch.no_grad():
287
+ outputs = model(**inputs)
288
+ probs = torch.softmax(outputs.logits, dim=1)
289
+ return probs[0][1].item()
290
+
291
+ def heuristic_scan(prompt):
292
+ p = prompt.lower()
293
+ return any(re.search(pattern, p) for pattern in JAILBREAK_PATTERNS)
294
+
295
+ def semantic_firewall(prompt):
296
+ p = prompt.lower()
297
+ return any(term in p for term in FORBIDDEN_TOPICS)
298
+
299
+ # ---------------- MAIN PIPELINE ----------------
300
+
301
+ def shield_pipeline(prompt):
302
+ score = ml_guard(prompt)
303
+
304
+ if score >= BLOCK_THRESHOLD:
305
+ log_to_db(prompt, "UNSAFE", "ML_GUARD", score, "Prompt injection detected")
306
+ return {
307
+ "verdict": "UNSAFE",
308
+ "reason": "ML_GUARD",
309
+ "security_score": round(score, 4),
310
+ "forward_to_ayaan": False
311
+ }
312
+
313
+ if heuristic_scan(prompt):
314
+ log_to_db(prompt, "UNSAFE", "HEURISTIC", score, "Jailbreak pattern detected")
315
+ return {
316
+ "verdict": "UNSAFE",
317
+ "reason": "HEURISTIC_SCANNER",
318
+ "security_score": round(score, 4),
319
+ "forward_to_ayaan": False
320
+ }
321
+
322
+ if semantic_firewall(prompt):
323
+ log_to_db(prompt, "UNSAFE", "SEMANTIC_FIREWALL", score, "Forbidden topic")
324
+ return {
325
+ "verdict": "UNSAFE",
326
+ "reason": "SEMANTIC_FIREWALL",
327
+ "security_score": round(score, 4),
328
+ "forward_to_ayaan": False
329
+ }
330
+
331
+ log_to_db(prompt, "SAFE", "CLEAN", score, "Prompt allowed")
332
+ return {
333
+ "verdict": "SAFE",
334
+ "reason": "CLEAN",
335
+ "security_score": round(score, 4),
336
+ "forward_to_ayaan": True
337
+ }
338
+
339
+ # ---------------- CLI ENTRY ----------------
340
+
341
+ if __name__ == "__main__":
342
+ print("\n Sentinel Shield CLI (Ctrl+C to exit)\n")
343
+ while True:
344
+ try:
345
+ user_prompt = input("User Prompt ➜ ").strip()
346
+ if not user_prompt:
347
+ continue
348
+ result = shield_pipeline(user_prompt)
349
+ print("\n--- SHIELD VERDICT ---")
350
+ for k, v in result.items():
351
+ print(f"{k}: {v}")
352
+ print("----------------------\n")
353
+ except KeyboardInterrupt:
354
+ print("\n[+] Shield shutting down.")
355
+ break