| import io | |
| import os | |
| import posixpath | |
| import zipfile | |
| from dataclasses import dataclass | |
| from typing import List | |
| class ExtractedFile: | |
| filename: str | |
| content: bytes | |
| class ZipExtractionError(Exception): | |
| pass | |
| MAX_EXTRACTED_FILES = 500 | |
| MAX_TOTAL_UNCOMPRESSED_BYTES = 100 * 1024 * 1024 | |
| MAX_SINGLE_FILE_BYTES = 25 * 1024 * 1024 | |
| MAX_ZIP_DEPTH = 5 | |
| MAX_COMPRESSION_RATIO = 200 | |
| SAFE_EXTENSIONS = { | |
| ".zip", ".csv", ".json", ".pdf", ".docx", ".txt", ".md", ".log", ".py", ".js", ".ts", | |
| ".tsx", ".jsx", ".java", ".go", ".rs", ".c", ".cpp", ".cs", ".html", ".css", ".xml", | |
| ".yaml", ".yml", ".sql", ".sh", ".jsonl", ".png", ".jpg", ".jpeg", ".bmp", ".gif", | |
| ".tif", ".tiff", ".webp" | |
| } | |
| def is_safe_member_name(name: str) -> bool: | |
| normalized = posixpath.normpath(name).lstrip("/") | |
| if normalized.startswith("../") or "/../" in normalized: | |
| return False | |
| if normalized in {".", ""}: | |
| return False | |
| basename = os.path.basename(normalized) | |
| if basename.startswith(".") or normalized.startswith("__MACOSX"): | |
| return False | |
| return True | |
| def extract_uploaded_items(items: List[tuple[str, bytes]]) -> List[ExtractedFile]: | |
| extracted: List[ExtractedFile] = [] | |
| state = {"files": 0, "bytes": 0} | |
| for filename, content in items: | |
| _extract_item(filename, content, extracted, state, depth=0) | |
| return extracted | |
| def _extract_item(filename: str, content: bytes, extracted: List[ExtractedFile], state: dict, depth: int) -> None: | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == ".zip": | |
| _extract_zip(filename, content, extracted, state, depth + 1) | |
| return | |
| if ext not in SAFE_EXTENSIONS and ext != "": | |
| return | |
| if len(content) > MAX_SINGLE_FILE_BYTES: | |
| return | |
| state["files"] += 1 | |
| state["bytes"] += len(content) | |
| if state["files"] > MAX_EXTRACTED_FILES or state["bytes"] > MAX_TOTAL_UNCOMPRESSED_BYTES: | |
| raise ZipExtractionError("Extraction limits exceeded") | |
| extracted.append(ExtractedFile(filename=filename, content=content)) | |
| def _extract_zip(filename: str, content: bytes, extracted: List[ExtractedFile], state: dict, depth: int) -> None: | |
| if depth > MAX_ZIP_DEPTH: | |
| raise ZipExtractionError("Nested ZIP depth exceeded") | |
| with zipfile.ZipFile(io.BytesIO(content)) as zf: | |
| infos = zf.infolist() | |
| for info in infos: | |
| if info.is_dir(): | |
| continue | |
| if not is_safe_member_name(info.filename): | |
| continue | |
| if info.file_size > MAX_SINGLE_FILE_BYTES: | |
| continue | |
| compressed = max(info.compress_size, 1) | |
| ratio = info.file_size / compressed | |
| if ratio > MAX_COMPRESSION_RATIO: | |
| continue | |
| member_name = posixpath.normpath(info.filename).lstrip("/") | |
| archive_root = os.path.splitext(filename)[0] | |
| full_name = f"{archive_root}/{member_name}" | |
| data = zf.read(info) | |
| _extract_item(full_name, data, extracted, state, depth) | |