| import bz2 |
| import gzip |
| import lzma |
| import os |
| import shutil |
| import struct |
| import tarfile |
| import warnings |
| import zipfile |
| from abc import ABC, abstractmethod |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| from .. import config |
| from ._filelock import FileLock |
| from .logging import get_logger |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class ExtractManager: |
| def __init__(self, cache_dir: Optional[str] = None): |
| self.extract_dir = ( |
| os.path.join(cache_dir, config.EXTRACTED_DATASETS_DIR) if cache_dir else config.EXTRACTED_DATASETS_PATH |
| ) |
| self.extractor = Extractor |
|
|
| def _get_output_path(self, path: str) -> str: |
| from .file_utils import hash_url_to_filename |
|
|
| |
| |
| abs_path = os.path.abspath(path) |
| return os.path.join(self.extract_dir, hash_url_to_filename(abs_path)) |
|
|
| def _do_extract(self, output_path: str, force_extract: bool) -> bool: |
| return force_extract or ( |
| not os.path.isfile(output_path) and not (os.path.isdir(output_path) and os.listdir(output_path)) |
| ) |
|
|
| def extract(self, input_path: str, force_extract: bool = False) -> str: |
| extractor_format = self.extractor.infer_extractor_format(input_path) |
| if not extractor_format: |
| return input_path |
| output_path = self._get_output_path(input_path) |
| if self._do_extract(output_path, force_extract): |
| self.extractor.extract(input_path, output_path, extractor_format) |
| return output_path |
|
|
|
|
| class BaseExtractor(ABC): |
| @classmethod |
| @abstractmethod |
| def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool: ... |
|
|
| @staticmethod |
| @abstractmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: ... |
|
|
|
|
| class MagicNumberBaseExtractor(BaseExtractor, ABC): |
| magic_numbers: list[bytes] = [] |
|
|
| @staticmethod |
| def read_magic_number(path: Union[Path, str], magic_number_length: int): |
| with open(path, "rb") as f: |
| return f.read(magic_number_length) |
|
|
| @classmethod |
| def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool: |
| if not magic_number: |
| magic_number_length = max(len(cls_magic_number) for cls_magic_number in cls.magic_numbers) |
| try: |
| magic_number = cls.read_magic_number(path, magic_number_length) |
| except OSError: |
| return False |
| return any(magic_number.startswith(cls_magic_number) for cls_magic_number in cls.magic_numbers) |
|
|
|
|
| class TarExtractor(BaseExtractor): |
| @classmethod |
| def is_extractable(cls, path: Union[Path, str], **kwargs) -> bool: |
| return tarfile.is_tarfile(path) |
|
|
| @staticmethod |
| def safemembers(members, output_path): |
| """ |
| Fix for CVE-2007-4559 |
| Desc: |
| Directory traversal vulnerability in the (1) extract and (2) extractall functions in the tarfile |
| module in Python allows user-assisted remote attackers to overwrite arbitrary files via a .. (dot dot) |
| sequence in filenames in a TAR archive, a related issue to CVE-2001-1267. |
| See: https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2007-4559 |
| From: https://stackoverflow.com/a/10077309 |
| """ |
|
|
| def resolved(path: str) -> str: |
| return os.path.realpath(os.path.abspath(path)) |
|
|
| def badpath(path: str, base: str) -> bool: |
| |
| return not resolved(os.path.join(base, path)).startswith(base) |
|
|
| def badlink(info, base: str) -> bool: |
| |
| tip = resolved(os.path.join(base, os.path.dirname(info.name))) |
| return badpath(info.linkname, base=tip) |
|
|
| base = resolved(output_path) |
|
|
| for finfo in members: |
| if badpath(finfo.name, base): |
| logger.error(f"Extraction of {finfo.name} is blocked (illegal path)") |
| elif finfo.issym() and badlink(finfo, base): |
| logger.error(f"Extraction of {finfo.name} is blocked: Symlink to {finfo.linkname}") |
| elif finfo.islnk() and badlink(finfo, base): |
| logger.error(f"Extraction of {finfo.name} is blocked: Hard link to {finfo.linkname}") |
| else: |
| yield finfo |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| os.makedirs(output_path, exist_ok=True) |
| tar_file = tarfile.open(input_path) |
| tar_file.extractall(output_path, members=TarExtractor.safemembers(tar_file, output_path)) |
| tar_file.close() |
|
|
|
|
| class GzipExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\x1f\x8b"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| with gzip.open(input_path, "rb") as gzip_file: |
| with open(output_path, "wb") as extracted_file: |
| shutil.copyfileobj(gzip_file, extracted_file) |
|
|
|
|
| class ZipExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [ |
| b"PK\x03\x04", |
| b"PK\x05\x06", |
| b"PK\x07\x08", |
| ] |
|
|
| @classmethod |
| def is_extractable(cls, path: Union[Path, str], magic_number: bytes = b"") -> bool: |
| if super().is_extractable(path, magic_number=magic_number): |
| return True |
| try: |
| |
| |
| from zipfile import ( |
| _CD_SIGNATURE, |
| _ECD_DISK_NUMBER, |
| _ECD_DISK_START, |
| _ECD_ENTRIES_TOTAL, |
| _ECD_OFFSET, |
| _ECD_SIZE, |
| _EndRecData, |
| sizeCentralDir, |
| stringCentralDir, |
| structCentralDir, |
| ) |
|
|
| with open(path, "rb") as fp: |
| endrec = _EndRecData(fp) |
| if endrec: |
| if endrec[_ECD_ENTRIES_TOTAL] == 0 and endrec[_ECD_SIZE] == 0 and endrec[_ECD_OFFSET] == 0: |
| return True |
| elif endrec[_ECD_DISK_NUMBER] == endrec[_ECD_DISK_START]: |
| fp.seek(endrec[_ECD_OFFSET]) |
| if fp.tell() == endrec[_ECD_OFFSET] and endrec[_ECD_SIZE] >= sizeCentralDir: |
| data = fp.read(sizeCentralDir) |
| if len(data) == sizeCentralDir: |
| centdir = struct.unpack(structCentralDir, data) |
| if centdir[_CD_SIGNATURE] == stringCentralDir: |
| return True |
| return False |
| except Exception: |
| return False |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| os.makedirs(output_path, exist_ok=True) |
| with zipfile.ZipFile(input_path, "r") as zip_file: |
| zip_file.extractall(output_path) |
| zip_file.close() |
|
|
|
|
| class XzExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\xfd\x37\x7a\x58\x5a\x00"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| with lzma.open(input_path) as compressed_file: |
| with open(output_path, "wb") as extracted_file: |
| shutil.copyfileobj(compressed_file, extracted_file) |
|
|
|
|
| class RarExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"Rar!\x1a\x07\x00", b"Rar!\x1a\x07\x01\x00"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| if not config.RARFILE_AVAILABLE: |
| raise ImportError("Please pip install rarfile") |
| import rarfile |
|
|
| os.makedirs(output_path, exist_ok=True) |
| rf = rarfile.RarFile(input_path) |
| rf.extractall(output_path) |
| rf.close() |
|
|
|
|
| class ZstdExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\x28\xb5\x2f\xfd"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| if not config.ZSTANDARD_AVAILABLE: |
| raise ImportError("Please pip install zstandard") |
| import zstandard as zstd |
|
|
| dctx = zstd.ZstdDecompressor() |
| with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh: |
| dctx.copy_stream(ifh, ofh) |
|
|
|
|
| class Bzip2Extractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\x42\x5a\x68"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| with bz2.open(input_path, "rb") as compressed_file: |
| with open(output_path, "wb") as extracted_file: |
| shutil.copyfileobj(compressed_file, extracted_file) |
|
|
|
|
| class SevenZipExtractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\x37\x7a\xbc\xaf\x27\x1c"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| if not config.PY7ZR_AVAILABLE: |
| raise ImportError("Please pip install py7zr") |
| import py7zr |
|
|
| os.makedirs(output_path, exist_ok=True) |
| with py7zr.SevenZipFile(input_path, "r") as archive: |
| archive.extractall(output_path) |
|
|
|
|
| class Lz4Extractor(MagicNumberBaseExtractor): |
| magic_numbers = [b"\x04\x22\x4d\x18"] |
|
|
| @staticmethod |
| def extract(input_path: Union[Path, str], output_path: Union[Path, str]) -> None: |
| if not config.LZ4_AVAILABLE: |
| raise ImportError("Please pip install lz4") |
| import lz4.frame |
|
|
| with lz4.frame.open(input_path, "rb") as compressed_file: |
| with open(output_path, "wb") as extracted_file: |
| shutil.copyfileobj(compressed_file, extracted_file) |
|
|
|
|
| class Extractor: |
| |
| extractors: dict[str, type[BaseExtractor]] = { |
| "tar": TarExtractor, |
| "gzip": GzipExtractor, |
| "zip": ZipExtractor, |
| "xz": XzExtractor, |
| "rar": RarExtractor, |
| "zstd": ZstdExtractor, |
| "bz2": Bzip2Extractor, |
| "7z": SevenZipExtractor, |
| "lz4": Lz4Extractor, |
| } |
|
|
| @classmethod |
| def _get_magic_number_max_length(cls): |
| return max( |
| len(extractor_magic_number) |
| for extractor in cls.extractors.values() |
| if issubclass(extractor, MagicNumberBaseExtractor) |
| for extractor_magic_number in extractor.magic_numbers |
| ) |
|
|
| @staticmethod |
| def _read_magic_number(path: Union[Path, str], magic_number_length: int): |
| try: |
| return MagicNumberBaseExtractor.read_magic_number(path, magic_number_length=magic_number_length) |
| except OSError: |
| return b"" |
|
|
| @classmethod |
| def is_extractable(cls, path: Union[Path, str], return_extractor: bool = False) -> bool: |
| warnings.warn( |
| "Method 'is_extractable' was deprecated in version 2.4.0 and will be removed in 3.0.0. " |
| "Use 'infer_extractor_format' instead.", |
| category=FutureWarning, |
| ) |
| extractor_format = cls.infer_extractor_format(path) |
| if extractor_format: |
| return True if not return_extractor else (True, cls.extractors[extractor_format]) |
| return False if not return_extractor else (False, None) |
|
|
| @classmethod |
| def infer_extractor_format(cls, path: Union[Path, str]) -> Optional[str]: |
| magic_number_max_length = cls._get_magic_number_max_length() |
| magic_number = cls._read_magic_number(path, magic_number_max_length) |
| for extractor_format, extractor in cls.extractors.items(): |
| if extractor.is_extractable(path, magic_number=magic_number): |
| return extractor_format |
|
|
| @classmethod |
| def extract( |
| cls, |
| input_path: Union[Path, str], |
| output_path: Union[Path, str], |
| extractor_format: str, |
| ) -> None: |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| lock_path = str(Path(output_path).with_suffix(".lock")) |
| with FileLock(lock_path): |
| shutil.rmtree(output_path, ignore_errors=True) |
| extractor = cls.extractors[extractor_format] |
| return extractor.extract(input_path, output_path) |
|
|