Guangsheng Bao commited on
Commit
e6450ab
·
1 Parent(s): 982d04d

demo for Glimpse

Browse files
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 gshbao
4
+ USER gshbao
5
+ ENV PATH="/home/gshbao/.local/bin:$PATH"
6
+
7
+ WORKDIR /glimpse
8
+
9
+ COPY --chown=gshbao ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=gshbao . /glimpse
13
+ CMD ["bash", "run.sh"]
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
  title: Glimpse
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
- license: cc
9
- short_description: 'Glimpse: Enabling White-Box Methods to Use Proprietary Model'
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Glimpse
3
+ emoji: 🏆
4
+ colorFrom: gray
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ license: mit
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
api.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flask
2
+ from flask_cors import CORS
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ import json
5
+ import datetime
6
+ import weave
7
+ from detector_base import get_detector
8
+
9
+ # app = flask.Flask(__name__)
10
+ app = flask.Flask(__name__, static_folder='./dist', static_url_path='/')
11
+
12
+ CORS(app, supports_credentials=True)
13
+ executor = ThreadPoolExecutor(10)
14
+
15
+ def return_data(code, msg, data, cookie="", ToNone=True):
16
+ if ToNone and len(data) <= 0:
17
+ data = None
18
+ jsonStr = {
19
+ 'code': code,
20
+ 'msg': msg,
21
+ 'data': data
22
+ }
23
+ response = flask.make_response(flask.jsonify(jsonStr))
24
+ if cookie:
25
+ for key, value in cookie.items():
26
+ response.set_cookie(key, value, max_age=3600 * 12)
27
+ return response
28
+
29
+ @weave.op()
30
+ def process_request(text, detector):
31
+ return detector.compute_prob(text)
32
+
33
+ def handle_request(detector_name):
34
+ # request data
35
+ if flask.request.method == 'POST':
36
+ try:
37
+ data = flask.request.data.decode('utf-8')
38
+ except Exception as ex:
39
+ print(datetime.datetime.now().isoformat(), ex, flush=True)
40
+ return return_data(400, 'Bad request', '')
41
+ else:
42
+ return return_data(0, '', {})
43
+ # handle request
44
+ info = {}
45
+ sentence = json.loads(data)
46
+ data = {"sentence": sentence}
47
+ print(datetime.datetime.now().isoformat(), data, flush=True)
48
+ try:
49
+ text = data["sentence"]
50
+ detector = get_detector(detector_name)
51
+ future = executor.submit(process_request, text, detector)
52
+ prob, crit, ntoken = future.result()
53
+ info["crit"] = crit
54
+ info["prob"] = prob
55
+ info["ntoken"] = ntoken
56
+ print(datetime.datetime.now().isoformat(), info, flush=True)
57
+ return return_data(0, '', info)
58
+ except Exception as ex:
59
+ print(datetime.datetime.now().isoformat(), ex, flush=True)
60
+ import os
61
+ os._exit(1)
62
+ return return_data(400, 'Bad request', '')
63
+
64
+
65
+ @app.route("/glimpse", methods=["GET", "POST"])
66
+ def glimpse():
67
+ return handle_request("glimpse")
68
+
69
+ @app.route("/", methods=["GET"])
70
+ def index():
71
+ return app.send_static_file('index.html')
72
+
73
+ if __name__ == '__main__':
74
+ # initialize detectors
75
+ detectors = ['glimpse']
76
+ for detector_name in detectors:
77
+ get_detector(detector_name)
78
+ # service
79
+ weave.init('Glimpse')
80
+ app.run(host='0.0.0.0', port=7860)
configs/glimpse.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "api_base": "${API_BASE}",
3
+ "api_key": "${API_KEY}",
4
+ "api_version": "2024-08-01-preview",
5
+ "scoring_model_name": "davinci-002",
6
+ "max_token_observed": 512,
7
+ "estimator": "geometric",
8
+ "prompt": "prompt3",
9
+ "rank_size": 1000,
10
+ "top_k": 5,
11
+ "linear_k": 1.34,
12
+ "linear_b": 2.41,
13
+ "cache_dir": "../cache"
14
+ }
detector_base.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Guangsheng Bao.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ from utils import load_json
7
+ from types import SimpleNamespace
8
+ import numpy as np
9
+
10
+ def sigmoid(x):
11
+ return 1 / (1 + np.exp(-x))
12
+
13
+ class DetectorBase:
14
+ def __init__(self, config_name):
15
+ self.config_name = config_name
16
+ self.config = self.load_config(config_name)
17
+
18
+ def load_config(self, config_name):
19
+ config = load_json(f'./configs/{config_name}.json')
20
+ for key in config:
21
+ val = config[key]
22
+ if type(val) == str and val.startswith('${') and val.endswith('}'):
23
+ val = os.getenv(val[2:-1])
24
+ config[key] = val
25
+ print(f'Config entry solved: {key} -> {val}')
26
+ return SimpleNamespace(**config)
27
+
28
+ def compute_crit(self, text):
29
+ raise NotImplementedError
30
+
31
+ def compute_prob(self, text):
32
+ crit, ntoken = self.compute_crit(text)
33
+ prob = sigmoid(self.config.linear_k * crit + self.config.linear_b)
34
+ return prob, crit, ntoken
35
+
36
+ def __str__(self):
37
+ return self.config_name
38
+
39
+
40
+ CACHE_DETECTORS = {}
41
+
42
+ def get_detector(name):
43
+ from glimpse import Glimpse
44
+ name_detectors = {
45
+ 'glimpse': ('glimpse', Glimpse),
46
+ }
47
+ # lookup cache
48
+ global CACHE_DETECTORS
49
+ if name in CACHE_DETECTORS:
50
+ return CACHE_DETECTORS[name]
51
+ # create new
52
+ if name in name_detectors:
53
+ config_name, detector_class = name_detectors[name]
54
+ detector = detector_class(config_name)
55
+ CACHE_DETECTORS[name] = detector
56
+ return detector
57
+ else:
58
+ raise NotImplementedError
dist/bitbug_favicon.ico ADDED
dist/bitbug_favicon.png ADDED
dist/favicon.ico ADDED
dist/glimpse.png ADDED
dist/index.html ADDED
@@ -0,0 +1 @@
 
 
1
+ <!doctype html><html lang="en"><head><meta charset="utf-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><meta name="viewport" content="width=device-width,initial-scale=1"><link rel="icon" href="westlake.png"><title>Demo from Westlake University</title><script defer="defer" src="static/js/app.ffb0b0581194594b.js"></script><link href="static/css/app.fa664689.css" rel="stylesheet"></head><body><noscript><strong>We're sorry but default doesn't work properly without JavaScript enabled. Please enable it to continue.</strong></noscript><div id="app"></div></body></html>
dist/static/css/app.0d49958b.css ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/css/app.3616b191.css ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/css/app.fa664689.css ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/js/app.594f31c7f0c0aa9f.js ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/js/app.594f31c7f0c0aa9f.js.LICENSE.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Vue.js v2.7.16
3
+ * (c) 2014-2023 Evan You
4
+ * Released under the MIT License.
5
+ */
6
+
7
+ /*!
8
+ * ZRender, a high performance 2d drawing library.
9
+ *
10
+ * Copyright (c) 2013, Baidu Inc.
11
+ * All rights reserved.
12
+ *
13
+ * LICENSE
14
+ * https://github.com/ecomfe/zrender/blob/master/LICENSE.txt
15
+ */
16
+
17
+ /*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
dist/static/js/app.895138338b9a9bfb.js ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/js/app.895138338b9a9bfb.js.LICENSE.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Vue.js v2.7.16
3
+ * (c) 2014-2023 Evan You
4
+ * Released under the MIT License.
5
+ */
6
+
7
+ /*!
8
+ * ZRender, a high performance 2d drawing library.
9
+ *
10
+ * Copyright (c) 2013, Baidu Inc.
11
+ * All rights reserved.
12
+ *
13
+ * LICENSE
14
+ * https://github.com/ecomfe/zrender/blob/master/LICENSE.txt
15
+ */
16
+
17
+ /*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
dist/static/js/app.edb5e5e9eac8f89e.js ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/js/app.edb5e5e9eac8f89e.js.LICENSE.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Vue.js v2.7.16
3
+ * (c) 2014-2023 Evan You
4
+ * Released under the MIT License.
5
+ */
6
+
7
+ /*!
8
+ * ZRender, a high performance 2d drawing library.
9
+ *
10
+ * Copyright (c) 2013, Baidu Inc.
11
+ * All rights reserved.
12
+ *
13
+ * LICENSE
14
+ * https://github.com/ecomfe/zrender/blob/master/LICENSE.txt
15
+ */
16
+
17
+ /*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
dist/static/js/app.ffb0b0581194594b.js ADDED
The diff for this file is too large to render. See raw diff
 
dist/static/js/app.ffb0b0581194594b.js.LICENSE.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Vue.js v2.7.16
3
+ * (c) 2014-2023 Evan You
4
+ * Released under the MIT License.
5
+ */
6
+
7
+ /*!
8
+ * ZRender, a high performance 2d drawing library.
9
+ *
10
+ * Copyright (c) 2013, Baidu Inc.
11
+ * All rights reserved.
12
+ *
13
+ * LICENSE
14
+ * https://github.com/ecomfe/zrender/blob/master/LICENSE.txt
15
+ */
16
+
17
+ /*! regenerator-runtime -- Copyright (c) 2014-present, Facebook, Inc. -- license (MIT): https://github.com/facebook/regenerator/blob/main/LICENSE */
dist/westlake.png ADDED
glimpse.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Guangsheng Bao.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import numpy as np
6
+ import json
7
+ import time
8
+ from types import SimpleNamespace
9
+ from detector_base import DetectorBase
10
+
11
+
12
+ class OpenAIGPT:
13
+ def __init__(self, config):
14
+ self.config = config
15
+ self.client = self.prepare_client1()
16
+ # predefined prompts
17
+ self.prompts = {
18
+ "prompt0": "",
19
+ "prompt1": f"You serve as a valuable aide, capable of generating clear and persuasive pieces of writing given a certain context. Now, assume the role of an author and strive to finalize this article.\n",
20
+ "prompt2": f"You serve as a valuable aide, capable of generating clear and persuasive pieces of writing given a certain context. Now, assume the role of an author and strive to finalize this article.\nI operate as an entity utilizing GPT as the foundational large language model. I function in the capacity of a writer, authoring articles on a daily basis. Presented below is an example of an article I have crafted.\n",
21
+ "prompt3": f"System:\nYou serve as a valuable aide, capable of generating clear and persuasive pieces of writing given a certain context. Now, assume the role of an author and strive to finalize this article.\nAssistant:\nI operate as an entity utilizing GPT as the foundational large language model. I function in the capacity of a writer, authoring articles on a daily basis. Presented below is an example of an article I have crafted.\n",
22
+ "prompt4": f"Assistant:\nYou serve as a valuable aide, capable of generating clear and persuasive pieces of writing given a certain context. Now, assume the role of an author and strive to finalize this article.\nUser:\nI operate as an entity utilizing GPT as the foundational large language model. I function in the capacity of a writer, authoring articles on a daily basis. Presented below is an example of an article I have crafted.\n",
23
+ }
24
+ self.max_topk = 10
25
+
26
+ def prepare_client1(self):
27
+ api_base = self.config.api_base
28
+ api_key = self.config.api_key
29
+ from openai import OpenAI
30
+ client = OpenAI(
31
+ base_url=api_base,
32
+ api_key=api_key)
33
+ return client
34
+
35
+ def prepare_client2(self):
36
+ api_base = self.config.api_base
37
+ api_key = self.config.api_key
38
+ api_version = self.config.api_version
39
+ from openai import AzureOpenAI
40
+ client = AzureOpenAI(
41
+ azure_endpoint=api_base,
42
+ api_key=api_key,
43
+ api_version=api_version)
44
+ return client
45
+
46
+ def _response_to_text(self, response):
47
+ obj = vars(response)
48
+ text = json.dumps(obj)
49
+ return text
50
+
51
+ def _response_from_text(self, text):
52
+ obj = json.loads(text)
53
+ response = SimpleNamespace(**obj)
54
+ return response
55
+
56
+ def evaluate(self, prompt, text):
57
+ model_name = self.config.scoring_model_name
58
+ kwargs = {"model": model_name,
59
+ "prompt": f"<|endoftext|>{prompt}{text}",
60
+ "max_tokens": 0, "echo": True, "logprobs": self.max_topk}
61
+ # retry 1 time
62
+ ntry = 2
63
+ for idx in range(ntry):
64
+ try:
65
+ response = self.client.completions.create(**kwargs)
66
+ response = response.choices[0].logprobs
67
+ return response
68
+ break
69
+ except Exception as e:
70
+ if idx < ntry - 1:
71
+ print(f'{model_name}, {kwargs}: {e}. Retrying ...')
72
+ time.sleep(5)
73
+ continue
74
+ raise e
75
+
76
+ def eval(self, text):
77
+ prompt = self.prompts[self.config.prompt]
78
+ # get top tokens
79
+ result = self.evaluate(prompt, text)
80
+ # decide the prefix length
81
+ prefix = ""
82
+ nprefix = 1
83
+ while len(prefix) < len(prompt):
84
+ prefix += result.tokens[nprefix]
85
+ nprefix += 1
86
+ assert prefix == prompt, f"Mismatch: {prompt} .vs. {prefix}"
87
+ tokens = result.tokens[nprefix:]
88
+ logprobs = result.token_logprobs[nprefix:]
89
+ toplogprobs = result.top_logprobs[nprefix:]
90
+ toplogprobs = [dict(item) for item in toplogprobs]
91
+ assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"
92
+ assert len(tokens) == len(toplogprobs), f"Expected {len(tokens)} toplogprobs, got {len(toplogprobs)}"
93
+ return tokens, logprobs, toplogprobs
94
+
95
+ # probability distribution estimation
96
+ def safe_log(prob):
97
+ return np.log(np.array(prob) + 1e-8)
98
+
99
+ class GeometricDistribution:
100
+ '''
101
+ Top-K probabilities: p_1, p_2, ..., p_K
102
+ Estimated probabilities: Pr(X=k) = p_K * lambda ^ (k - K), for k > K.
103
+ '''
104
+ def __init__(self, top_k, rank_size):
105
+ self.name = "GeometricDistribution"
106
+ self.top_k = top_k
107
+ self.rank_size = rank_size
108
+
109
+ def estimate_distrib_token(self, toplogprobs):
110
+ M = self.rank_size # assuming rank list size
111
+ K = self.top_k # assuming top-K tokens
112
+ assert K <= M
113
+ toplogprobs = sorted(toplogprobs.values(), reverse=True)
114
+ assert len(toplogprobs) >= K
115
+ toplogprobs = toplogprobs[:K]
116
+ probs = np.exp(toplogprobs) # distribution over ranks
117
+ if probs.sum() > 1.0:
118
+ # print(f'Warnining: Probability {probs.sum()} excels 1.0')
119
+ probs = probs / (probs.sum() + 1e-6)
120
+ p_K = probs[-1] # the k-th top token
121
+ p_rest = 1 - probs.sum() # the rest probability mass
122
+ _lambda = p_rest / (p_K + p_rest) # approximate the decay factor
123
+ if _lambda ** (M - K + 1) > 1e-6:
124
+ # If the condition was not satisfied, use the following code to calculate the decay factor iteratively
125
+ _lambda_old = _lambda
126
+ last_diff = 1.0
127
+ while True:
128
+ _lambda0 = _lambda
129
+ minor = _lambda ** (M - K + 1) # the minor part
130
+ assert p_rest > 0, f'Error: Invalid p_rest={p_rest}'
131
+ _lambda = 1 - (_lambda - minor) * p_K / p_rest
132
+ # check convergence
133
+ diff = abs(_lambda - _lambda0)
134
+ if _lambda < 0 or diff < 1e-6 or diff >= last_diff:
135
+ _lambda = _lambda0
136
+ break
137
+ last_diff = diff
138
+ # print(f'Warnining: Invalid lambda={_lambda_old}, re-calculate lambda={_lambda}')
139
+ assert p_rest >= 0, f'Error: Invalid p_rest={p_rest}'
140
+ assert 0 <= _lambda <= 1, f'Error: Invalid lambda={_lambda} calculated by p_K={p_K} and p_rest={p_rest}.'
141
+ # estimate the probabilities of the rest tokens
142
+ probs_rest = np.exp(safe_log(p_K) + np.arange(1, M - K + 1) * safe_log(_lambda))
143
+ probs = np.concatenate([probs, probs_rest])
144
+ # check total probability
145
+ # if abs(probs.sum() - 1.0) >= 1e-2:
146
+ # print(f'Warnining: Invalid total probability: {probs.sum()}')
147
+ probs = probs / probs.sum()
148
+ return probs.tolist()
149
+
150
+ class PdeBase:
151
+ def __init__(self, distrib):
152
+ self.distrib = distrib
153
+
154
+ def estimate_distrib_sequence(self, item):
155
+ key = f'{self.distrib.name}-top{self.distrib.top_k}'
156
+ if key in item:
157
+ probs = item[key]
158
+ else:
159
+ toplogprobs = item["toplogprobs"]
160
+ probs = [self.distrib.estimate_distrib_token(v) for v in toplogprobs]
161
+ item[key] = probs
162
+ return np.array(probs)
163
+
164
+ class PdeFastDetectGPT(PdeBase):
165
+ def __call__(self, item):
166
+ logprobs = item["logprobs"]
167
+ probs = self.estimate_distrib_sequence(item)
168
+ log_likelihood = np.array(logprobs)
169
+ lprobs = np.nan_to_num(np.log(probs))
170
+ mean_ref = (probs * lprobs).sum(axis=-1)
171
+ lprobs2 = np.nan_to_num(np.square(lprobs))
172
+ var_ref = (probs * lprobs2).sum(axis=-1) - np.square(mean_ref)
173
+ discrepancy = (log_likelihood.sum(axis=-1) - mean_ref.sum(axis=-1)) / np.sqrt(var_ref.sum(axis=-1))
174
+ discrepancy = discrepancy.mean()
175
+ return discrepancy.item()
176
+
177
+
178
+ # the detector
179
+ class Glimpse(DetectorBase):
180
+ def __init__(self, config_name):
181
+ super().__init__(config_name)
182
+ self.gpt = OpenAIGPT(self.config)
183
+ self.criterion_fn = PdeFastDetectGPT(GeometricDistribution(self.config.top_k, self.config.rank_size))
184
+
185
+ def compute_crit(self, text):
186
+ tokens, logprobs, toplogprobs = self.gpt.eval(text)
187
+ result = { 'text': text, 'tokens': tokens,
188
+ 'logprobs': logprobs, 'toplogprobs': toplogprobs}
189
+ crit = self.criterion_fn(result)
190
+ return crit, len(tokens)
log/log.txt ADDED
File without changes
metrics.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Guangsheng Bao.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.metrics import roc_curve, precision_recall_curve, auc
8
+
9
+ # 15 colorblind-friendly colors
10
+ COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
11
+ "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
12
+ "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]
13
+
14
+
15
+ def get_roc_metrics(real_preds, sample_preds):
16
+ fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
17
+ roc_auc = auc(fpr, tpr)
18
+ return fpr.tolist(), tpr.tolist(), float(roc_auc)
19
+
20
+
21
+ def get_precision_recall_metrics(real_preds, sample_preds):
22
+ precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds),
23
+ real_preds + sample_preds)
24
+ pr_auc = auc(recall, precision)
25
+ return precision.tolist(), recall.tolist(), float(pr_auc)
26
+
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Guangsheng Bao.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import torch
8
+ import time
9
+ import os
10
+
11
+ def from_pretrained(cls, model_name, kwargs, cache_dir):
12
+ # use local model if it exists
13
+ local_path = os.path.join(cache_dir, model_name.replace("/", "_"))
14
+ if os.path.exists(local_path):
15
+ return cls.from_pretrained(local_path, **kwargs)
16
+ return cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir)
17
+
18
+ # predefined models
19
+ model_fullnames = { 'gpt2': 'gpt2',
20
+ 'gpt2-xl': 'gpt2-xl',
21
+ 'opt-2.7b': 'facebook/opt-2.7b',
22
+ 'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B',
23
+ 'gpt-j-6B': 'EleutherAI/gpt-j-6B',
24
+ 'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',
25
+ 'mgpt': 'sberbank-ai/mGPT',
26
+ 'pubmedgpt': 'stanford-crfm/pubmedgpt',
27
+ 'mt5-xl': 'google/mt5-xl',
28
+ 'llama-13b': 'huggyllama/llama-13b',
29
+ 'llama2-13b': 'TheBloke/Llama-2-13B-fp16',
30
+ 'bloom-7b1': 'bigscience/bloom-7b1',
31
+ 'opt-13b': 'facebook/opt-13b',
32
+ 'falcon-7b': 'falcon-7b',
33
+ 'falcon-7b-instruct': 'falcon-7b-instruct',
34
+ }
35
+ float16_models = ['gpt-neo-2.7B', 'gpt-j-6B', 'gpt-neox-20b', 'llama-13b', 'llama2-13b', 'bloom-7b1', 'opt-13b',
36
+ 'falcon-7b', 'falcon-7b-instruct']
37
+
38
+ def get_model_fullname(model_name):
39
+ return model_fullnames[model_name] if model_name in model_fullnames else model_name
40
+
41
+ def load_model(model_name, device, cache_dir, is_half=False):
42
+ model_fullname = get_model_fullname(model_name)
43
+ print(f'Loading model {model_fullname}...')
44
+ model_kwargs = {}
45
+ if model_name in float16_models:
46
+ model_kwargs.update(dict(torch_dtype=torch.float16))
47
+ if 'gpt-j' in model_name:
48
+ model_kwargs.update(dict(revision='float16'))
49
+ model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
50
+ print('Moving model to GPU...', end='', flush=True)
51
+ start = time.time()
52
+ if is_half:
53
+ model.half()
54
+ model.to(device)
55
+ print(f'DONE ({time.time() - start:.2f}s)')
56
+ return model
57
+
58
+ def load_tokenizer(model_name, cache_dir):
59
+ model_fullname = get_model_fullname(model_name)
60
+ optional_tok_kwargs = {}
61
+ if "facebook/opt-" in model_fullname:
62
+ print("Using non-fast tokenizer for OPT")
63
+ optional_tok_kwargs['fast'] = False
64
+ optional_tok_kwargs['padding_side'] = 'right'
65
+ base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
66
+ if base_tokenizer.pad_token_id is None:
67
+ base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
68
+ if '13b' in model_fullname:
69
+ base_tokenizer.pad_token_id = 0
70
+ return base_tokenizer
71
+
72
+
73
+ if __name__ == '__main__':
74
+ import argparse
75
+ parser = argparse.ArgumentParser()
76
+ parser.add_argument('--model_name', type=str, default="bloom-7b1")
77
+ parser.add_argument('--cache_dir', type=str, default="../cache")
78
+ args = parser.parse_args()
79
+
80
+ load_tokenizer(args.model_name, args.cache_dir)
81
+ load_model(args.model_name, 'cpu', args.cache_dir)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Flask==3.0.2
2
+ Flask_Cors==4.0.0
3
+ numpy==1.23.5
4
+ openai==1.56.1
5
+ httpx==0.27.2
6
+ weave
run.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ while true
4
+ do
5
+ echo `date`, START
6
+ python api.py
7
+ done
8
+
utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Guangsheng Bao.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import json
6
+ from io import open
7
+
8
+ def load_json(filename):
9
+ with open(filename, encoding='utf-8') as f:
10
+ return json.load(f)
11
+
12
+ def save_json(filename, data):
13
+ with open(filename, 'w', encoding='utf-8') as f:
14
+ json.dump(data, f, indent=2, ensure_ascii=False)
15
+
16
+ def load_text(filename):
17
+ with open(filename) as f:
18
+ return f.read()
19
+
20
+ def save_text(filename, text):
21
+ with open(filename, 'w') as f:
22
+ f.write(text)