KangjunNoh's picture
Upload 47 files
906e061 verified
import os
import pickle
import time
from typing import Any, List
class Client:
"""
Wrapper class for language models that we query. It keeps a cache of prompts and
responses so that we don't have to requery things in experiments.
"""
def __init__(self, cache_file, model : str = 'gpt-3.5-turbo'):
self.cache_file = cache_file
self.cache_dict = self.load_cache()
self.model = model
self.modified_cache = False
def load_model(self):
# load the model and put it as self.model
raise NotImplementedError()
def query(
self,
prompt : str,
sample_idx : int = 0,
**kwargs
):
prompt = prompt.strip() # it's important not to end with a whitespace
cache_key = f"{prompt}_{sample_idx}"
if cache_key in self.cache_dict:
return self.cache_dict[cache_key]
if self.model is None:
self.load_model()
# print("I didn't find a cached copy!")
output = self._query(prompt, **kwargs)
return output
def cache_outputs(
self,
prompts : List[str],
sample_indices : List[int],
outputs : List[Any]
):
for prompt, sample_idx, output in zip(prompts, sample_indices, outputs):
prompt = prompt.strip()
cache_key = f"{prompt}_{sample_idx}"
self.cache_dict[cache_key] = output
self.modified_cache = True
def save_cache(self):
if self.modified_cache == False:
return
# load the latest cache first, since if there were other processes running in parallel, cache might have been updated
for k, v in self.load_cache().items():
self.cache_dict[k] = v
with open(self.cache_file, "wb") as f:
pickle.dump(self.cache_dict, f)
def load_cache(self, allow_retry=True):
if os.path.exists(self.cache_file):
while True:
try:
with open(self.cache_file, "rb") as f:
cache = pickle.load(f)
break
except Exception: # if there are concurent processes, things can fail
if not allow_retry:
assert False
print ("Pickle Error: Retry in 5sec...")
time.sleep(5)
elif 's3' in self.cache_file:
from aws_utils import s3_open
s3_path = self.cache_file.removeprefix('s3://')
bucket_name = s3_path.split('/')[0]
path_to_file = '/'.join(s3_path.split('/')[1:])
with s3_open(bucket_name, path_to_file) as fp:
cache = pickle.load(fp)
else:
cache = {}
return cache