shellyTa commited on
Commit
7d6e89f
·
1 Parent(s): 2db80ec

setting space

Browse files
Files changed (2) hide show
  1. app.py +986 -0
  2. requirements.txt +17 -0
app.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompter."""
2
+
3
+ import asyncio
4
+ import importlib
5
+ import logging
6
+ import os
7
+ import string
8
+ import sys
9
+
10
+ import aiohttp
11
+ import cohere
12
+ import numpy as np
13
+ import pandas as pd
14
+ import streamlit as st
15
+ from datasets import load_dataset
16
+ from datasets.features import ClassLabel
17
+ from huggingface_hub import AsyncInferenceClient, dataset_info, model_info
18
+ from huggingface_hub.utils import (
19
+ HfHubHTTPError,
20
+ HFValidationError,
21
+ RepositoryNotFoundError,
22
+ )
23
+ from imblearn.under_sampling import RandomUnderSampler
24
+ from sklearn.metrics import (
25
+ ConfusionMatrixDisplay,
26
+ accuracy_score,
27
+ balanced_accuracy_score,
28
+ confusion_matrix,
29
+ matthews_corrcoef,
30
+ )
31
+ from sklearn.model_selection import StratifiedShuffleSplit
32
+ from spacy.lang.en import English
33
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
34
+ from transformers import pipeline
35
+
36
+ HOW_OPENAI_INITIATED = None
37
+
38
+ LOGGER = logging.getLogger(__name__)
39
+
40
+ TITLE = "Prompter"
41
+
42
+ OPENAI_API_KEY = st.secrets.get("openai_api_key", None)
43
+ TOGETHER_API_KEY = st.secrets.get("together_api_key", None)
44
+ HF_TOKEN = st.secrets.get("hf_token", None)
45
+ COHERE_API_KEY = st.secrets.get("cohere_api_key", None)
46
+ AZURE_OPENAI_KEY = st.secrets.get("azure_openai_key", None)
47
+ AZURE_OPENAI_ENDPOINT = st.secrets.get("azure_openai_endpoint", None)
48
+ AZURE_DEPLOYMENT_NAME = st.secrets.get("azure_deployment_name", None)
49
+
50
+ HF_MODEL = os.environ.get("FM_MODEL", "")
51
+
52
+ HF_DATASET = os.environ.get("FM_HF_DATASET", "")
53
+
54
+ DATASET_SPLIT_SEED = os.environ.get("FM_DATASET_SPLIT_SEED", "")
55
+ TRAIN_SIZE = 15
56
+ TEST_SIZE = 25
57
+ BALANCING = True
58
+
59
+ RETRY_MIN_WAIT = 1
60
+ RETRY_MAX_WAIT = 60
61
+ RETRY_MAX_ATTEMPTS = 6
62
+
63
+ PROMPT_TEXT_HEIGHT = 300
64
+
65
+ UNKNOWN_LABEL = "Unknown"
66
+
67
+ SEARCH_ROW_DICT = {"First": 0, "Last": -1}
68
+
69
+ # TODO: Change start temperature to 0.0 when HF supports it
70
+ GENERATION_CONFIG_PARAMS = {
71
+ "temperature": {
72
+ "NAME": "Temperature",
73
+ "START": 0.1,
74
+ "END": 5.0,
75
+ "DEFAULT": 1.0,
76
+ "STEP": 0.1,
77
+ "SAMPLING": True,
78
+ },
79
+ "top_k": {
80
+ "NAME": "Top K",
81
+ "START": 0,
82
+ "END": 100,
83
+ "DEFAULT": 0,
84
+ "STEP": 10,
85
+ "SAMPLING": True,
86
+ },
87
+ "top_p": {
88
+ "NAME": "Top P",
89
+ "START": 0.1,
90
+ "END": 1.0,
91
+ "DEFAULT": 1.0,
92
+ "STEP": 0.1,
93
+ "SAMPLING": True,
94
+ },
95
+ "max_new_tokens": {
96
+ "NAME": "Max New Tokens",
97
+ "START": 16,
98
+ "END": 1024,
99
+ "DEFAULT": 16,
100
+ "STEP": 16,
101
+ "SAMPLING": False,
102
+ },
103
+ "do_sample": {
104
+ "NAME": "Sampling",
105
+ "DEFAULT": False,
106
+ },
107
+ "stop_sequences": {
108
+ "NAME": "Stop Sequences",
109
+ "DEFAULT": os.environ.get("FM_STOP_SEQUENCES", "").split(),
110
+ "SAMPLING": False,
111
+ },
112
+ }
113
+
114
+ GENERATION_CONFIG_DEFAULTS = {
115
+ key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items()
116
+ }
117
+
118
+ st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed")
119
+
120
+
121
+ def get_processing_tokenizer():
122
+ return English().tokenizer
123
+
124
+
125
+ PROCESSING_TOKENIZER = get_processing_tokenizer()
126
+
127
+
128
+ class OpenAIAlreadyInitiatedError(Exception):
129
+ """OpenAIAlreadyInitiatedError."""
130
+
131
+ pass
132
+
133
+
134
+ def prepare_huggingface_generation_config(generation_config):
135
+ generation_config = generation_config.copy()
136
+
137
+ # Reference for decoding stratagies:
138
+ # https://huggingface.co/docs/transformers/generation_strategies
139
+
140
+ # `text_generation_interface`
141
+ # Currenly supports only `greedy` amd `sampling` decoding strategies
142
+ # Following , we add `do_sample` if any of the other
143
+ # samling related parameters are set
144
+ # https://github.com/huggingface/text-generation-inference/blob/e943a294bca239e26828732dd6ab5b6f95dadd0a/server/text_generation_server/utils/tokens.py#L46
145
+
146
+ # `transformers`
147
+ # According to experimentations, it seems that `transformers` behave similarly
148
+
149
+ # I'm not sure what is the right behavior here, but it is better to be explicit
150
+ for name, params in GENERATION_CONFIG_PARAMS.items():
151
+ # Checking for START to examine the a slider parameters only
152
+ if (
153
+ "START" in params
154
+ and params["SAMPLING"]
155
+ and name in generation_config
156
+ and generation_config[name] is not None
157
+ ):
158
+ if generation_config[name] == params["DEFAULT"]:
159
+ generation_config[name] = None
160
+ else:
161
+ assert generation_config["do_sample"]
162
+
163
+ # TODO: refactor this part
164
+ if generation_config["is_chat"]:
165
+ generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
166
+
167
+ generation_config["stop"] = generation_config.pop("stop_sequences")
168
+ del generation_config["do_sample"]
169
+ del generation_config["top_k"]
170
+
171
+ is_chat = generation_config.pop("is_chat")
172
+
173
+ return generation_config, is_chat
174
+
175
+
176
+ def escape_markdown(text):
177
+ escape_dict = {
178
+ "*": r"\*",
179
+ "_": r"\_",
180
+ "{": r"\{",
181
+ "}": r"\}",
182
+ "[": r"\[",
183
+ "]": r"\]",
184
+ "(": r"\(",
185
+ ")": r"\)",
186
+ "+": r"\+",
187
+ "-": r"\-",
188
+ ".": r"\.",
189
+ "!": r"\!",
190
+ "`": r"\`",
191
+ ">": r"\>",
192
+ "|": r"\|",
193
+ "#": r"\#",
194
+ }
195
+ return "".join([escape_dict.get(c, c) for c in text])
196
+
197
+
198
+ def reload_module(name):
199
+ if name in sys.modules:
200
+ del sys.modules[name]
201
+ return importlib.import_module(name)
202
+
203
+
204
+ def build_api_call_function(model):
205
+ global HOW_OPENAI_INITIATED
206
+
207
+ if any(
208
+ model.startswith(known_providers)
209
+ for known_providers in ("openai", "azure", "together")
210
+ ):
211
+ provider, model = model.split("/", maxsplit=1)
212
+
213
+ if provider == "openai":
214
+ from openai import AsyncOpenAI
215
+
216
+ aclient = AsyncOpenAI(api_key=OPENAI_API_KEY)
217
+
218
+ elif provider == "azure":
219
+ from openai import AsyncAzureOpenAI
220
+
221
+ aclient = AsyncAzureOpenAI(
222
+ # https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
223
+ api_version="2023-07-01-preview",
224
+ # https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource
225
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
226
+ )
227
+
228
+ elif provider == "together":
229
+ from openai import AsyncOpenAI
230
+
231
+ aclient = AsyncOpenAI(
232
+ api_key=TOGETHER_API_KEY, base_url="https://api.together.xyz/v1"
233
+ )
234
+
235
+ if provider in ("openai", "azure"):
236
+
237
+ async def list_models():
238
+ return [model async for model in aclient.models.list()]
239
+
240
+ openai_models = {model_obj.id for model_obj in asyncio.run(list_models())}
241
+ assert model in openai_models
242
+
243
+ @retry(
244
+ wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
245
+ stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
246
+ reraise=True,
247
+ )
248
+ async def api_call_function(prompt, generation_config):
249
+ temperature = (
250
+ generation_config["temperature"]
251
+ if generation_config["do_sample"]
252
+ else 0
253
+ )
254
+ top_p = generation_config["top_p"] if generation_config["do_sample"] else 1
255
+ max_tokens = generation_config["max_new_tokens"]
256
+
257
+ if (
258
+ model.startswith("gpt") and "instruct" not in model
259
+ ) or provider == "together":
260
+ response = await aclient.chat.completions.create(
261
+ model=model,
262
+ messages=[{"role": "user", "content": prompt}],
263
+ temperature=temperature,
264
+ top_p=top_p,
265
+ max_tokens=max_tokens,
266
+ )
267
+ assert response.choices[0].message.role == "assistant"
268
+ output = response.choices[0].message.content
269
+
270
+ else:
271
+ response = await aclient.completions.create(
272
+ model=model,
273
+ prompt=prompt,
274
+ temperature=temperature,
275
+ top_p=top_p,
276
+ max_tokens=max_tokens,
277
+ )
278
+ output = response.choices[0].text
279
+
280
+ try:
281
+ length = response.usage.total_tokens
282
+ except AttributeError:
283
+ length = None
284
+
285
+ return output, length
286
+
287
+ elif model.startswith("cohere"):
288
+ _, model = model.split("/")
289
+
290
+ @retry(
291
+ wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
292
+ stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
293
+ reraise=True,
294
+ )
295
+ async def api_call_function(prompt, generation_config):
296
+ async with cohere.AsyncClient(COHERE_API_KEY) as co:
297
+ response = await co.generate(
298
+ model=model,
299
+ prompt=prompt,
300
+ temperature=generation_config["temperature"]
301
+ if generation_config["do_sample"]
302
+ else 0,
303
+ p=generation_config["top_p"]
304
+ if generation_config["do_sample"]
305
+ else 1,
306
+ k=generation_config["top_k"]
307
+ if generation_config["do_sample"]
308
+ else 0,
309
+ max_tokens=generation_config["max_new_tokens"],
310
+ end_sequences=generation_config["stop_sequences"],
311
+ )
312
+
313
+ output = response.generations[0].text
314
+ length = None
315
+
316
+ return output, length
317
+
318
+ elif model.startswith("@"):
319
+ model = model[1:]
320
+ pipe = pipeline(
321
+ "text-generation", model=model, trust_remote_code=True, device_map="auto"
322
+ )
323
+
324
+ async def api_call_function(prompt, generation_config):
325
+ generation_config, _ = prepare_huggingface_generation_config(
326
+ generation_config
327
+ )
328
+
329
+ # TODO: include chat
330
+ output = pipe(prompt, return_text=True, **generation_config)[0][
331
+ "generated_text"
332
+ ]
333
+ output = output[len(prompt) :]
334
+
335
+ length = None
336
+
337
+ return output, length
338
+
339
+ else:
340
+
341
+ @retry(
342
+ wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT),
343
+ stop=stop_after_attempt(RETRY_MAX_ATTEMPTS),
344
+ reraise=True,
345
+ )
346
+ async def api_call_function(prompt, generation_config):
347
+ hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
348
+
349
+ generation_config, is_chat = prepare_huggingface_generation_config(
350
+ generation_config
351
+ )
352
+
353
+ if is_chat:
354
+ messages = [{"role": "user", "content": prompt}]
355
+ response = await hf_client.chat_completion(
356
+ messages, stream=False, **generation_config
357
+ )
358
+ output = response.choices[0].message.content
359
+ length = None
360
+
361
+ else:
362
+ response = await hf_client.text_generation(
363
+ prompt, stream=False, details=True, **generation_config
364
+ )
365
+
366
+ length = (
367
+ len(response.details.prefill) + len(response.details.tokens)
368
+ if response.details is not None
369
+ else None
370
+ )
371
+
372
+ output = response.generated_text
373
+
374
+ # TODO: refactor to support stop of chats
375
+ # Remove stop sequences from the output
376
+ # Inspired by
377
+ # https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
378
+ # https://huggingface.co/spaces/tiiuae/falcon-chat/blob/main/app.py
379
+ if (
380
+ "stop_sequences" in generation_config
381
+ and generation_config["stop_sequences"] is not None
382
+ ):
383
+ for stop_sequence in generation_config["stop_sequences"]:
384
+ output = output.rsplit(stop_sequence, maxsplit=1)[0]
385
+
386
+ return output, length
387
+
388
+ return api_call_function
389
+
390
+
391
+ def strip_newline_space(text):
392
+ return text.strip("\n").strip()
393
+
394
+
395
+ def normalize(text):
396
+ return strip_newline_space(text).lower().capitalize()
397
+
398
+
399
+ def prepare_datasets(
400
+ dataset_name,
401
+ take_split="train",
402
+ train_size=TRAIN_SIZE,
403
+ test_size=TEST_SIZE,
404
+ balancing=BALANCING,
405
+ dataset_split_seed=None,
406
+ ):
407
+ try:
408
+ ds = load_dataset(dataset_name, trust_remote_code=True)
409
+ except FileNotFoundError as e:
410
+ try:
411
+ assert "/" in dataset_name
412
+ dataset_name, subset_name = dataset_name.rsplit("/", 1)
413
+ ds = load_dataset(dataset_name, subset_name, trust_remote_code=True)
414
+ except (FileNotFoundError, AssertionError):
415
+ st.error(f"Dataset `{dataset_name}` not found.")
416
+ st.stop()
417
+
418
+ label_columns = [
419
+ (name, info)
420
+ for name, info in ds["train"].features.items()
421
+ if isinstance(info, ClassLabel)
422
+ ]
423
+ assert len(label_columns) == 1
424
+ label_column, label_column_info = label_columns[0]
425
+ labels = [normalize(label) for label in label_column_info.names]
426
+ label_dict = dict(enumerate(labels))
427
+
428
+ if any(len(PROCESSING_TOKENIZER(label)) > 1 for label in labels):
429
+ st.error(
430
+ "Labels are not single words. "
431
+ "Matching labels won't not work as expected."
432
+ )
433
+
434
+ original_input_columns = [
435
+ name
436
+ for name, info in ds["train"].features.items()
437
+ if not isinstance(info, ClassLabel) and info.dtype == "string"
438
+ ]
439
+
440
+ input_columns = []
441
+ for input_column in original_input_columns:
442
+ lowered_input_column = input_column.lower()
443
+ if input_column != lowered_input_column:
444
+ ds = ds.rename_column(input_column, lowered_input_column)
445
+ input_columns.append(lowered_input_column)
446
+
447
+ df = ds[take_split].to_pandas()
448
+ for input_column in input_columns:
449
+ df[input_column] = df[input_column].apply(strip_newline_space)
450
+ df[label_column] = df[label_column].replace(label_dict)
451
+
452
+ df = df[[label_column] + input_columns]
453
+
454
+ if train_size is not None and test_size is not None:
455
+ undersample = RandomUnderSampler(
456
+ sampling_strategy="not minority", random_state=dataset_split_seed
457
+ )
458
+ df, df[label_column] = undersample.fit_resample(df, df[label_column])
459
+ sss = StratifiedShuffleSplit(
460
+ n_splits=1,
461
+ train_size=train_size,
462
+ test_size=test_size,
463
+ random_state=dataset_split_seed,
464
+ )
465
+ train_index, test_index = next(iter(sss.split(df, df[label_column])))
466
+
467
+ train_df = df.iloc[train_index]
468
+ test_df = df.iloc[test_index]
469
+
470
+ dfs = {"train": train_df, "test": test_df}
471
+
472
+ else:
473
+ dfs = {take_split: df}
474
+
475
+ return dataset_name, dfs, input_columns, label_column, labels
476
+
477
+
478
+ async def complete(api_call_function, prompt, generation_config=None):
479
+ if generation_config is None:
480
+ generation_config = {}
481
+
482
+ LOGGER.info(f"API Call\n\n``{prompt}``\n\n{generation_config=}")
483
+
484
+ output, length = await api_call_function(prompt, generation_config)
485
+
486
+ return output, length
487
+
488
+
489
+ async def infer(api_call_function, prompt_template, inputs, generation_config=None):
490
+ prompt = prompt_template.format(**inputs)
491
+ output, length = await complete(api_call_function, prompt, generation_config)
492
+ return output, prompt, length
493
+
494
+
495
+ async def infer_multi(
496
+ api_call_function, prompt_template, inputs_df, generation_config=None
497
+ ):
498
+ results = await asyncio.gather(
499
+ *(
500
+ infer(
501
+ api_call_function, prompt_template, inputs.to_dict(), generation_config
502
+ )
503
+ for _, inputs in inputs_df.iterrows()
504
+ )
505
+ )
506
+
507
+ return zip(*results)
508
+
509
+
510
+ def preprocess_output_line(text):
511
+ return [
512
+ normalize(token_str)
513
+ for token in PROCESSING_TOKENIZER(text)
514
+ if (token_str := str(token))
515
+ ]
516
+
517
+
518
+ # Inspired by OpenAI depcriated classification endpoint API
519
+ # They take the label from the first line of the output
520
+ # https://github.com/openai/openai-cookbook/blob/main/transition_guides_for_deprecated_API_endpoints/classification_functionality_example.py
521
+ # https://help.openai.com/en/articles/6272941-classifications-transition-guide#h_e63b71a5c8
522
+ # Here we take the label from either the *first* or *last* (for CoT) line of the output
523
+ # This is not very robust, but it's a start that doesn't requires asking for a structured output such as JSON
524
+ # HELM has more robust processing options, we are not using them, but these are the references:
525
+ # https://github.com/stanford-crfm/helm/blob/04a75826ce75835f6d22a7d41ae1487104797964/src/helm/benchmark/metrics/classification_metrics.py
526
+ # https://github.com/stanford-crfm/helm/blob/04a75826ce75835f6d22a7d41ae1487104797964/src/helm/benchmark/metrics/basic_metrics.py
527
+ def canonize_label(output, annotation_labels, search_row):
528
+ assert search_row in SEARCH_ROW_DICT.keys()
529
+
530
+ search_row_index = SEARCH_ROW_DICT[search_row]
531
+
532
+ annotation_labels_set = set(annotation_labels)
533
+
534
+ output_lines = strip_newline_space(output).split("\n")
535
+ output_search_words = preprocess_output_line(output_lines[search_row_index])
536
+
537
+ label_matches = set(output_search_words) & annotation_labels_set
538
+
539
+ if len(label_matches) == 1:
540
+ return next(iter(label_matches))
541
+ else:
542
+ return UNKNOWN_LABEL
543
+
544
+
545
+ def measure(dataset, outputs, labels, label_column, input_columns, search_row):
546
+ inferences = [canonize_label(output, labels, search_row) for output in outputs]
547
+
548
+ LOGGER.info(f"{inferences=}")
549
+ LOGGER.info(f"{labels=}")
550
+ inference_labels = labels + [UNKNOWN_LABEL]
551
+
552
+ evaluation_df = pd.DataFrame(
553
+ {
554
+ "hit/miss": np.where(dataset[label_column] == inferences, "hit", "miss"),
555
+ "annotation": dataset[label_column],
556
+ "inference": inferences,
557
+ "output": outputs,
558
+ }
559
+ | dataset[input_columns].to_dict("list")
560
+ )
561
+
562
+ unknown_proportion = (evaluation_df["inference"] == UNKNOWN_LABEL).mean()
563
+
564
+ acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"])
565
+ bacc = balanced_accuracy_score(
566
+ evaluation_df["annotation"], evaluation_df["inference"]
567
+ )
568
+ mcc = matthews_corrcoef(evaluation_df["annotation"], evaluation_df["inference"])
569
+ cm = confusion_matrix(
570
+ evaluation_df["annotation"], evaluation_df["inference"], labels=inference_labels
571
+ )
572
+
573
+ cm_display = ConfusionMatrixDisplay(cm, display_labels=inference_labels)
574
+ cm_display.plot()
575
+ cm_display.ax_.set_xlabel("Inference Labels")
576
+ cm_display.ax_.set_ylabel("Annotation Labels")
577
+ cm_display.figure_.autofmt_xdate(rotation=45)
578
+
579
+ metrics = {
580
+ "unknown_proportion": unknown_proportion,
581
+ "accuracy": acc,
582
+ "balanced_accuracy": bacc,
583
+ "mcc": mcc,
584
+ "confusion_matrix": cm,
585
+ "confusion_matrix_display": cm_display.figure_,
586
+ "hit_miss": evaluation_df,
587
+ "annotation_labels": labels,
588
+ "inference_labels": inference_labels,
589
+ }
590
+
591
+ return metrics
592
+
593
+
594
+ def run_evaluation(
595
+ api_call_function,
596
+ prompt_template,
597
+ dataset,
598
+ labels,
599
+ label_column,
600
+ input_columns,
601
+ search_row,
602
+ generation_config=None,
603
+ ):
604
+ inputs_df = dataset[input_columns]
605
+ outputs, prompts, lengths = asyncio.run(
606
+ infer_multi(
607
+ api_call_function,
608
+ prompt_template,
609
+ inputs_df,
610
+ generation_config,
611
+ )
612
+ )
613
+
614
+ metrics = measure(dataset, outputs, labels, label_column, input_columns, search_row)
615
+
616
+ metrics["hit_miss"]["prompt"] = prompts
617
+ metrics["hit_miss"]["length"] = lengths
618
+
619
+ return metrics
620
+
621
+
622
+ def combine_labels(labels):
623
+ return "|".join(f"``{label}``" for label in labels)
624
+
625
+
626
+ def main():
627
+ try:
628
+ if "dataset_split_seed" not in st.session_state:
629
+ st.session_state["dataset_split_seed"] = (
630
+ int(DATASET_SPLIT_SEED) if DATASET_SPLIT_SEED else None
631
+ )
632
+
633
+ if "train_size" not in st.session_state:
634
+ st.session_state["train_size"] = TRAIN_SIZE
635
+
636
+ if "test_size" not in st.session_state:
637
+ st.session_state["test_size"] = TEST_SIZE
638
+
639
+ if "api_call_function" not in st.session_state:
640
+ st.session_state["api_call_function"] = build_api_call_function(
641
+ model=HF_MODEL,
642
+ )
643
+
644
+ if "train_dataset" not in st.session_state:
645
+ (
646
+ st.session_state["dataset_name"],
647
+ splits_df,
648
+ st.session_state["input_columns"],
649
+ st.session_state["label_column"],
650
+ st.session_state["labels"],
651
+ ) = prepare_datasets(
652
+ HF_DATASET,
653
+ train_size=st.session_state.train_size,
654
+ test_size=st.session_state.test_size,
655
+ dataset_split_seed=st.session_state.dataset_split_seed,
656
+ )
657
+
658
+ for split in splits_df:
659
+ st.session_state[f"{split}_dataset"] = splits_df[split]
660
+
661
+ if "generation_config" not in st.session_state:
662
+ st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS
663
+
664
+ except Exception as e:
665
+ st.error(e)
666
+
667
+ st.title(TITLE)
668
+
669
+ with st.sidebar:
670
+ with st.form("model_form"):
671
+ model = st.text_input("Model", HF_MODEL).strip()
672
+
673
+ # Defautlt values from:
674
+ # https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation
675
+ # Edges values from:
676
+ # https://docs.cohere.com/reference/generate
677
+ # https://platform.openai.com/playground
678
+
679
+ generation_config_sliders = {
680
+ name: st.slider(
681
+ params["NAME"],
682
+ params["START"],
683
+ params["END"],
684
+ params["DEFAULT"],
685
+ params["STEP"],
686
+ )
687
+ for name, params in GENERATION_CONFIG_PARAMS.items()
688
+ if "START" in params
689
+ }
690
+
691
+ do_sample = st.checkbox(
692
+ GENERATION_CONFIG_PARAMS["do_sample"]["NAME"],
693
+ value=GENERATION_CONFIG_PARAMS["do_sample"]["DEFAULT"],
694
+ )
695
+
696
+ stop_sequences = st.text_area(
697
+ GENERATION_CONFIG_PARAMS["stop_sequences"]["NAME"],
698
+ value="\n".join(GENERATION_CONFIG_PARAMS["stop_sequences"]["DEFAULT"]),
699
+ )
700
+
701
+ stop_sequences = [
702
+ clean_stop.encode().decode("unicode_escape") # interpret \n as newline
703
+ for stop in stop_sequences.split("\n")
704
+ if (clean_stop := stop.strip())
705
+ ]
706
+ if not stop_sequences:
707
+ stop_sequences = None
708
+
709
+ decoding_seed = st.text_input("Decoding Seed").strip()
710
+
711
+ st.divider()
712
+
713
+ dataset = st.text_input("Dataset", HF_DATASET).strip()
714
+
715
+ train_size = st.number_input("Train Size", value=TRAIN_SIZE, min_value=10)
716
+ test_size = st.number_input("Test Size", value=TEST_SIZE, min_value=10)
717
+
718
+ balancing = st.checkbox("Balancing", BALANCING)
719
+
720
+ dataset_split_seed = st.text_input(
721
+ "Dataset Split Seed", DATASET_SPLIT_SEED
722
+ ).strip()
723
+
724
+ st.divider()
725
+
726
+ submitted = st.form_submit_button("Set")
727
+
728
+ if submitted:
729
+ if not dataset:
730
+ st.error("Dataset must be specified.")
731
+ st.stop()
732
+
733
+ if not model:
734
+ st.error("Model must be specified.")
735
+ st.stop()
736
+
737
+ if not decoding_seed:
738
+ decoding_seed = None
739
+ elif seed.isnumeric():
740
+ decoding_seed = int(seed)
741
+ else:
742
+ st.error("Seed must be numeric or empty.")
743
+ st.stop()
744
+
745
+ generation_confing_slider_sampling = {
746
+ name: value
747
+ for name, value in generation_config_sliders.items()
748
+ if GENERATION_CONFIG_PARAMS[name]["SAMPLING"]
749
+ }
750
+ if (
751
+ any(
752
+ value != GENERATION_CONFIG_DEFAULTS[name]
753
+ for name, value in generation_confing_slider_sampling.items()
754
+ )
755
+ and not do_sample
756
+ ):
757
+ sampling_slider_default_values_info = " | ".join(
758
+ f"{name}={GENERATION_CONFIG_DEFAULTS[name]}"
759
+ for name in generation_confing_slider_sampling
760
+ )
761
+ st.error(
762
+ f"Sampling must be enabled to use non default values for generation parameters: {sampling_slider_default_values_info}"
763
+ )
764
+ st.stop()
765
+
766
+ if decoding_seed is not None and not do_sample:
767
+ st.error(
768
+ "Sampling must be enabled to use a decoding seed. Otherwise, the seed field should be empty."
769
+ )
770
+ st.stop()
771
+
772
+ if not dataset_split_seed:
773
+ dataset_split_seed = None
774
+ elif dataset_split_seed.isnumeric():
775
+ dataset_split_seed = int(dataset_split_seed)
776
+ else:
777
+ st.error("Dataset split seed must be numeric or empty.")
778
+ st.stop()
779
+
780
+ generation_config = generation_config_sliders | dict(
781
+ do_sample=do_sample,
782
+ stop_sequences=stop_sequences,
783
+ seed=decoding_seed,
784
+ )
785
+
786
+ st.session_state["dataset_split_seed"] = dataset_split_seed
787
+ st.session_state["train_size"] = train_size
788
+ st.session_state["test_size"] = test_size
789
+
790
+ try:
791
+ st.session_state["api_call_function"] = build_api_call_function(
792
+ model=model,
793
+ )
794
+ except OpenAIAlreadyInitiatedError as e:
795
+ st.error(e)
796
+ st.stop()
797
+
798
+ st.session_state["generation_config"] = generation_config
799
+
800
+ (
801
+ st.session_state["dataset_name"],
802
+ splits_df,
803
+ st.session_state["input_columns"],
804
+ st.session_state["label_column"],
805
+ st.session_state["labels"],
806
+ ) = prepare_datasets(
807
+ dataset,
808
+ train_size=st.session_state.train_size,
809
+ test_size=st.session_state.test_size,
810
+ balancing=balancing,
811
+ dataset_split_seed=st.session_state.dataset_split_seed,
812
+ )
813
+
814
+ for split in splits_df:
815
+ st.session_state[f"{split}_dataset"] = splits_df[split]
816
+
817
+ LOGGER.info(f"FORM {dataset=}")
818
+ LOGGER.info(f"FORM {model=}")
819
+ LOGGER.info(f"FORM {generation_config=}")
820
+
821
+ with st.expander("Info"):
822
+ try:
823
+ data_card = dataset_info(st.session_state.dataset_name).cardData
824
+ except (HFValidationError, RepositoryNotFoundError):
825
+ pass
826
+ else:
827
+ st.caption("Dataset")
828
+ st.write(data_card)
829
+ try:
830
+ model_info_respose = model_info(model)
831
+ model_card = model_info_respose.cardData
832
+ st.session_state["generation_config"]["is_chat"] = (
833
+ "conversational" in model_info_respose.tags
834
+ )
835
+ except (HFValidationError, RepositoryNotFoundError):
836
+ pass
837
+ else:
838
+ st.caption("Model")
839
+ st.write(model_card)
840
+
841
+ # st.write(f"Model max length: {AutoTokenizer.from_pretrained(model).model_max_length}")
842
+
843
+ tab1, tab2, tab3 = st.tabs(["Evaluation", "Examples", "Playground"])
844
+
845
+ with tab1:
846
+ with st.form("prompt_form"):
847
+ prompt_template = st.text_area("Prompt Template", height=PROMPT_TEXT_HEIGHT)
848
+
849
+ is_multi_placeholder = len(st.session_state.input_columns) > 1
850
+
851
+ st.write(
852
+ f"To determine the inferred label of an input, the model should output one of the following words:"
853
+ f" {combine_labels(st.session_state.labels)}"
854
+ )
855
+ st.write(
856
+ f"The input placeholder{'s' if is_multi_placeholder else ''} available for the prompt template {'are' if is_multi_placeholder else 'is'}:"
857
+ f" {combine_labels(f'{{{col}}}' for col in st.session_state.input_columns)}"
858
+ )
859
+
860
+ col1, col2 = st.columns(2)
861
+
862
+ with col1:
863
+ search_row = st.selectbox(
864
+ "Search label at which row", list(SEARCH_ROW_DICT)
865
+ )
866
+
867
+ with col2:
868
+ submitted = st.form_submit_button("Evaluate")
869
+
870
+ if submitted:
871
+ if not prompt_template:
872
+ st.error("Prompt template must be specified.")
873
+ st.stop()
874
+
875
+ _, formats, *_ = zip(*string.Formatter().parse(prompt_template))
876
+ is_valid_prompt_template = set(formats).issubset(
877
+ {None} | set(st.session_state.input_columns)
878
+ )
879
+
880
+ if not is_valid_prompt_template:
881
+ st.error(f"The prompt template contains unrecognized fields.")
882
+ st.stop()
883
+
884
+ with st.spinner("Executing inference..."):
885
+ try:
886
+ evaluation = run_evaluation(
887
+ st.session_state.api_call_function,
888
+ prompt_template,
889
+ st.session_state.test_dataset,
890
+ st.session_state.labels,
891
+ st.session_state.label_column,
892
+ st.session_state.input_columns,
893
+ search_row,
894
+ st.session_state.generation_config,
895
+ )
896
+ except HfHubHTTPError as e:
897
+ st.error(e)
898
+ st.stop()
899
+
900
+ st.markdown("### Metrics")
901
+ num_metric_cols = 2 if balancing else 4
902
+ cols = st.columns(num_metric_cols)
903
+ with cols[0]:
904
+ st.metric("Accuracy", f"{100 * evaluation['accuracy']:.0f}%")
905
+ st.caption("The percentage of correct inferences.")
906
+ with cols[1]:
907
+ st.metric(
908
+ "Unknown",
909
+ f"{100 * evaluation['unknown_proportion']:.0f}%",
910
+ )
911
+ st.caption(
912
+ "The percentage of inferences"
913
+ " that could not be determined based on the model output."
914
+ )
915
+ if not balancing:
916
+ with cols[2]:
917
+ st.metric(
918
+ "Balanced Accuracy",
919
+ f"{100 * evaluation['balanced_accuracy']:.0f}%",
920
+ )
921
+ with cols[3]:
922
+ st.metric("MCC", f"{evaluation['mcc']:.2f}")
923
+
924
+ st.markdown("### Detailed Evaluation")
925
+
926
+ st.caption(
927
+ "This table showcases all examples (input and output pairs) that were leveraged for the evaluation of the prompt template with the model (for instance, accuracy)."
928
+ " It comprises the input placeholder values, the unmodified model *output*, the deduced *inference*, and the ground-truth *annotation*."
929
+ )
930
+ st.caption(
931
+ "A 'hit' signifies a correct inference (when *inference* coincides with *annotation*), while a 'miss' denotes an incorrect inference."
932
+ " If the *inference* cannot be determined based on the model output, it is labeled as 'unknown'."
933
+ )
934
+ st.caption(
935
+ "The *prompt* column features the complete prompt that the model was prompted to complete, i.e., your prompt template filled with the input placeholders you have used."
936
+ )
937
+ st.caption(
938
+ "You are not allowed to include these examples in your prompt template."
939
+ )
940
+
941
+ st.dataframe(evaluation["hit_miss"])
942
+
943
+ with st.expander("Additional Information", expanded=False):
944
+ st.markdown("## Confusion Matrix")
945
+ st.pyplot(evaluation["confusion_matrix_display"])
946
+
947
+ if evaluation["accuracy"] == 1:
948
+ st.balloons()
949
+
950
+ with tab2:
951
+ st.caption(
952
+ "You can include the following examples in your prompt template for few-shot prompting."
953
+ )
954
+ st.dataframe(st.session_state.train_dataset)
955
+
956
+ with tab3:
957
+ prompt = st.text_area("Prompt", height=PROMPT_TEXT_HEIGHT)
958
+
959
+ submitted = st.button("Complete")
960
+
961
+ if submitted:
962
+ if not prompt:
963
+ st.error("Prompt must be specified.")
964
+ st.stop()
965
+
966
+ with st.spinner("Generating..."):
967
+ try:
968
+ output, length = asyncio.run(
969
+ complete(
970
+ st.session_state.api_call_function,
971
+ prompt,
972
+ st.session_state.generation_config,
973
+ )
974
+ )
975
+ except HfHubHTTPError as e:
976
+ st.error(e)
977
+ st.stop()
978
+ st.markdown(escape_markdown(output))
979
+ if length is not None:
980
+ with st.expander("Stats"):
981
+ st.metric("#Tokens", length)
982
+
983
+
984
+ if __name__ == "__main__":
985
+ logging.basicConfig(level=logging.DEBUG)
986
+ main()
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ aiohttp
3
+ cohere
4
+ datasets
5
+ einops
6
+ huggingface_hub[inference]
7
+ imbalanced-learn
8
+ numpy==1.23.5
9
+ pandas
10
+ matplotlib
11
+ openai
12
+ scikit-learn
13
+ spacy
14
+ streamlit
15
+ tenacity
16
+ torch
17
+ transformers