Spaces:
Sleeping
Sleeping
| import re | |
| import torch | |
| import openai | |
| from functools import partial | |
| import time | |
| import multiprocessing | |
| def get_openai_embedding(text, | |
| model="text-embedding-ada-002", | |
| max_retry=1, | |
| sleep_time=0): | |
| assert isinstance(text, str), f'text must be str, but got {type(text)}' | |
| assert len(text) > 0, f'text to be embedded should be non-empty' | |
| client = openai.OpenAI() | |
| for _ in range(max_retry): | |
| try: | |
| emb = client.embeddings.create(input=[text], model=model) | |
| return torch.FloatTensor(emb.data[0].embedding).view(1, -1) | |
| except openai.BadRequestError as e: | |
| print(f'{e}') | |
| e = str(e) | |
| ori_length = len(text.split(' ')) | |
| match = re.search(r'maximum context length is (\d+) tokens, however you requested (\d+) tokens', e) | |
| if match is not None: | |
| max_length = int(match.group(1)) | |
| cur_length = int(match.group(2)) | |
| ratio = float(max_length) / cur_length | |
| for reduce_rate in range(9, 0, -1): | |
| shorten_text = text.split(' ') | |
| length = int(ratio * ori_length * (reduce_rate * 0.1)) | |
| shorten_text = ' '.join(shorten_text[:length]) | |
| try: | |
| emb = client.embeddings.create(input=[shorten_text], model=model) | |
| print(f'length={length} works! reduce_rate={0.1 * reduce_rate}.') | |
| return torch.FloatTensor(emb.data[0].embedding).view(1, -1) | |
| except: | |
| continue | |
| except (openai.RateLimitError, openai.APITimeoutError) as e: | |
| print(f'{e}, sleep for 1 min') | |
| time.sleep(sleep_time) | |
| def get_openai_embeddings(texts, | |
| n_max_nodes=5, | |
| model="text-embedding-ada-002" | |
| ): | |
| """ | |
| Get embeddings for a list of texts. | |
| """ | |
| assert isinstance(texts, list), f'texts must be list, but got {type(texts)}' | |
| assert all([len(s) > 0 for s in texts]), f'every string in the `texts` list to be embedded should be non-empty' | |
| processes = min(len(texts), n_max_nodes) | |
| ada_encoder = partial(get_openai_embedding, model=model) | |
| with multiprocessing.Pool(processes=processes) as pool: | |
| results = pool.map(ada_encoder, texts) | |
| results = torch.cat(results, dim=0) | |
| return results | |