Viswanath Chirravuri commited on
Commit
261b663
·
0 Parent(s):

Lab1 created

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +10 -0
  3. app.py +467 -0
  4. requirements.txt +4 -0
  5. src/streamlit_app.py +40 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SEC545 Workshop Lab
3
+ emoji: 🛡️
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: "1.42.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pickle
4
+ import pickletools
5
+ import io
6
+ import subprocess
7
+ import uuid
8
+ import numpy as np
9
+ from huggingface_hub import login, HfApi
10
+ from safetensors.numpy import save_file
11
+
12
+ # --- CONFIGURATION & SECRETS ---
13
+ st.set_page_config(page_title="SEC545 Lab 1", layout="wide")
14
+
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+
17
+ if not HF_TOKEN:
18
+ st.error("⚠️ HF_TOKEN not found! Please add it to your Space Secrets.")
19
+ st.stop()
20
+ else:
21
+ login(token=HF_TOKEN, add_to_git_credential=False)
22
+
23
+ # --- SESSION ISOLATION ---
24
+ if "session_id" not in st.session_state:
25
+ st.session_state["session_id"] = str(uuid.uuid4())[:8]
26
+
27
+ session_id = st.session_state["session_id"]
28
+ PKL_PATH = f"vulnerable_model_{session_id}.pkl"
29
+ SAFE_PATH = f"secure_model_{session_id}.safetensors"
30
+
31
+ # --- CUSTOM PICKLE SCANNER ---
32
+ # Replaces modelscan — inspects pickle opcodes without executing the file.
33
+ # Dangerous pickle opcodes that can execute arbitrary code:
34
+ DANGEROUS_OPCODES = {
35
+ # GLOBAL and STACK_GLOBAL are handled separately with stack resolution (see scan_pickle_file)
36
+ "REDUCE", # calls a callable with args — the core RCE vector
37
+ "BUILD", # calls __setstate__ — can trigger code execution
38
+ "INST", # legacy opcode: instantiates a class by module/name string
39
+ "OBJ", # instantiates an object from stack
40
+ "NEWOBJ", # creates a new object — can invoke __new__ with arbitrary args
41
+ "NEWOBJ_EX", # extended version of NEWOBJ
42
+ }
43
+
44
+ # Known dangerous module/name pairs that indicate likely malicious intent
45
+ DANGEROUS_GLOBALS = [
46
+ ("os", "system"),
47
+ ("os", "popen"),
48
+ ("posix", "system"), # Linux: os.system is backed by posix.system
49
+ ("posix", "popen"), # Linux: os.popen is backed by posix.popen
50
+ ("nt", "system"), # Windows equivalent of os.system
51
+ ("nt", "popen"), # Windows equivalent of os.popen
52
+ ("subprocess", "Popen"),
53
+ ("subprocess", "call"),
54
+ ("subprocess", "run"),
55
+ ("builtins", "eval"),
56
+ ("builtins", "exec"),
57
+ ("builtins", "__import__"),
58
+ ("socket", "socket"),
59
+ ]
60
+
61
+ def scan_pickle_file(filepath: str) -> dict:
62
+ """
63
+ Scans a pickle file for dangerous opcodes and globals without executing it.
64
+ Tracks the string stack to resolve STACK_GLOBAL arguments (Python 3 default format).
65
+ Returns a dict with: safe (bool), findings (list of strings), opcode_log (str)
66
+ """
67
+ findings = []
68
+ opcode_log_buffer = io.StringIO()
69
+
70
+ # safetensors files are not pickle — they store only raw tensor data and
71
+ # cannot contain executable code by design. Return clean immediately.
72
+ if filepath.endswith(".safetensors"):
73
+ return {
74
+ "safe": True,
75
+ "findings": [],
76
+ "opcode_log": (
77
+ "Not a pickle file — safetensors format detected.\n"
78
+ "safetensors stores only raw tensor data (no Python objects, "
79
+ "no opcodes, no callable code). It is architecturally safe."
80
+ ),
81
+ }
82
+
83
+ try:
84
+ with open(filepath, "rb") as f:
85
+ data = f.read()
86
+
87
+ # Disassemble the pickle bytecode into a human-readable log.
88
+ # Note: output= kwarg was removed in Python 3.13, so we redirect stdout.
89
+ import sys
90
+ _old_stdout = sys.stdout
91
+ sys.stdout = opcode_log_buffer
92
+ try:
93
+ pickletools.dis(io.BytesIO(data))
94
+ finally:
95
+ sys.stdout = _old_stdout
96
+ opcode_log = opcode_log_buffer.getvalue()
97
+
98
+ # Walk each opcode and track all string literals seen so far.
99
+ # For STACK_GLOBAL (Python 3 default format), the module and name are always
100
+ # the last two string values pushed before the opcode — so we just keep an
101
+ # ever-growing list and read [-2] and [-1] when needed. No clearing required.
102
+ seen_strings = []
103
+
104
+ for opcode, arg, pos in pickletools.genops(io.BytesIO(data)):
105
+ opname = opcode.name
106
+
107
+ # Record every string literal pushed onto the pickle stack
108
+ if opname in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"):
109
+ seen_strings.append(arg)
110
+
111
+ # GLOBAL (older pickle format): module and name are inline in the opcode arg
112
+ elif opname == "GLOBAL" and arg:
113
+ parts = arg.split(" ", 1)
114
+ if len(parts) == 2:
115
+ _report_global(findings, pos, parts[0], parts[1])
116
+
117
+ # STACK_GLOBAL (Python 3 default): resolve from the last two strings seen
118
+ elif opname == "STACK_GLOBAL":
119
+ if len(seen_strings) >= 2:
120
+ module, name = seen_strings[-2], seen_strings[-1]
121
+ _report_global(findings, pos, module, name, via_stack=True)
122
+ else:
123
+ findings.append(
124
+ f"⚠️ WARNING — STACK_GLOBAL at byte {pos}: "
125
+ f"could not resolve callable name (not enough string context)."
126
+ )
127
+
128
+ # REDUCE is the opcode that actually *invokes* the callable — the RCE trigger
129
+ elif opname == "REDUCE":
130
+ findings.append(
131
+ f"🚨 CRITICAL — REDUCE opcode at byte {pos}: "
132
+ f"a callable on the stack will be invoked when this file is loaded."
133
+ )
134
+
135
+ # Flag other execution-capable opcodes
136
+ elif opname in DANGEROUS_OPCODES:
137
+ findings.append(
138
+ f"⚠️ WARNING — Opcode `{opname}` at byte {pos} can trigger code execution."
139
+ )
140
+
141
+ except Exception as e:
142
+ findings.append(f"❌ Scan error: {e}")
143
+ opcode_log = ""
144
+
145
+ return {
146
+ "safe": len(findings) == 0,
147
+ "findings": findings,
148
+ "opcode_log": opcode_log,
149
+ }
150
+
151
+
152
+ def _report_global(findings, pos, module, name, via_stack=False):
153
+ """Classify a global reference and append the appropriate finding."""
154
+ source = "STACK_GLOBAL (Python 3 format)" if via_stack else "GLOBAL"
155
+ if (module, name) in DANGEROUS_GLOBALS:
156
+ findings.append(
157
+ f"🚨 CRITICAL — Dangerous callable at byte {pos} via `{source}`: "
158
+ f"`{module}.{name}` — loading this file will execute a system command."
159
+ )
160
+ else:
161
+ findings.append(
162
+ f"⚠️ WARNING — Global reference at byte {pos} via `{source}`: "
163
+ f"`{module}.{name}` — verify this callable is expected."
164
+ )
165
+
166
+ # --- LAB INTERFACE ---
167
+
168
+ st.title("🛡️ Lab: ML Model Serialization Vulnerabilities")
169
+ st.markdown(f"""
170
+ **Goal:** Demonstrate how malicious code can be hidden in standard ML model files (`.pkl`)
171
+ and how to fix it using `safetensors`.
172
+
173
+ > 🔑 Your session ID: `{session_id}` — your files are isolated from other students.
174
+ """)
175
+
176
+ # --- STEP 1: CREATE VULNERABLE MODEL ---
177
+ st.header("Step 1: Create a 'Vulnerable' Model")
178
+ st.markdown("""
179
+ We will create a `pickle` file that contains a hidden system command.
180
+ When the file is loaded with `pickle.load()`, the embedded code executes **automatically** —
181
+ without the loader ever intentionally calling it.
182
+ """)
183
+
184
+ st.code("""
185
+ class MaliciousPayload:
186
+ def __reduce__(self):
187
+ cmd = "echo 'SECURITY LAB DEMO: Payload Executed'"
188
+ return (os.system, (cmd,))
189
+ """, language="python")
190
+
191
+ class MaliciousPayload:
192
+ def __reduce__(self):
193
+ cmd = f"echo 'SECURITY LAB DEMO: Benign Payload Executed by session {session_id}'"
194
+ return (os.system, (cmd,))
195
+
196
+ if st.button("Generate Vulnerable Model", key="gen"):
197
+ model_data = {
198
+ "weights": [0.1, 0.2, 0.3],
199
+ "metadata": "Lab Demo Model",
200
+ "payload": MaliciousPayload()
201
+ }
202
+ with open(PKL_PATH, "wb") as f:
203
+ pickle.dump(model_data, f)
204
+ st.success(f"✅ `{PKL_PATH}` created with embedded payload!")
205
+ st.info("ℹ️ The payload has **not executed yet** — it only fires when the file is loaded.")
206
+
207
+ # --- STEP 2: SCAN ---
208
+ st.header("Step 2: Static Analysis Scan")
209
+ st.markdown("""
210
+ Our scanner inspects the **pickle bytecode opcodes** without executing the file.
211
+ This is the same approach used by tools like ModelScan — static analysis catches the threat before it can run.
212
+ """)
213
+
214
+ if st.button("Run Pickle Scanner", key="scan"):
215
+ if not os.path.exists(PKL_PATH):
216
+ st.warning("⚠️ Please generate the vulnerable model first (Step 1).")
217
+ else:
218
+ with st.spinner("Scanning..."):
219
+ result = scan_pickle_file(PKL_PATH)
220
+
221
+ if result["findings"]:
222
+ st.error(f"🚨 **{len(result['findings'])} issue(s) detected:**")
223
+ for f in result["findings"]:
224
+ st.markdown(f"- {f}")
225
+ else:
226
+ st.success("✅ No issues found.")
227
+
228
+ with st.expander("🔍 View raw pickle opcode disassembly"):
229
+ st.code(result["opcode_log"], language="text")
230
+
231
+ with st.expander("📄 Show scanner source code & how to run it on any model"):
232
+ st.markdown("#### How this scanner works")
233
+ st.markdown("""
234
+ The scanner uses Python's built-in `pickletools` module to **disassemble the pickle
235
+ bytecode without executing it**, then looks for opcodes that can invoke arbitrary code.
236
+ No third-party tools required — `pickletools` ships with every Python installation.
237
+ """)
238
+ st.markdown("#### Scanner source — copy this into your own project")
239
+ st.code('''import pickletools
240
+ import io
241
+
242
+ DANGEROUS_GLOBALS = [
243
+ ("posix", "system"), ("os", "system"), ("nt", "system"),
244
+ ("posix", "popen"), ("os", "popen"), ("nt", "popen"),
245
+ ("subprocess", "Popen"), ("subprocess", "call"), ("subprocess", "run"),
246
+ ("builtins", "eval"), ("builtins", "exec"), ("builtins", "__import__"),
247
+ ]
248
+
249
+ DANGEROUS_OPCODES = {"REDUCE", "BUILD", "INST", "OBJ", "NEWOBJ", "NEWOBJ_EX"}
250
+
251
+ def scan_pickle(filepath):
252
+ findings = []
253
+ seen_strings = []
254
+
255
+ with open(filepath, "rb") as f:
256
+ data = f.read()
257
+
258
+ for opcode, arg, pos in pickletools.genops(io.BytesIO(data)):
259
+ name = opcode.name
260
+
261
+ if name in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"):
262
+ seen_strings.append(arg)
263
+
264
+ elif name == "GLOBAL" and arg:
265
+ parts = arg.split(" ", 1)
266
+ if len(parts) == 2:
267
+ module, func = parts
268
+ severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING"
269
+ findings.append(f"[{severity}] byte {pos}: GLOBAL {module}.{func}")
270
+
271
+ elif name == "STACK_GLOBAL" and len(seen_strings) >= 2:
272
+ module, func = seen_strings[-2], seen_strings[-1]
273
+ severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING"
274
+ findings.append(f"[{severity}] byte {pos}: STACK_GLOBAL {module}.{func}")
275
+
276
+ elif name == "REDUCE":
277
+ findings.append(f"[CRITICAL] byte {pos}: REDUCE — callable will execute on load")
278
+
279
+ elif name in DANGEROUS_OPCODES:
280
+ findings.append(f"[WARNING] byte {pos}: {name} — can trigger code execution")
281
+
282
+ return findings
283
+
284
+
285
+ # --- Usage ---
286
+ findings = scan_pickle("your_model.pkl")
287
+ if findings:
288
+ print(f"UNSAFE — {len(findings)} issue(s) found:")
289
+ for f in findings:
290
+ print(" *", f)
291
+ else:
292
+ print("SAFE — no dangerous opcodes detected")
293
+ ''', language="python")
294
+
295
+ st.markdown("#### Quick command-line check with `pickletools`")
296
+ st.code("python -m pickletools your_model.pkl | grep -E 'GLOBAL|REDUCE|STACK_GLOBAL'", language="bash")
297
+ st.markdown("""
298
+ > **Tip:** If you see `GLOBAL`, `STACK_GLOBAL`, or `REDUCE` opcodes referencing
299
+ > system modules like `os`, `subprocess`, or `builtins` — treat the file as malicious
300
+ > and do not load it.
301
+ """)
302
+
303
+
304
+ # --- STEP 3: SUPPLY CHAIN SIMULATION ---
305
+ st.header("Step 3: Supply Chain Simulation")
306
+ st.markdown("""
307
+ Upload the file to Hugging Face to simulate a **compromised model registry**.
308
+ Anyone who downloads and loads this model will unknowingly execute the payload.
309
+ """)
310
+
311
+ username = "vchirrav"
312
+ repo_id = f"{username}/security-lab-demo"
313
+
314
+ if st.button(f"Upload to `{repo_id}`", key="upload"):
315
+ if not os.path.exists(PKL_PATH):
316
+ st.warning("⚠️ Please generate the vulnerable model first (Step 1).")
317
+ else:
318
+ api = HfApi(token=HF_TOKEN)
319
+ st.write(f"Uploading to `{repo_id}`...")
320
+ try:
321
+ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
322
+ # Read as bytes and pass directly — prevents HF Hub from routing
323
+ # .pkl files through Git LFS, which causes the "LFS pointer" error.
324
+ with open(PKL_PATH, "rb") as f:
325
+ file_bytes = f.read()
326
+ api.upload_file(
327
+ path_or_fileobj=file_bytes,
328
+ path_in_repo=PKL_PATH,
329
+ repo_id=repo_id,
330
+ repo_type="model",
331
+ )
332
+ st.success(f"✅ Uploaded to https://huggingface.co/{repo_id}")
333
+ st.warning("⚠️ In a real attack, victims download and load this — silently executing the payload.")
334
+ except Exception as e:
335
+ st.error(f"❌ Upload failed: {e}")
336
+
337
+ # --- STEP 4: REMEDIATE ---
338
+ st.header("Step 4: Remediate with Safetensors")
339
+ st.markdown("""
340
+ Convert the model to `safetensors` format.
341
+ `safetensors` stores **only raw tensor data** in a flat binary format —
342
+ it is architecturally incapable of embedding executable code.
343
+ """)
344
+
345
+ if st.button("Convert to Safetensors", key="convert"):
346
+ safe_model_data = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)}
347
+ save_file(safe_model_data, SAFE_PATH)
348
+ st.success(f"✅ Converted! Saved as `{SAFE_PATH}`.")
349
+ st.info("ℹ️ Only raw tensor values were saved — no Python objects, no callable code.")
350
+
351
+ with st.expander("📄 Show real-world mitigation code — converting any model to safetensors"):
352
+ st.markdown("#### Install the required packages")
353
+ st.code("pip install safetensors torch", language="bash")
354
+
355
+ st.markdown("#### Convert a PyTorch model (.pt / .pth / .pkl) to safetensors")
356
+ st.code('''import torch
357
+ from safetensors.torch import save_file
358
+
359
+ # Load the original model (only do this with files you already trust or have scanned)
360
+ state_dict = torch.load("model.pt", map_location="cpu")
361
+
362
+ # If the file contains a full model object rather than a plain state_dict, extract it
363
+ if hasattr(state_dict, "state_dict"):
364
+ state_dict = state_dict.state_dict()
365
+
366
+ # Strip out any non-tensor entries (metadata strings, config dicts, etc.)
367
+ tensor_only = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
368
+
369
+ # Save in safetensors format — only raw tensor bytes, no executable code possible
370
+ save_file(tensor_only, "model.safetensors")
371
+ print("Conversion complete: model.safetensors")
372
+ ''', language="python")
373
+
374
+ st.markdown("#### Load the safetensors file back (safe to do with untrusted files)")
375
+ st.code('''from safetensors.torch import load_file
376
+
377
+ state_dict = load_file("model.safetensors")
378
+
379
+ # Restore into your model architecture
380
+ model = MyModel()
381
+ model.load_state_dict(state_dict)
382
+ model.eval()
383
+ ''', language="python")
384
+
385
+ st.markdown("#### Using numpy instead of torch (no GPU/CUDA required)")
386
+ st.code('''import numpy as np
387
+ from safetensors.numpy import save_file, load_file
388
+
389
+ # Save
390
+ arrays = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)}
391
+ save_file(arrays, "model.safetensors")
392
+
393
+ # Load
394
+ loaded = load_file("model.safetensors")
395
+ print(loaded["weights"])
396
+ ''', language="python")
397
+
398
+ st.markdown("""
399
+ > **Why safetensors is safe:** The format stores a JSON header describing tensor shapes
400
+ > and dtypes, followed by raw binary tensor data. There is no mechanism to store Python
401
+ > objects, callables, or executable bytecode — making it safe to load from untrusted sources.
402
+ """)
403
+
404
+
405
+ # --- STEP 5: UPLOAD SECURE MODEL ---
406
+ st.header("Step 5: Publish the Secure Model")
407
+ st.markdown(f"""
408
+ Upload the safe `safetensors` file to the **same repository** as the vulnerable model.
409
+ This simulates replacing a compromised model in the registry with a remediated one.
410
+ """)
411
+
412
+ if st.button(f"Upload Secure Model to `{repo_id}`", key="upload_safe"):
413
+ if not os.path.exists(SAFE_PATH):
414
+ st.warning("⚠️ Please convert to safetensors first (Step 4).")
415
+ else:
416
+ api = HfApi(token=HF_TOKEN)
417
+ st.write(f"Uploading `{SAFE_PATH}` to `{repo_id}`...")
418
+ try:
419
+ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
420
+ with open(SAFE_PATH, "rb") as f:
421
+ safe_bytes = f.read()
422
+ api.upload_file(
423
+ path_or_fileobj=safe_bytes,
424
+ path_in_repo=SAFE_PATH,
425
+ repo_id=repo_id,
426
+ repo_type="model",
427
+ )
428
+ st.success(f"✅ Secure model uploaded to https://huggingface.co/{repo_id}")
429
+ st.info(
430
+ f"ℹ️ Both files now exist in the same repo:\n"
431
+ f"- `{PKL_PATH}` — the vulnerable pickle (still there as evidence)\n"
432
+ f"- `{SAFE_PATH}` — the remediated safetensors replacement"
433
+ )
434
+ except Exception as e:
435
+ st.error(f"❌ Upload failed: {e}")
436
+
437
+ # --- STEP 6: VERIFY ---
438
+ st.header("Step 6: Verify the Fix")
439
+ st.markdown("Scan the safetensors file to confirm the vulnerability is gone.")
440
+
441
+ if st.button("Scan Secure Model", key="verify"):
442
+ if not os.path.exists(SAFE_PATH):
443
+ st.warning("⚠️ Please convert to safetensors first (Step 4).")
444
+ else:
445
+ result = scan_pickle_file(SAFE_PATH)
446
+ if result["safe"]:
447
+ st.success("🎉 Clean scan! No dangerous opcodes found in the safetensors file.")
448
+ st.info("ℹ️ safetensors files are not pickle-based — they cannot contain executable code.")
449
+ else:
450
+ st.error("Unexpected findings — review below.")
451
+ for f in result["findings"]:
452
+ st.markdown(f"- {f}")
453
+
454
+ # --- LAB SUMMARY ---
455
+ st.divider()
456
+ st.header("🧠 Key Takeaways")
457
+ st.markdown("""
458
+ | Format | Can Embed Code? | Safe to Load Untrusted Files? |
459
+ |---|---|---|
460
+ | `.pkl` (pickle) | ✅ Yes | ❌ Never |
461
+ | `.pt` / `.pth` (PyTorch) | ✅ Yes (uses pickle internally) | ❌ No |
462
+ | `.safetensors` | ❌ No | ✅ Yes |
463
+
464
+ **Best practice:** Always use `safetensors` for distributing model weights.
465
+ If you must load a pickle-based model, scan it statically first and only load
466
+ files from fully trusted, verified sources.
467
+ """)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.42.0
2
+ huggingface_hub
3
+ safetensors
4
+ numpy
src/streamlit_app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import altair as alt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import streamlit as st
5
+
6
+ """
7
+ # Welcome to Streamlit!
8
+
9
+ Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
+ If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
+ forums](https://discuss.streamlit.io).
12
+
13
+ In the meantime, below is an example of what you can do with just a few lines of code:
14
+ """
15
+
16
+ num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
+ num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
+
19
+ indices = np.linspace(0, 1, num_points)
20
+ theta = 2 * np.pi * num_turns * indices
21
+ radius = indices
22
+
23
+ x = radius * np.cos(theta)
24
+ y = radius * np.sin(theta)
25
+
26
+ df = pd.DataFrame({
27
+ "x": x,
28
+ "y": y,
29
+ "idx": indices,
30
+ "rand": np.random.randn(num_points),
31
+ })
32
+
33
+ st.altair_chart(alt.Chart(df, height=700, width=700)
34
+ .mark_point(filled=True)
35
+ .encode(
36
+ x=alt.X("x", axis=None),
37
+ y=alt.Y("y", axis=None),
38
+ color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
+ size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
+ ))