File size: 4,754 Bytes
be7647c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """Caching of formatted files with feature-based invalidation."""
import hashlib
import os
import pickle
import sys
import tempfile
from collections.abc import Iterable
from dataclasses import dataclass, field
from pathlib import Path
from typing import NamedTuple
from platformdirs import user_cache_dir
from _black_version import version as __version__
from black.mode import Mode
from black.output import err
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
class FileData(NamedTuple):
st_mtime: float
st_size: int
hash: str
def get_cache_dir() -> Path:
"""Get the cache directory used by black.
Users can customize this directory on all systems using `BLACK_CACHE_DIR`
environment variable. By default, the cache directory is the user cache directory
under the black application.
This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid
repeated calls.
"""
# NOTE: Function mostly exists as a clean way to test getting the cache directory.
default_cache_dir = user_cache_dir("black")
cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir))
cache_dir = cache_dir / __version__
return cache_dir
CACHE_DIR = get_cache_dir()
def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
@dataclass
class Cache:
mode: Mode
cache_file: Path
file_data: dict[str, FileData] = field(default_factory=dict)
@classmethod
def read(cls, mode: Mode) -> Self:
"""Read the cache if it exists and is well-formed.
If it is not well-formed, the call to write later should
resolve the issue.
"""
cache_file = get_cache_file(mode)
try:
exists = cache_file.exists()
except OSError as e:
# Likely file too long; see #4172 and #4174
err(f"Unable to read cache file {cache_file} due to {e}")
return cls(mode, cache_file)
if not exists:
return cls(mode, cache_file)
with cache_file.open("rb") as fobj:
try:
data: dict[str, tuple[float, int, str]] = pickle.load(fobj)
file_data = {k: FileData(*v) for k, v in data.items()}
except (pickle.UnpicklingError, ValueError, IndexError):
return cls(mode, cache_file)
return cls(mode, cache_file, file_data)
@staticmethod
def hash_digest(path: Path) -> str:
"""Return hash digest for path."""
data = path.read_bytes()
return hashlib.sha256(data).hexdigest()
@staticmethod
def get_file_data(path: Path) -> FileData:
"""Return file data for path."""
stat = path.stat()
hash = Cache.hash_digest(path)
return FileData(stat.st_mtime, stat.st_size, hash)
def is_changed(self, source: Path) -> bool:
"""Check if source has changed compared to cached version."""
res_src = source.resolve()
old = self.file_data.get(str(res_src))
if old is None:
return True
st = res_src.stat()
if st.st_size != old.st_size:
return True
if st.st_mtime != old.st_mtime:
new_hash = Cache.hash_digest(res_src)
if new_hash != old.hash:
return True
return False
def filtered_cached(self, sources: Iterable[Path]) -> tuple[set[Path], set[Path]]:
"""Split an iterable of paths in `sources` into two sets.
The first contains paths of files that modified on disk or are not in the
cache. The other contains paths to non-modified files.
"""
changed: set[Path] = set()
done: set[Path] = set()
for src in sources:
if self.is_changed(src):
changed.add(src)
else:
done.add(src)
return changed, done
def write(self, sources: Iterable[Path]) -> None:
"""Update the cache file data and write a new cache file."""
self.file_data.update(
**{str(src.resolve()): Cache.get_file_data(src) for src in sources}
)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(
dir=str(self.cache_file.parent), delete=False
) as f:
# We store raw tuples in the cache because it's faster.
data: dict[str, tuple[float, int, str]] = {
k: (*v,) for k, v in self.file_data.items()
}
pickle.dump(data, f, protocol=4)
os.replace(f.name, self.cache_file)
except OSError:
pass
|