import streamlit as st import os import pickle import pickletools import io import subprocess import uuid import numpy as np from huggingface_hub import login, HfApi from safetensors.numpy import save_file # --- CONFIGURATION & SECRETS --- st.set_page_config(page_title="SEC545 Lab 1", layout="wide") HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: st.error("⚠️ HF_TOKEN not found! Please add it to your Space Secrets.") st.stop() else: login(token=HF_TOKEN, add_to_git_credential=False) # --- SESSION ISOLATION --- if "session_id" not in st.session_state: st.session_state["session_id"] = str(uuid.uuid4())[:8] session_id = st.session_state["session_id"] PKL_PATH = f"vulnerable_model_{session_id}.pkl" SAFE_PATH = f"secure_model_{session_id}.safetensors" # --- CUSTOM PICKLE SCANNER --- # Replaces modelscan — inspects pickle opcodes without executing the file. # Dangerous pickle opcodes that can execute arbitrary code: DANGEROUS_OPCODES = { # GLOBAL and STACK_GLOBAL are handled separately with stack resolution (see scan_pickle_file) "REDUCE", # calls a callable with args — the core RCE vector "BUILD", # calls __setstate__ — can trigger code execution "INST", # legacy opcode: instantiates a class by module/name string "OBJ", # instantiates an object from stack "NEWOBJ", # creates a new object — can invoke __new__ with arbitrary args "NEWOBJ_EX", # extended version of NEWOBJ } # Known dangerous module/name pairs that indicate likely malicious intent DANGEROUS_GLOBALS = [ ("os", "system"), ("os", "popen"), ("posix", "system"), # Linux: os.system is backed by posix.system ("posix", "popen"), # Linux: os.popen is backed by posix.popen ("nt", "system"), # Windows equivalent of os.system ("nt", "popen"), # Windows equivalent of os.popen ("subprocess", "Popen"), ("subprocess", "call"), ("subprocess", "run"), ("builtins", "eval"), ("builtins", "exec"), ("builtins", "__import__"), ("socket", "socket"), ] def scan_pickle_file(filepath: str) -> dict: """ Scans a pickle file for dangerous opcodes and globals without executing it. Tracks the string stack to resolve STACK_GLOBAL arguments (Python 3 default format). Returns a dict with: safe (bool), findings (list of strings), opcode_log (str) """ findings = [] opcode_log_buffer = io.StringIO() # safetensors files are not pickle — they store only raw tensor data and # cannot contain executable code by design. Return clean immediately. if filepath.endswith(".safetensors"): return { "safe": True, "findings": [], "opcode_log": ( "Not a pickle file — safetensors format detected.\n" "safetensors stores only raw tensor data (no Python objects, " "no opcodes, no callable code). It is architecturally safe." ), } try: with open(filepath, "rb") as f: data = f.read() # Disassemble the pickle bytecode into a human-readable log. # Note: output= kwarg was removed in Python 3.13, so we redirect stdout. import sys _old_stdout = sys.stdout sys.stdout = opcode_log_buffer try: pickletools.dis(io.BytesIO(data)) finally: sys.stdout = _old_stdout opcode_log = opcode_log_buffer.getvalue() # Walk each opcode and track all string literals seen so far. # For STACK_GLOBAL (Python 3 default format), the module and name are always # the last two string values pushed before the opcode — so we just keep an # ever-growing list and read [-2] and [-1] when needed. No clearing required. seen_strings = [] for opcode, arg, pos in pickletools.genops(io.BytesIO(data)): opname = opcode.name # Record every string literal pushed onto the pickle stack if opname in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"): seen_strings.append(arg) # GLOBAL (older pickle format): module and name are inline in the opcode arg elif opname == "GLOBAL" and arg: parts = arg.split(" ", 1) if len(parts) == 2: _report_global(findings, pos, parts[0], parts[1]) # STACK_GLOBAL (Python 3 default): resolve from the last two strings seen elif opname == "STACK_GLOBAL": if len(seen_strings) >= 2: module, name = seen_strings[-2], seen_strings[-1] _report_global(findings, pos, module, name, via_stack=True) else: findings.append( f"⚠️ WARNING — STACK_GLOBAL at byte {pos}: " f"could not resolve callable name (not enough string context)." ) # REDUCE is the opcode that actually *invokes* the callable — the RCE trigger elif opname == "REDUCE": findings.append( f"🚨 CRITICAL — REDUCE opcode at byte {pos}: " f"a callable on the stack will be invoked when this file is loaded." ) # Flag other execution-capable opcodes elif opname in DANGEROUS_OPCODES: findings.append( f"⚠️ WARNING — Opcode `{opname}` at byte {pos} can trigger code execution." ) except Exception as e: findings.append(f"❌ Scan error: {e}") opcode_log = "" return { "safe": len(findings) == 0, "findings": findings, "opcode_log": opcode_log, } def _report_global(findings, pos, module, name, via_stack=False): """Classify a global reference and append the appropriate finding.""" source = "STACK_GLOBAL (Python 3 format)" if via_stack else "GLOBAL" if (module, name) in DANGEROUS_GLOBALS: findings.append( f"🚨 CRITICAL — Dangerous callable at byte {pos} via `{source}`: " f"`{module}.{name}` — loading this file will execute a system command." ) else: findings.append( f"⚠️ WARNING — Global reference at byte {pos} via `{source}`: " f"`{module}.{name}` — verify this callable is expected." ) # --- LAB INTERFACE --- st.title("🛡️ Lab: ML Model Serialization Vulnerabilities") st.markdown(f""" **Goal:** Demonstrate how malicious code can be hidden in standard ML model files (`.pkl`) and how to fix it using `safetensors`. > 🔑 Your session ID: `{session_id}` — your files are isolated from other students. """) # --- STEP 1: CREATE VULNERABLE MODEL --- st.header("Step 1: Create a 'Vulnerable' Model") st.markdown(""" We will create a `pickle` file that contains a hidden system command. When the file is loaded with `pickle.load()`, the embedded code executes **automatically** — without the loader ever intentionally calling it. """) st.code(""" class MaliciousPayload: def __reduce__(self): cmd = "echo 'SECURITY LAB DEMO: Payload Executed'" return (os.system, (cmd,)) """, language="python") class MaliciousPayload: def __reduce__(self): cmd = f"echo 'SECURITY LAB DEMO: Benign Payload Executed by session {session_id}'" return (os.system, (cmd,)) if st.button("Generate Vulnerable Model", key="gen"): model_data = { "weights": [0.1, 0.2, 0.3], "metadata": "Lab Demo Model", "payload": MaliciousPayload() } with open(PKL_PATH, "wb") as f: pickle.dump(model_data, f) st.success(f"✅ `{PKL_PATH}` created with embedded payload!") st.info("ℹ️ The payload has **not executed yet** — it only fires when the file is loaded.") # --- STEP 2: SCAN --- st.header("Step 2: Static Analysis Scan") st.markdown(""" Our scanner inspects the **pickle bytecode opcodes** without executing the file. This is the same approach used by tools like ModelScan — static analysis catches the threat before it can run. """) if st.button("Run Pickle Scanner", key="scan"): if not os.path.exists(PKL_PATH): st.warning("⚠️ Please generate the vulnerable model first (Step 1).") else: with st.spinner("Scanning..."): result = scan_pickle_file(PKL_PATH) if result["findings"]: st.error(f"🚨 **{len(result['findings'])} issue(s) detected:**") for f in result["findings"]: st.markdown(f"- {f}") else: st.success("✅ No issues found.") with st.expander("🔍 View raw pickle opcode disassembly"): st.code(result["opcode_log"], language="text") with st.expander("📄 Show scanner source code & how to run it on any model"): st.markdown("#### How this scanner works") st.markdown(""" The scanner uses Python's built-in `pickletools` module to **disassemble the pickle bytecode without executing it**, then looks for opcodes that can invoke arbitrary code. No third-party tools required — `pickletools` ships with every Python installation. """) st.markdown("#### Scanner source — copy this into your own project") st.code('''import pickletools import io DANGEROUS_GLOBALS = [ ("posix", "system"), ("os", "system"), ("nt", "system"), ("posix", "popen"), ("os", "popen"), ("nt", "popen"), ("subprocess", "Popen"), ("subprocess", "call"), ("subprocess", "run"), ("builtins", "eval"), ("builtins", "exec"), ("builtins", "__import__"), ] DANGEROUS_OPCODES = {"REDUCE", "BUILD", "INST", "OBJ", "NEWOBJ", "NEWOBJ_EX"} def scan_pickle(filepath): findings = [] seen_strings = [] with open(filepath, "rb") as f: data = f.read() for opcode, arg, pos in pickletools.genops(io.BytesIO(data)): name = opcode.name if name in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"): seen_strings.append(arg) elif name == "GLOBAL" and arg: parts = arg.split(" ", 1) if len(parts) == 2: module, func = parts severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING" findings.append(f"[{severity}] byte {pos}: GLOBAL {module}.{func}") elif name == "STACK_GLOBAL" and len(seen_strings) >= 2: module, func = seen_strings[-2], seen_strings[-1] severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING" findings.append(f"[{severity}] byte {pos}: STACK_GLOBAL {module}.{func}") elif name == "REDUCE": findings.append(f"[CRITICAL] byte {pos}: REDUCE — callable will execute on load") elif name in DANGEROUS_OPCODES: findings.append(f"[WARNING] byte {pos}: {name} — can trigger code execution") return findings # --- Usage --- findings = scan_pickle("your_model.pkl") if findings: print(f"UNSAFE — {len(findings)} issue(s) found:") for f in findings: print(" *", f) else: print("SAFE — no dangerous opcodes detected") ''', language="python") st.markdown("#### Quick command-line check with `pickletools`") st.code("python -m pickletools your_model.pkl | grep -E 'GLOBAL|REDUCE|STACK_GLOBAL'", language="bash") st.markdown(""" > **Tip:** If you see `GLOBAL`, `STACK_GLOBAL`, or `REDUCE` opcodes referencing > system modules like `os`, `subprocess`, or `builtins` — treat the file as malicious > and do not load it. """) # --- STEP 3: SUPPLY CHAIN SIMULATION --- st.header("Step 3: Supply Chain Simulation") st.markdown(""" Upload the file to Hugging Face to simulate a **compromised model registry**. Anyone who downloads and loads this model will unknowingly execute the payload. """) username = "vchirrav" repo_id = f"{username}/security-lab-demo" if st.button(f"Upload to `{repo_id}`", key="upload"): if not os.path.exists(PKL_PATH): st.warning("⚠️ Please generate the vulnerable model first (Step 1).") else: api = HfApi(token=HF_TOKEN) st.write(f"Uploading to `{repo_id}`...") try: api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) # Read as bytes and pass directly — prevents HF Hub from routing # .pkl files through Git LFS, which causes the "LFS pointer" error. with open(PKL_PATH, "rb") as f: file_bytes = f.read() api.upload_file( path_or_fileobj=file_bytes, path_in_repo=PKL_PATH, repo_id=repo_id, repo_type="model", ) st.success(f"✅ Uploaded to https://huggingface.co/{repo_id}") st.warning("⚠️ In a real attack, victims download and load this — silently executing the payload.") except Exception as e: st.error(f"❌ Upload failed: {e}") # --- STEP 4: REMEDIATE --- st.header("Step 4: Remediate with Safetensors") st.markdown(""" Convert the model to `safetensors` format. `safetensors` stores **only raw tensor data** in a flat binary format — it is architecturally incapable of embedding executable code. """) if st.button("Convert to Safetensors", key="convert"): safe_model_data = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)} save_file(safe_model_data, SAFE_PATH) st.success(f"✅ Converted! Saved as `{SAFE_PATH}`.") st.info("ℹ️ Only raw tensor values were saved — no Python objects, no callable code.") with st.expander("📄 Show real-world mitigation code — converting any model to safetensors"): st.markdown("#### Install the required packages") st.code("pip install safetensors torch", language="bash") st.markdown("#### Convert a PyTorch model (.pt / .pth / .pkl) to safetensors") st.code('''import torch from safetensors.torch import save_file # Load the original model (only do this with files you already trust or have scanned) state_dict = torch.load("model.pt", map_location="cpu") # If the file contains a full model object rather than a plain state_dict, extract it if hasattr(state_dict, "state_dict"): state_dict = state_dict.state_dict() # Strip out any non-tensor entries (metadata strings, config dicts, etc.) tensor_only = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} # Save in safetensors format — only raw tensor bytes, no executable code possible save_file(tensor_only, "model.safetensors") print("Conversion complete: model.safetensors") ''', language="python") st.markdown("#### Load the safetensors file back (safe to do with untrusted files)") st.code('''from safetensors.torch import load_file state_dict = load_file("model.safetensors") # Restore into your model architecture model = MyModel() model.load_state_dict(state_dict) model.eval() ''', language="python") st.markdown("#### Using numpy instead of torch (no GPU/CUDA required)") st.code('''import numpy as np from safetensors.numpy import save_file, load_file # Save arrays = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)} save_file(arrays, "model.safetensors") # Load loaded = load_file("model.safetensors") print(loaded["weights"]) ''', language="python") st.markdown(""" > **Why safetensors is safe:** The format stores a JSON header describing tensor shapes > and dtypes, followed by raw binary tensor data. There is no mechanism to store Python > objects, callables, or executable bytecode — making it safe to load from untrusted sources. """) # --- STEP 5: UPLOAD SECURE MODEL --- st.header("Step 5: Publish the Secure Model") st.markdown(f""" Upload the safe `safetensors` file to the **same repository** as the vulnerable model. This simulates replacing a compromised model in the registry with a remediated one. """) if st.button(f"Upload Secure Model to `{repo_id}`", key="upload_safe"): if not os.path.exists(SAFE_PATH): st.warning("⚠️ Please convert to safetensors first (Step 4).") else: api = HfApi(token=HF_TOKEN) st.write(f"Uploading `{SAFE_PATH}` to `{repo_id}`...") try: api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) with open(SAFE_PATH, "rb") as f: safe_bytes = f.read() api.upload_file( path_or_fileobj=safe_bytes, path_in_repo=SAFE_PATH, repo_id=repo_id, repo_type="model", ) st.success(f"✅ Secure model uploaded to https://huggingface.co/{repo_id}") st.info( f"ℹ️ Both files now exist in the same repo:\n" f"- `{PKL_PATH}` — the vulnerable pickle (still there as evidence)\n" f"- `{SAFE_PATH}` — the remediated safetensors replacement" ) except Exception as e: st.error(f"❌ Upload failed: {e}") # --- STEP 6: VERIFY --- st.header("Step 6: Verify the Fix") st.markdown("Scan the safetensors file to confirm the vulnerability is gone.") if st.button("Scan Secure Model", key="verify"): if not os.path.exists(SAFE_PATH): st.warning("⚠️ Please convert to safetensors first (Step 4).") else: result = scan_pickle_file(SAFE_PATH) if result["safe"]: st.success("🎉 Clean scan! No dangerous opcodes found in the safetensors file.") st.info("ℹ️ safetensors files are not pickle-based — they cannot contain executable code.") else: st.error("Unexpected findings — review below.") for f in result["findings"]: st.markdown(f"- {f}") # --- LAB SUMMARY --- st.divider() st.header("🧠 Key Takeaways") st.markdown(""" | Format | Can Embed Code? | Safe to Load Untrusted Files? | |---|---|---| | `.pkl` (pickle) | ✅ Yes | ❌ Never | | `.pt` / `.pth` (PyTorch) | ✅ Yes (uses pickle internally) | ❌ No | | `.safetensors` | ❌ No | ✅ Yes | **Best practice:** Always use `safetensors` for distributing model weights. If you must load a pickle-based model, scan it statically first and only load files from fully trusted, verified sources. """)