File size: 2,143 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import hashlib
import os
import pickle

import dill

from lmms_eval.loggers.utils import _handle_non_serializable, is_serializable
from lmms_eval.utils import eval_logger

MODULE_DIR = os.path.dirname(os.path.realpath(__file__))

OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")


PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"

# This should be sufficient for uniqueness
HASH_INPUT = "EleutherAI-lm-evaluation-harness"

HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()

FILE_SUFFIX = f".{HASH_PREFIX}.pickle"


def load_from_cache(file_name):
    try:
        path = f"{PATH}/{file_name}{FILE_SUFFIX}"

        with open(path, "rb") as file:
            cached_task_dict = dill.loads(file.read())
            return cached_task_dict

    except Exception:
        eval_logger.debug(f"{file_name} is not cached, generating...")
        pass


def save_to_cache(file_name, obj):
    if not os.path.exists(PATH):
        os.mkdir(PATH)

    file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"

    serializable_obj = []

    for item in obj:
        for subitem in item:
            if hasattr(subitem, "arguments"):  # we need to handle the arguments specially since doc_to_visual is callable method and not serializable
                serializable_arguments = tuple(arg if not callable(arg) else None for arg in subitem.arguments)
                subitem.arguments = serializable_arguments

    eval_logger.debug(f"Saving {file_path} to cache...")
    try:
        with open(file_path, "wb") as file:
            file.write(dill.dumps(serializable_obj))
    except (pickle.PickleError, dill.PicklingError, TypeError, AttributeError):
        with open(file_path, "wb") as file:
            file.write(dill.dumps([[subitem if is_serializable(subitem) else _handle_non_serializable(subitem) for subitem in item] for item in obj]))


# NOTE the "key" param is to allow for flexibility
def delete_cache(key: str = ""):
    files = os.listdir(PATH)

    for file in files:
        if file.startswith(key) and file.endswith(FILE_SUFFIX):
            file_path = f"{PATH}/{file}"
            os.unlink(file_path)