|
|
import os |
|
|
import pickle |
|
|
import time |
|
|
|
|
|
from typing import Any, List |
|
|
|
|
|
class Client: |
|
|
""" |
|
|
Wrapper class for language models that we query. It keeps a cache of prompts and |
|
|
responses so that we don't have to requery things in experiments. |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_file, model : str = 'gpt-3.5-turbo'): |
|
|
self.cache_file = cache_file |
|
|
self.cache_dict = self.load_cache() |
|
|
self.model = model |
|
|
self.modified_cache = False |
|
|
|
|
|
def load_model(self): |
|
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
def query( |
|
|
self, |
|
|
prompt : str, |
|
|
sample_idx : int = 0, |
|
|
**kwargs |
|
|
): |
|
|
prompt = prompt.strip() |
|
|
cache_key = f"{prompt}_{sample_idx}" |
|
|
|
|
|
if cache_key in self.cache_dict: |
|
|
return self.cache_dict[cache_key] |
|
|
|
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
output = self._query(prompt, **kwargs) |
|
|
|
|
|
return output |
|
|
|
|
|
def cache_outputs( |
|
|
self, |
|
|
prompts : List[str], |
|
|
sample_indices : List[int], |
|
|
outputs : List[Any] |
|
|
): |
|
|
for prompt, sample_idx, output in zip(prompts, sample_indices, outputs): |
|
|
prompt = prompt.strip() |
|
|
cache_key = f"{prompt}_{sample_idx}" |
|
|
self.cache_dict[cache_key] = output |
|
|
self.modified_cache = True |
|
|
|
|
|
def save_cache(self): |
|
|
if self.modified_cache == False: |
|
|
return |
|
|
|
|
|
|
|
|
for k, v in self.load_cache().items(): |
|
|
self.cache_dict[k] = v |
|
|
|
|
|
with open(self.cache_file, "wb") as f: |
|
|
pickle.dump(self.cache_dict, f) |
|
|
|
|
|
def load_cache(self, allow_retry=True): |
|
|
if os.path.exists(self.cache_file): |
|
|
while True: |
|
|
try: |
|
|
with open(self.cache_file, "rb") as f: |
|
|
cache = pickle.load(f) |
|
|
break |
|
|
except Exception: |
|
|
if not allow_retry: |
|
|
assert False |
|
|
print ("Pickle Error: Retry in 5sec...") |
|
|
time.sleep(5) |
|
|
elif 's3' in self.cache_file: |
|
|
from aws_utils import s3_open |
|
|
s3_path = self.cache_file.removeprefix('s3://') |
|
|
bucket_name = s3_path.split('/')[0] |
|
|
path_to_file = '/'.join(s3_path.split('/')[1:]) |
|
|
with s3_open(bucket_name, path_to_file) as fp: |
|
|
cache = pickle.load(fp) |
|
|
else: |
|
|
cache = {} |
|
|
return cache |
|
|
|
|
|
|
|
|
|
|
|
|