| import os.path |
| import shutil |
| from typing import List, Optional, Union |
|
|
| from inference.core.env import MODEL_CACHE_DIR |
| from inference.core.utils.file_system import ( |
| dump_bytes, |
| dump_json, |
| dump_text_lines, |
| read_json, |
| read_text_file, |
| ) |
|
|
|
|
| def initialise_cache(model_id: Optional[str] = None) -> None: |
| cache_dir = get_cache_dir(model_id=model_id) |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
| def are_all_files_cached(files: List[str], model_id: Optional[str] = None) -> bool: |
| return all(is_file_cached(file=file, model_id=model_id) for file in files) |
|
|
|
|
| def is_file_cached(file: str, model_id: Optional[str] = None) -> bool: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| return os.path.isfile(cached_file_path) |
|
|
|
|
| def load_text_file_from_cache( |
| file: str, |
| model_id: Optional[str] = None, |
| split_lines: bool = False, |
| strip_white_chars: bool = False, |
| ) -> Union[str, List[str]]: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| return read_text_file( |
| path=cached_file_path, |
| split_lines=split_lines, |
| strip_white_chars=strip_white_chars, |
| ) |
|
|
|
|
| def load_json_from_cache( |
| file: str, model_id: Optional[str] = None, **kwargs |
| ) -> Optional[Union[dict, list]]: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| return read_json(path=cached_file_path, **kwargs) |
|
|
|
|
| def save_bytes_in_cache( |
| content: bytes, |
| file: str, |
| model_id: Optional[str] = None, |
| allow_override: bool = True, |
| ) -> None: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| dump_bytes(path=cached_file_path, content=content, allow_override=allow_override) |
|
|
|
|
| def save_json_in_cache( |
| content: Union[dict, list], |
| file: str, |
| model_id: Optional[str] = None, |
| allow_override: bool = True, |
| **kwargs, |
| ) -> None: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| dump_json( |
| path=cached_file_path, content=content, allow_override=allow_override, **kwargs |
| ) |
|
|
|
|
| def save_text_lines_in_cache( |
| content: List[str], |
| file: str, |
| model_id: Optional[str] = None, |
| allow_override: bool = True, |
| ) -> None: |
| cached_file_path = get_cache_file_path(file=file, model_id=model_id) |
| dump_text_lines( |
| path=cached_file_path, content=content, allow_override=allow_override |
| ) |
|
|
|
|
| def get_cache_file_path(file: str, model_id: Optional[str] = None) -> str: |
| cache_dir = get_cache_dir(model_id=model_id) |
| return os.path.join(cache_dir, file) |
|
|
|
|
| def clear_cache(model_id: Optional[str] = None) -> None: |
| cache_dir = get_cache_dir(model_id=model_id) |
| if os.path.exists(cache_dir): |
| shutil.rmtree(cache_dir) |
|
|
|
|
| def get_cache_dir(model_id: Optional[str] = None) -> str: |
| if model_id is not None: |
| return os.path.join(MODEL_CACHE_DIR, model_id) |
| return MODEL_CACHE_DIR |
|
|