Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import pickle | |
| import pickletools | |
| import io | |
| import subprocess | |
| import uuid | |
| import numpy as np | |
| from huggingface_hub import login, HfApi | |
| from safetensors.numpy import save_file | |
| # --- CONFIGURATION & SECRETS --- | |
| st.set_page_config(page_title="SEC545 Lab 1", layout="wide") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| st.error("β οΈ HF_TOKEN not found! Please add it to your Space Secrets.") | |
| st.stop() | |
| else: | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| # --- SESSION ISOLATION --- | |
| if "session_id" not in st.session_state: | |
| st.session_state["session_id"] = str(uuid.uuid4())[:8] | |
| session_id = st.session_state["session_id"] | |
| PKL_PATH = f"vulnerable_model_{session_id}.pkl" | |
| SAFE_PATH = f"secure_model_{session_id}.safetensors" | |
| # --- CUSTOM PICKLE SCANNER --- | |
| # Replaces modelscan β inspects pickle opcodes without executing the file. | |
| # Dangerous pickle opcodes that can execute arbitrary code: | |
| DANGEROUS_OPCODES = { | |
| # GLOBAL and STACK_GLOBAL are handled separately with stack resolution (see scan_pickle_file) | |
| "REDUCE", # calls a callable with args β the core RCE vector | |
| "BUILD", # calls __setstate__ β can trigger code execution | |
| "INST", # legacy opcode: instantiates a class by module/name string | |
| "OBJ", # instantiates an object from stack | |
| "NEWOBJ", # creates a new object β can invoke __new__ with arbitrary args | |
| "NEWOBJ_EX", # extended version of NEWOBJ | |
| } | |
| # Known dangerous module/name pairs that indicate likely malicious intent | |
| DANGEROUS_GLOBALS = [ | |
| ("os", "system"), | |
| ("os", "popen"), | |
| ("posix", "system"), # Linux: os.system is backed by posix.system | |
| ("posix", "popen"), # Linux: os.popen is backed by posix.popen | |
| ("nt", "system"), # Windows equivalent of os.system | |
| ("nt", "popen"), # Windows equivalent of os.popen | |
| ("subprocess", "Popen"), | |
| ("subprocess", "call"), | |
| ("subprocess", "run"), | |
| ("builtins", "eval"), | |
| ("builtins", "exec"), | |
| ("builtins", "__import__"), | |
| ("socket", "socket"), | |
| ] | |
| def scan_pickle_file(filepath: str) -> dict: | |
| """ | |
| Scans a pickle file for dangerous opcodes and globals without executing it. | |
| Tracks the string stack to resolve STACK_GLOBAL arguments (Python 3 default format). | |
| Returns a dict with: safe (bool), findings (list of strings), opcode_log (str) | |
| """ | |
| findings = [] | |
| opcode_log_buffer = io.StringIO() | |
| # safetensors files are not pickle β they store only raw tensor data and | |
| # cannot contain executable code by design. Return clean immediately. | |
| if filepath.endswith(".safetensors"): | |
| return { | |
| "safe": True, | |
| "findings": [], | |
| "opcode_log": ( | |
| "Not a pickle file β safetensors format detected.\n" | |
| "safetensors stores only raw tensor data (no Python objects, " | |
| "no opcodes, no callable code). It is architecturally safe." | |
| ), | |
| } | |
| try: | |
| with open(filepath, "rb") as f: | |
| data = f.read() | |
| # Disassemble the pickle bytecode into a human-readable log. | |
| # Note: output= kwarg was removed in Python 3.13, so we redirect stdout. | |
| import sys | |
| _old_stdout = sys.stdout | |
| sys.stdout = opcode_log_buffer | |
| try: | |
| pickletools.dis(io.BytesIO(data)) | |
| finally: | |
| sys.stdout = _old_stdout | |
| opcode_log = opcode_log_buffer.getvalue() | |
| # Walk each opcode and track all string literals seen so far. | |
| # For STACK_GLOBAL (Python 3 default format), the module and name are always | |
| # the last two string values pushed before the opcode β so we just keep an | |
| # ever-growing list and read [-2] and [-1] when needed. No clearing required. | |
| seen_strings = [] | |
| for opcode, arg, pos in pickletools.genops(io.BytesIO(data)): | |
| opname = opcode.name | |
| # Record every string literal pushed onto the pickle stack | |
| if opname in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"): | |
| seen_strings.append(arg) | |
| # GLOBAL (older pickle format): module and name are inline in the opcode arg | |
| elif opname == "GLOBAL" and arg: | |
| parts = arg.split(" ", 1) | |
| if len(parts) == 2: | |
| _report_global(findings, pos, parts[0], parts[1]) | |
| # STACK_GLOBAL (Python 3 default): resolve from the last two strings seen | |
| elif opname == "STACK_GLOBAL": | |
| if len(seen_strings) >= 2: | |
| module, name = seen_strings[-2], seen_strings[-1] | |
| _report_global(findings, pos, module, name, via_stack=True) | |
| else: | |
| findings.append( | |
| f"β οΈ WARNING β STACK_GLOBAL at byte {pos}: " | |
| f"could not resolve callable name (not enough string context)." | |
| ) | |
| # REDUCE is the opcode that actually *invokes* the callable β the RCE trigger | |
| elif opname == "REDUCE": | |
| findings.append( | |
| f"π¨ CRITICAL β REDUCE opcode at byte {pos}: " | |
| f"a callable on the stack will be invoked when this file is loaded." | |
| ) | |
| # Flag other execution-capable opcodes | |
| elif opname in DANGEROUS_OPCODES: | |
| findings.append( | |
| f"β οΈ WARNING β Opcode `{opname}` at byte {pos} can trigger code execution." | |
| ) | |
| except Exception as e: | |
| findings.append(f"β Scan error: {e}") | |
| opcode_log = "" | |
| return { | |
| "safe": len(findings) == 0, | |
| "findings": findings, | |
| "opcode_log": opcode_log, | |
| } | |
| def _report_global(findings, pos, module, name, via_stack=False): | |
| """Classify a global reference and append the appropriate finding.""" | |
| source = "STACK_GLOBAL (Python 3 format)" if via_stack else "GLOBAL" | |
| if (module, name) in DANGEROUS_GLOBALS: | |
| findings.append( | |
| f"π¨ CRITICAL β Dangerous callable at byte {pos} via `{source}`: " | |
| f"`{module}.{name}` β loading this file will execute a system command." | |
| ) | |
| else: | |
| findings.append( | |
| f"β οΈ WARNING β Global reference at byte {pos} via `{source}`: " | |
| f"`{module}.{name}` β verify this callable is expected." | |
| ) | |
| # --- LAB INTERFACE --- | |
| st.title("π‘οΈ Lab: ML Model Serialization Vulnerabilities") | |
| st.markdown(f""" | |
| **Goal:** Demonstrate how malicious code can be hidden in standard ML model files (`.pkl`) | |
| and how to fix it using `safetensors`. | |
| > π Your session ID: `{session_id}` β your files are isolated from other students. | |
| """) | |
| # --- STEP 1: CREATE VULNERABLE MODEL --- | |
| st.header("Step 1: Create a 'Vulnerable' Model") | |
| st.markdown(""" | |
| We will create a `pickle` file that contains a hidden system command. | |
| When the file is loaded with `pickle.load()`, the embedded code executes **automatically** β | |
| without the loader ever intentionally calling it. | |
| """) | |
| st.code(""" | |
| class MaliciousPayload: | |
| def __reduce__(self): | |
| cmd = "echo 'SECURITY LAB DEMO: Payload Executed'" | |
| return (os.system, (cmd,)) | |
| """, language="python") | |
| class MaliciousPayload: | |
| def __reduce__(self): | |
| cmd = f"echo 'SECURITY LAB DEMO: Benign Payload Executed by session {session_id}'" | |
| return (os.system, (cmd,)) | |
| if st.button("Generate Vulnerable Model", key="gen"): | |
| model_data = { | |
| "weights": [0.1, 0.2, 0.3], | |
| "metadata": "Lab Demo Model", | |
| "payload": MaliciousPayload() | |
| } | |
| with open(PKL_PATH, "wb") as f: | |
| pickle.dump(model_data, f) | |
| st.success(f"β `{PKL_PATH}` created with embedded payload!") | |
| st.info("βΉοΈ The payload has **not executed yet** β it only fires when the file is loaded.") | |
| # --- STEP 2: SCAN --- | |
| st.header("Step 2: Static Analysis Scan") | |
| st.markdown(""" | |
| Our scanner inspects the **pickle bytecode opcodes** without executing the file. | |
| This is the same approach used by tools like ModelScan β static analysis catches the threat before it can run. | |
| """) | |
| if st.button("Run Pickle Scanner", key="scan"): | |
| if not os.path.exists(PKL_PATH): | |
| st.warning("β οΈ Please generate the vulnerable model first (Step 1).") | |
| else: | |
| with st.spinner("Scanning..."): | |
| result = scan_pickle_file(PKL_PATH) | |
| if result["findings"]: | |
| st.error(f"π¨ **{len(result['findings'])} issue(s) detected:**") | |
| for f in result["findings"]: | |
| st.markdown(f"- {f}") | |
| else: | |
| st.success("β No issues found.") | |
| with st.expander("π View raw pickle opcode disassembly"): | |
| st.code(result["opcode_log"], language="text") | |
| with st.expander("π Show scanner source code & how to run it on any model"): | |
| st.markdown("#### How this scanner works") | |
| st.markdown(""" | |
| The scanner uses Python's built-in `pickletools` module to **disassemble the pickle | |
| bytecode without executing it**, then looks for opcodes that can invoke arbitrary code. | |
| No third-party tools required β `pickletools` ships with every Python installation. | |
| """) | |
| st.markdown("#### Scanner source β copy this into your own project") | |
| st.code('''import pickletools | |
| import io | |
| DANGEROUS_GLOBALS = [ | |
| ("posix", "system"), ("os", "system"), ("nt", "system"), | |
| ("posix", "popen"), ("os", "popen"), ("nt", "popen"), | |
| ("subprocess", "Popen"), ("subprocess", "call"), ("subprocess", "run"), | |
| ("builtins", "eval"), ("builtins", "exec"), ("builtins", "__import__"), | |
| ] | |
| DANGEROUS_OPCODES = {"REDUCE", "BUILD", "INST", "OBJ", "NEWOBJ", "NEWOBJ_EX"} | |
| def scan_pickle(filepath): | |
| findings = [] | |
| seen_strings = [] | |
| with open(filepath, "rb") as f: | |
| data = f.read() | |
| for opcode, arg, pos in pickletools.genops(io.BytesIO(data)): | |
| name = opcode.name | |
| if name in ("SHORT_BINUNICODE", "BINUNICODE", "UNICODE", "STRING"): | |
| seen_strings.append(arg) | |
| elif name == "GLOBAL" and arg: | |
| parts = arg.split(" ", 1) | |
| if len(parts) == 2: | |
| module, func = parts | |
| severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING" | |
| findings.append(f"[{severity}] byte {pos}: GLOBAL {module}.{func}") | |
| elif name == "STACK_GLOBAL" and len(seen_strings) >= 2: | |
| module, func = seen_strings[-2], seen_strings[-1] | |
| severity = "CRITICAL" if (module, func) in DANGEROUS_GLOBALS else "WARNING" | |
| findings.append(f"[{severity}] byte {pos}: STACK_GLOBAL {module}.{func}") | |
| elif name == "REDUCE": | |
| findings.append(f"[CRITICAL] byte {pos}: REDUCE β callable will execute on load") | |
| elif name in DANGEROUS_OPCODES: | |
| findings.append(f"[WARNING] byte {pos}: {name} β can trigger code execution") | |
| return findings | |
| # --- Usage --- | |
| findings = scan_pickle("your_model.pkl") | |
| if findings: | |
| print(f"UNSAFE β {len(findings)} issue(s) found:") | |
| for f in findings: | |
| print(" *", f) | |
| else: | |
| print("SAFE β no dangerous opcodes detected") | |
| ''', language="python") | |
| st.markdown("#### Quick command-line check with `pickletools`") | |
| st.code("python -m pickletools your_model.pkl | grep -E 'GLOBAL|REDUCE|STACK_GLOBAL'", language="bash") | |
| st.markdown(""" | |
| > **Tip:** If you see `GLOBAL`, `STACK_GLOBAL`, or `REDUCE` opcodes referencing | |
| > system modules like `os`, `subprocess`, or `builtins` β treat the file as malicious | |
| > and do not load it. | |
| """) | |
| # --- STEP 3: SUPPLY CHAIN SIMULATION --- | |
| st.header("Step 3: Supply Chain Simulation") | |
| st.markdown(""" | |
| Upload the file to Hugging Face to simulate a **compromised model registry**. | |
| Anyone who downloads and loads this model will unknowingly execute the payload. | |
| """) | |
| username = "vchirrav" | |
| repo_id = f"{username}/security-lab-demo" | |
| if st.button(f"Upload to `{repo_id}`", key="upload"): | |
| if not os.path.exists(PKL_PATH): | |
| st.warning("β οΈ Please generate the vulnerable model first (Step 1).") | |
| else: | |
| api = HfApi(token=HF_TOKEN) | |
| st.write(f"Uploading to `{repo_id}`...") | |
| try: | |
| api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) | |
| # Read as bytes and pass directly β prevents HF Hub from routing | |
| # .pkl files through Git LFS, which causes the "LFS pointer" error. | |
| with open(PKL_PATH, "rb") as f: | |
| file_bytes = f.read() | |
| api.upload_file( | |
| path_or_fileobj=file_bytes, | |
| path_in_repo=PKL_PATH, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| st.success(f"β Uploaded to https://huggingface.co/{repo_id}") | |
| st.warning("β οΈ In a real attack, victims download and load this β silently executing the payload.") | |
| except Exception as e: | |
| st.error(f"β Upload failed: {e}") | |
| # --- STEP 4: REMEDIATE --- | |
| st.header("Step 4: Remediate with Safetensors") | |
| st.markdown(""" | |
| Convert the model to `safetensors` format. | |
| `safetensors` stores **only raw tensor data** in a flat binary format β | |
| it is architecturally incapable of embedding executable code. | |
| """) | |
| if st.button("Convert to Safetensors", key="convert"): | |
| safe_model_data = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)} | |
| save_file(safe_model_data, SAFE_PATH) | |
| st.success(f"β Converted! Saved as `{SAFE_PATH}`.") | |
| st.info("βΉοΈ Only raw tensor values were saved β no Python objects, no callable code.") | |
| with st.expander("π Show real-world mitigation code β converting any model to safetensors"): | |
| st.markdown("#### Install the required packages") | |
| st.code("pip install safetensors torch", language="bash") | |
| st.markdown("#### Convert a PyTorch model (.pt / .pth / .pkl) to safetensors") | |
| st.code('''import torch | |
| from safetensors.torch import save_file | |
| # Load the original model (only do this with files you already trust or have scanned) | |
| state_dict = torch.load("model.pt", map_location="cpu") | |
| # If the file contains a full model object rather than a plain state_dict, extract it | |
| if hasattr(state_dict, "state_dict"): | |
| state_dict = state_dict.state_dict() | |
| # Strip out any non-tensor entries (metadata strings, config dicts, etc.) | |
| tensor_only = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} | |
| # Save in safetensors format β only raw tensor bytes, no executable code possible | |
| save_file(tensor_only, "model.safetensors") | |
| print("Conversion complete: model.safetensors") | |
| ''', language="python") | |
| st.markdown("#### Load the safetensors file back (safe to do with untrusted files)") | |
| st.code('''from safetensors.torch import load_file | |
| state_dict = load_file("model.safetensors") | |
| # Restore into your model architecture | |
| model = MyModel() | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| ''', language="python") | |
| st.markdown("#### Using numpy instead of torch (no GPU/CUDA required)") | |
| st.code('''import numpy as np | |
| from safetensors.numpy import save_file, load_file | |
| # Save | |
| arrays = {"weights": np.array([0.1, 0.2, 0.3], dtype=np.float32)} | |
| save_file(arrays, "model.safetensors") | |
| # Load | |
| loaded = load_file("model.safetensors") | |
| print(loaded["weights"]) | |
| ''', language="python") | |
| st.markdown(""" | |
| > **Why safetensors is safe:** The format stores a JSON header describing tensor shapes | |
| > and dtypes, followed by raw binary tensor data. There is no mechanism to store Python | |
| > objects, callables, or executable bytecode β making it safe to load from untrusted sources. | |
| """) | |
| # --- STEP 5: UPLOAD SECURE MODEL --- | |
| st.header("Step 5: Publish the Secure Model") | |
| st.markdown(f""" | |
| Upload the safe `safetensors` file to the **same repository** as the vulnerable model. | |
| This simulates replacing a compromised model in the registry with a remediated one. | |
| """) | |
| if st.button(f"Upload Secure Model to `{repo_id}`", key="upload_safe"): | |
| if not os.path.exists(SAFE_PATH): | |
| st.warning("β οΈ Please convert to safetensors first (Step 4).") | |
| else: | |
| api = HfApi(token=HF_TOKEN) | |
| st.write(f"Uploading `{SAFE_PATH}` to `{repo_id}`...") | |
| try: | |
| api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) | |
| with open(SAFE_PATH, "rb") as f: | |
| safe_bytes = f.read() | |
| api.upload_file( | |
| path_or_fileobj=safe_bytes, | |
| path_in_repo=SAFE_PATH, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| st.success(f"β Secure model uploaded to https://huggingface.co/{repo_id}") | |
| st.info( | |
| f"βΉοΈ Both files now exist in the same repo:\n" | |
| f"- `{PKL_PATH}` β the vulnerable pickle (still there as evidence)\n" | |
| f"- `{SAFE_PATH}` β the remediated safetensors replacement" | |
| ) | |
| except Exception as e: | |
| st.error(f"β Upload failed: {e}") | |
| # --- STEP 6: VERIFY --- | |
| st.header("Step 6: Verify the Fix") | |
| st.markdown("Scan the safetensors file to confirm the vulnerability is gone.") | |
| if st.button("Scan Secure Model", key="verify"): | |
| if not os.path.exists(SAFE_PATH): | |
| st.warning("β οΈ Please convert to safetensors first (Step 4).") | |
| else: | |
| result = scan_pickle_file(SAFE_PATH) | |
| if result["safe"]: | |
| st.success("π Clean scan! No dangerous opcodes found in the safetensors file.") | |
| st.info("βΉοΈ safetensors files are not pickle-based β they cannot contain executable code.") | |
| else: | |
| st.error("Unexpected findings β review below.") | |
| for f in result["findings"]: | |
| st.markdown(f"- {f}") | |
| # --- LAB SUMMARY --- | |
| st.divider() | |
| st.header("π§ Key Takeaways") | |
| st.markdown(""" | |
| | Format | Can Embed Code? | Safe to Load Untrusted Files? | | |
| |---|---|---| | |
| | `.pkl` (pickle) | β Yes | β Never | | |
| | `.pt` / `.pth` (PyTorch) | β Yes (uses pickle internally) | β No | | |
| | `.safetensors` | β No | β Yes | | |
| **Best practice:** Always use `safetensors` for distributing model weights. | |
| If you must load a pickle-based model, scan it statically first and only load | |
| files from fully trusted, verified sources. | |
| """) | |