File size: 2,830 Bytes
906e061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import pickle
import time

from typing import Any, List

class Client:
    """
    Wrapper class for language models that we query. It keeps a cache of prompts and
    responses so that we don't have to requery things in experiments.
    """

    def __init__(self, cache_file, model : str = 'gpt-3.5-turbo'):
        self.cache_file = cache_file
        self.cache_dict = self.load_cache()
        self.model = model
        self.modified_cache = False

    def load_model(self):
        # load the model and put it as self.model
        raise NotImplementedError()

    def query(
            self, 
            prompt : str, 
            sample_idx : int = 0, 
            **kwargs
        ):
        prompt = prompt.strip() # it's important not to end with a whitespace
        cache_key = f"{prompt}_{sample_idx}"

        if cache_key in self.cache_dict:
            return self.cache_dict[cache_key]

        if self.model is None:
            self.load_model()
        # print("I didn't find a cached copy!")
        output = self._query(prompt, **kwargs)

        return output
    
    def cache_outputs(
            self,
            prompts : List[str],
            sample_indices : List[int],
            outputs : List[Any]
    ):
        for prompt, sample_idx, output in zip(prompts, sample_indices, outputs):
            prompt = prompt.strip()
            cache_key = f"{prompt}_{sample_idx}"
            self.cache_dict[cache_key] = output
            self.modified_cache = True

    def save_cache(self):
        if self.modified_cache == False:
            return

        # load the latest cache first, since if there were other processes running in parallel, cache might have been updated
        for k, v in self.load_cache().items():
            self.cache_dict[k] = v

        with open(self.cache_file, "wb") as f:
            pickle.dump(self.cache_dict, f)

    def load_cache(self, allow_retry=True):
        if os.path.exists(self.cache_file):
            while True:
                try:
                    with open(self.cache_file, "rb") as f:
                        cache = pickle.load(f)
                    break
                except Exception: # if there are concurent processes, things can fail
                    if not allow_retry:
                        assert False
                    print ("Pickle Error: Retry in 5sec...")
                    time.sleep(5)  
        elif 's3' in self.cache_file:
            from aws_utils import s3_open
            s3_path = self.cache_file.removeprefix('s3://')
            bucket_name = s3_path.split('/')[0]
            path_to_file = '/'.join(s3_path.split('/')[1:])
            with s3_open(bucket_name, path_to_file) as fp:
                cache = pickle.load(fp)
        else:
            cache = {}
        return cache