| import os |
| import sys |
| import time |
| import hashlib |
| import numpy as np |
| import requests |
|
|
| import logging |
| import functools |
| import tiktoken |
| from tqdm import tqdm |
| from mteb import MTEB |
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("main") |
|
|
| all_task_list = ['Classification', 'Clustering', 'Reranking', 'Retrieval', 'STS', 'PairClassification'] |
| if len(sys.argv) > 1: |
| task_list = [t for t in sys.argv[1].split(',') if t in all_task_list] |
| else: |
| task_list = all_task_list |
|
|
| OPENAI_BASE_URL = os.environ.get('OPENAI_BASE_URL', '') |
| OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '') |
| EMB_CACHE_DIR = os.environ.get('EMB_CACHE_DIR', '.cache/embs') |
| REQ_OPENAI_TIMEOUT = int(os.environ.get('REQ_OPENAI_TIMEOUT', 120)) |
| REQ_OPENAI_RETRY = int(os.environ.get('REQ_OPENAI_RETRY', 3)) |
| REQ_OPENAI_INTERVAL = int(os.environ.get('REQ_OPENAI_INTERVAL', 60)) |
| os.makedirs(EMB_CACHE_DIR, exist_ok=True) |
|
|
| def log(*args): |
| print(*args, file=sys.stderr) |
|
|
| def uuid_for_text(text): |
| return hashlib.md5(text.encode('utf8')).hexdigest() |
|
|
| def count_openai_tokens(text, model="text-embedding-3-large"): |
| encoding = tiktoken.get_encoding("cl100k_base") |
| |
| input_ids = encoding.encode(text) |
| return len(input_ids) |
|
|
| def request_openai_emb(texts, model="text-embedding-3-large", |
| base_url='https://api.openai.com', prefix_url='/v1/embeddings', |
| timeout=4, retry=3, interval=2, caching=True): |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| data = [] |
| if caching: |
| for text in texts: |
| emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}" |
| if os.path.isfile(emb_file) and os.path.getsize(emb_file) > 0: |
| data.append(np.loadtxt(emb_file)) |
| if len(texts) == len(data): |
| return data |
|
|
| url = f"{OPENAI_BASE_URL}{prefix_url}" if OPENAI_BASE_URL else f"{base_url}{prefix_url}" |
| headers = { |
| "Authorization": f"Bearer {OPENAI_API_KEY}", |
| "Content-Type": "application/json" |
| } |
| payload = {"input": texts, "model": model} |
|
|
| data = [] |
| while retry > 0 and len(data) == 0: |
| try: |
| r = requests.post(url, headers=headers, json=payload, |
| timeout=timeout) |
| res = r.json() |
| for x in res["data"]: |
| data.append(np.array(x["embedding"])) |
| except Exception as e: |
| log(f"request openai, retry {retry}, error: {e}") |
| time.sleep(interval) |
| retry -= 1 |
|
|
| if len(data) != len(texts): |
| log(f"request openai, failed, texts and embs DONT match!") |
| return [] |
|
|
| if caching and len(data) > 0: |
| for text, emb in zip(texts, data): |
| emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}" |
| np.savetxt(emb_file, emb) |
|
|
| return data |
|
|
|
|
| class OpenaiEmbModel: |
|
|
| def __init__(self, model_name, model_dim, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.model_name = model_name |
| self.model_dim = model_dim |
|
|
| def encode(self, sentences, batch_size=32, **kwargs): |
| i = 0 |
| max_tokens = kwargs.get("max_tokens", 8000) |
| batch_tokens = 0 |
| batch = [] |
| batch_list = [] |
| while i < len(sentences): |
| num_tokens = count_openai_tokens(sentences[i], |
| model=self.model_name) |
| if batch_tokens+num_tokens > max_tokens: |
| if batch: |
| batch_list.append(batch) |
| if num_tokens > max_tokens: |
| batch = [sentences[i][:2048]] |
| batch_tokens = count_openai_tokens(sentences[i][:2048], |
| model=self.model_name) |
| else: |
| batch = [sentences[i]] |
| batch_tokens = num_tokens |
| else: |
| batch_list.append([sentences[i][:2048]]) |
| else: |
| batch.append(sentences[i]) |
| batch_tokens += num_tokens |
| i += 1 |
| if batch: |
| batch_list.append(batch) |
|
|
| |
| |
| |
| |
| |
|
|
| log(f"Total sentences={len(sentences)}, batches={len(batch_list)}") |
| embs = [] |
| waiting = 0 |
| for batch_idx, batch_texts in enumerate(tqdm(batch_list)): |
| batch_embs = request_openai_emb(batch_texts, model=self.model_name, |
| caching=kwargs.get("caching", True), |
| timeout=kwargs.get("timeout", REQ_OPENAI_TIMEOUT), |
| retry=kwargs.get("retry", REQ_OPENAI_RETRY), |
| interval=kwargs.get("interval", REQ_OPENAI_INTERVAL)) |
|
|
| if len(batch_texts) == len(batch_embs): |
| embs.extend(batch_embs) |
| waiting = waiting // 2 |
| log(f"The batch-{batch_idx} encoding SUCCESS! waiting={waiting}s...") |
| else: |
| embs.extend([np.array([0.0 for j in range(self.model_dim)]) for i in range(len(batch_texts))]) |
| waiting = 120 if waiting <= 0 else waiting+120 |
| log(f"The batch-{batch_idx} encoding FAILED {len(batch_texts)}:{len(batch_embs)}! waiting={waiting}s...") |
|
|
| if waiting > 3600: |
| log(f"Frequently failed, should be waiting more then 3600s, break down!!!") |
| break |
| if waiting > 0: |
| time.sleep(waiting) |
|
|
| print(f'Total encoding sentences={len(sentences)}, embeddings={len(embs)}') |
| return embs |
|
|
|
|
| model_name = "text-embedding-3-large" |
| model_dim = 3072 |
| model = OpenaiEmbModel(model_name, model_dim) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| task_langs=["zh", "zh-CN"] |
|
|
| evaluation = MTEB(task_types=task_list, task_langs=task_langs) |
| evaluation.run(model, output_folder=f"results/zh/{model_name.split('/')[-1]}") |
|
|