File size: 3,062 Bytes
0cb7559 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import io
import os
import posixpath
import zipfile
from dataclasses import dataclass
from typing import List
@dataclass
class ExtractedFile:
filename: str
content: bytes
class ZipExtractionError(Exception):
pass
MAX_EXTRACTED_FILES = 500
MAX_TOTAL_UNCOMPRESSED_BYTES = 100 * 1024 * 1024
MAX_SINGLE_FILE_BYTES = 25 * 1024 * 1024
MAX_ZIP_DEPTH = 5
MAX_COMPRESSION_RATIO = 200
SAFE_EXTENSIONS = {
".zip", ".csv", ".json", ".pdf", ".docx", ".txt", ".md", ".log", ".py", ".js", ".ts",
".tsx", ".jsx", ".java", ".go", ".rs", ".c", ".cpp", ".cs", ".html", ".css", ".xml",
".yaml", ".yml", ".sql", ".sh", ".jsonl", ".png", ".jpg", ".jpeg", ".bmp", ".gif",
".tif", ".tiff", ".webp"
}
def is_safe_member_name(name: str) -> bool:
normalized = posixpath.normpath(name).lstrip("/")
if normalized.startswith("../") or "/../" in normalized:
return False
if normalized in {".", ""}:
return False
basename = os.path.basename(normalized)
if basename.startswith(".") or normalized.startswith("__MACOSX"):
return False
return True
def extract_uploaded_items(items: List[tuple[str, bytes]]) -> List[ExtractedFile]:
extracted: List[ExtractedFile] = []
state = {"files": 0, "bytes": 0}
for filename, content in items:
_extract_item(filename, content, extracted, state, depth=0)
return extracted
def _extract_item(filename: str, content: bytes, extracted: List[ExtractedFile], state: dict, depth: int) -> None:
ext = os.path.splitext(filename)[1].lower()
if ext == ".zip":
_extract_zip(filename, content, extracted, state, depth + 1)
return
if ext not in SAFE_EXTENSIONS and ext != "":
return
if len(content) > MAX_SINGLE_FILE_BYTES:
return
state["files"] += 1
state["bytes"] += len(content)
if state["files"] > MAX_EXTRACTED_FILES or state["bytes"] > MAX_TOTAL_UNCOMPRESSED_BYTES:
raise ZipExtractionError("Extraction limits exceeded")
extracted.append(ExtractedFile(filename=filename, content=content))
def _extract_zip(filename: str, content: bytes, extracted: List[ExtractedFile], state: dict, depth: int) -> None:
if depth > MAX_ZIP_DEPTH:
raise ZipExtractionError("Nested ZIP depth exceeded")
with zipfile.ZipFile(io.BytesIO(content)) as zf:
infos = zf.infolist()
for info in infos:
if info.is_dir():
continue
if not is_safe_member_name(info.filename):
continue
if info.file_size > MAX_SINGLE_FILE_BYTES:
continue
compressed = max(info.compress_size, 1)
ratio = info.file_size / compressed
if ratio > MAX_COMPRESSION_RATIO:
continue
member_name = posixpath.normpath(info.filename).lstrip("/")
archive_root = os.path.splitext(filename)[0]
full_name = f"{archive_root}/{member_name}"
data = zf.read(info)
_extract_item(full_name, data, extracted, state, depth)
|