xiaowenbin commited on
Commit
50a041e
·
verified ·
1 Parent(s): 58a9abd

Delete mteb_eval_openai.py

Browse files
Files changed (1) hide show
  1. mteb_eval_openai.py +0 -100
mteb_eval_openai.py DELETED
@@ -1,100 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import hashlib
5
- import numpy as np
6
- import requests
7
-
8
- import logging
9
- import functools
10
- from mteb import MTEB
11
- from sentence_transformers import SentenceTransformer
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger("main")
14
-
15
- OPENAI_BASE_URL = os.environ.get('OPENAI_BASE_URL', '')
16
- OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
17
- EMB_CACHE_DIR = os.environ.get('EMB_CACHE_DIR', '.cache/embs')
18
- os.makedirs(EMB_CACHE_DIR, exist_ok=True)
19
-
20
-
21
- def uuid_for_text(text):
22
- return hashlib.md5(text.encode('utf8')).hexdigest()
23
-
24
- def request_openai_emb(texts, model="text-embedding-3-large",
25
- base_url='https://api.openai.com', prefix_url='/v1/embeddings',
26
- timeout=4, retry=3, interval=2, caching=True):
27
- if isinstance(texts, str):
28
- texts = [texts]
29
- assert len(texts) <= 256
30
-
31
- data = []
32
- if caching:
33
- for text in texts:
34
- emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
35
- if os.path.isfile(emb_file) and os.path.getsize(emb_file) > 0:
36
- data.append(np.loadtxt(emb_file))
37
- if len(texts) == len(data):
38
- return data
39
-
40
- url = f"{OPENAI_BASE_URL}{prefix_url}" if OPENAI_BASE_URL else f"{base_url}{prefix_url}"
41
- headers = {
42
- "Authorization": f"Bearer {OPENAI_API_KEY}",
43
- "Content-Type": "application/json"
44
- }
45
- payload = {"input": texts, "model": model}
46
-
47
- while retry > 0 and len(data) == 0:
48
- try:
49
- r = requests.post(url, headers=headers, json=payload,
50
- timeout=timeout)
51
- res = r.json()
52
- for x in res["data"]:
53
- data.append(np.array(x["embedding"]))
54
- except Exception as e:
55
- print(f"request openai, retry {retry}, error: {e}", file=sys.stderr)
56
- time.sleep(interval)
57
- retry -= 1
58
-
59
- if len(data) != len(texts):
60
- data = []
61
-
62
- if caching and len(data) > 0:
63
- for text, emb in zip(texts, data):
64
- emb_file = f"{EMB_CACHE_DIR}/{uuid_for_text(text)}"
65
- np.savetxt(emb_file, emb)
66
-
67
- return data
68
-
69
-
70
- class OpenaiEmbModel:
71
-
72
- def encode(self, sentences, batch_size=32, **kwargs):
73
- batch_size = min(64, batch_size)
74
-
75
- embs = []
76
- for i in range(0, len(sentences), batch_size):
77
- batch_texts = sentences[i:i+batch_size]
78
- batch_embs = request_openai_emb(batch_texts,
79
- caching=True, retry=3, interval=2)
80
- assert len(batch_texts) == len(batch_embs), "The batch of texts and embs DONT match!"
81
- embs.extend(batch_embs)
82
-
83
- return embs
84
-
85
-
86
- model = OpenaiEmbModel()
87
-
88
- ######
89
- # test
90
- #####
91
- #embs = model.encode(['全国', '北京'])
92
- #print(embs)
93
-
94
- # task_list
95
- task_list = ['Classification', 'Clustering', 'Reranking', 'Retrieval', 'STS', 'PairClassification']
96
- # languages
97
- task_langs=["zh", "zh-CN"]
98
-
99
- evaluation = MTEB(task_types=task_list, task_langs=task_langs)
100
- evaluation.run(model, output_folder=f"results/zh/{model_name.split('/')[-1]}")