|
|
"""Prompter.""" |
|
|
|
|
|
import asyncio |
|
|
import importlib |
|
|
import logging |
|
|
import os |
|
|
import string |
|
|
import sys |
|
|
|
|
|
import aiohttp |
|
|
import cohere |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
from datasets import load_dataset |
|
|
from datasets.features import ClassLabel |
|
|
from huggingface_hub import AsyncInferenceClient, dataset_info, model_info |
|
|
from huggingface_hub.utils import ( |
|
|
HfHubHTTPError, |
|
|
HFValidationError, |
|
|
RepositoryNotFoundError, |
|
|
) |
|
|
from imblearn.under_sampling import RandomUnderSampler |
|
|
from sklearn.metrics import ( |
|
|
ConfusionMatrixDisplay, |
|
|
accuracy_score, |
|
|
balanced_accuracy_score, |
|
|
confusion_matrix, |
|
|
matthews_corrcoef, |
|
|
) |
|
|
from sklearn.model_selection import StratifiedShuffleSplit |
|
|
from spacy.lang.en import English |
|
|
from tenacity import retry, stop_after_attempt, wait_random_exponential |
|
|
from transformers import pipeline |
|
|
|
|
|
HOW_OPENAI_INITIATED = None |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
TITLE = "Prompter" |
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.environ.get("openai_api_key") |
|
|
TOGETHER_API_KEY = st.secrets.get("together_api_key", None) |
|
|
HF_TOKEN = st.secrets.get("hf_token", None) |
|
|
COHERE_API_KEY = st.secrets.get("cohere_api_key", None) |
|
|
AZURE_OPENAI_KEY = st.secrets.get("azure_openai_key", None) |
|
|
AZURE_OPENAI_ENDPOINT = st.secrets.get("azure_openai_endpoint", None) |
|
|
AZURE_DEPLOYMENT_NAME = st.secrets.get("azure_deployment_name", None) |
|
|
|
|
|
HF_MODEL = os.environ.get("FM_MODEL", "") |
|
|
|
|
|
HF_DATASET = os.environ.get("FM_HF_DATASET", "") |
|
|
|
|
|
DATASET_SPLIT_SEED = os.environ.get("FM_DATASET_SPLIT_SEED", "") |
|
|
TRAIN_SIZE = 15 |
|
|
TEST_SIZE = 25 |
|
|
BALANCING = True |
|
|
|
|
|
RETRY_MIN_WAIT = 1 |
|
|
RETRY_MAX_WAIT = 60 |
|
|
RETRY_MAX_ATTEMPTS = 6 |
|
|
|
|
|
PROMPT_TEXT_HEIGHT = 300 |
|
|
|
|
|
UNKNOWN_LABEL = "Unknown" |
|
|
|
|
|
SEARCH_ROW_DICT = {"First": 0, "Last": -1} |
|
|
|
|
|
|
|
|
GENERATION_CONFIG_PARAMS = { |
|
|
"temperature": { |
|
|
"NAME": "Temperature", |
|
|
"START": 0.1, |
|
|
"END": 5.0, |
|
|
"DEFAULT": 1.0, |
|
|
"STEP": 0.1, |
|
|
"SAMPLING": True, |
|
|
}, |
|
|
"top_k": { |
|
|
"NAME": "Top K", |
|
|
"START": 0, |
|
|
"END": 100, |
|
|
"DEFAULT": 0, |
|
|
"STEP": 10, |
|
|
"SAMPLING": True, |
|
|
}, |
|
|
"top_p": { |
|
|
"NAME": "Top P", |
|
|
"START": 0.1, |
|
|
"END": 1.0, |
|
|
"DEFAULT": 1.0, |
|
|
"STEP": 0.1, |
|
|
"SAMPLING": True, |
|
|
}, |
|
|
"max_new_tokens": { |
|
|
"NAME": "Max New Tokens", |
|
|
"START": 16, |
|
|
"END": 1024, |
|
|
"DEFAULT": 16, |
|
|
"STEP": 16, |
|
|
"SAMPLING": False, |
|
|
}, |
|
|
"do_sample": { |
|
|
"NAME": "Sampling", |
|
|
"DEFAULT": False, |
|
|
}, |
|
|
"stop_sequences": { |
|
|
"NAME": "Stop Sequences", |
|
|
"DEFAULT": os.environ.get("FM_STOP_SEQUENCES", "").split(), |
|
|
"SAMPLING": False, |
|
|
}, |
|
|
} |
|
|
|
|
|
GENERATION_CONFIG_DEFAULTS = { |
|
|
key: value["DEFAULT"] for key, value in GENERATION_CONFIG_PARAMS.items() |
|
|
} |
|
|
|
|
|
st.set_page_config(page_title=TITLE, initial_sidebar_state="collapsed") |
|
|
|
|
|
|
|
|
def get_processing_tokenizer(): |
|
|
return English().tokenizer |
|
|
|
|
|
|
|
|
PROCESSING_TOKENIZER = get_processing_tokenizer() |
|
|
|
|
|
|
|
|
class OpenAIAlreadyInitiatedError(Exception): |
|
|
"""OpenAIAlreadyInitiatedError.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def prepare_huggingface_generation_config(generation_config): |
|
|
generation_config = generation_config.copy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "is_chat" not in generation_config: |
|
|
generation_config["is_chat"] = False |
|
|
|
|
|
for name, params in GENERATION_CONFIG_PARAMS.items(): |
|
|
if ( |
|
|
"START" in params |
|
|
and params["SAMPLING"] |
|
|
and name in generation_config |
|
|
and generation_config[name] is not None |
|
|
): |
|
|
if generation_config[name] == params["DEFAULT"]: |
|
|
generation_config[name] = None |
|
|
else: |
|
|
assert generation_config.get("do_sample", False) |
|
|
|
|
|
if generation_config.get("is_chat", False): |
|
|
generation_config["max_tokens"] = generation_config.pop("max_new_tokens") |
|
|
generation_config["stop"] = generation_config.pop("stop_sequences") |
|
|
del generation_config["do_sample"] |
|
|
del generation_config["top_k"] |
|
|
|
|
|
|
|
|
is_chat = generation_config.pop("is_chat", False) |
|
|
|
|
|
return generation_config, is_chat |
|
|
|
|
|
|
|
|
def escape_markdown(text): |
|
|
escape_dict = { |
|
|
"*": r"\*", |
|
|
"_": r"\_", |
|
|
"{": r"\{", |
|
|
"}": r"\}", |
|
|
"[": r"\[", |
|
|
"]": r"\]", |
|
|
"(": r"\(", |
|
|
")": r"\)", |
|
|
"+": r"\+", |
|
|
"-": r"\-", |
|
|
".": r"\.", |
|
|
"!": r"\!", |
|
|
"`": r"\`", |
|
|
">": r"\>", |
|
|
"|": r"\|", |
|
|
"#": r"\#", |
|
|
} |
|
|
return "".join([escape_dict.get(c, c) for c in text]) |
|
|
|
|
|
|
|
|
def reload_module(name): |
|
|
if name in sys.modules: |
|
|
del sys.modules[name] |
|
|
return importlib.import_module(name) |
|
|
|
|
|
|
|
|
def build_api_call_function(model): |
|
|
global HOW_OPENAI_INITIATED |
|
|
|
|
|
if any( |
|
|
model.startswith(known_providers) |
|
|
for known_providers in ("openai", "azure", "together") |
|
|
): |
|
|
provider, model = model.split("/", maxsplit=1) |
|
|
|
|
|
if provider == "openai": |
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
aclient = AsyncOpenAI(api_key=OPENAI_API_KEY) |
|
|
|
|
|
elif provider == "azure": |
|
|
from openai import AsyncAzureOpenAI |
|
|
|
|
|
aclient = AsyncAzureOpenAI( |
|
|
|
|
|
api_version="2023-07-01-preview", |
|
|
|
|
|
azure_endpoint=AZURE_OPENAI_ENDPOINT, |
|
|
) |
|
|
|
|
|
elif provider == "together": |
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
aclient = AsyncOpenAI( |
|
|
api_key=TOGETHER_API_KEY, base_url="https://api.together.xyz/v1" |
|
|
) |
|
|
|
|
|
if provider in ("openai", "azure"): |
|
|
|
|
|
async def list_models(): |
|
|
return [model async for model in aclient.models.list()] |
|
|
|
|
|
openai_models = {model_obj.id for model_obj in asyncio.run(list_models())} |
|
|
assert model in openai_models |
|
|
|
|
|
@retry( |
|
|
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT), |
|
|
stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), |
|
|
reraise=True, |
|
|
) |
|
|
async def api_call_function(prompt, generation_config): |
|
|
temperature = ( |
|
|
generation_config["temperature"] |
|
|
if generation_config["do_sample"] |
|
|
else 0 |
|
|
) |
|
|
top_p = generation_config["top_p"] if generation_config["do_sample"] else 1 |
|
|
max_tokens = generation_config["max_new_tokens"] |
|
|
|
|
|
if ( |
|
|
model.startswith("gpt") and "instruct" not in model |
|
|
) or provider == "together": |
|
|
response = await aclient.chat.completions.create( |
|
|
model=model, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
) |
|
|
assert response.choices[0].message.role == "assistant" |
|
|
output = response.choices[0].message.content |
|
|
|
|
|
else: |
|
|
response = await aclient.completions.create( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
) |
|
|
output = response.choices[0].text |
|
|
|
|
|
try: |
|
|
length = response.usage.total_tokens |
|
|
except AttributeError: |
|
|
length = None |
|
|
|
|
|
return output, length |
|
|
|
|
|
elif model.startswith("cohere"): |
|
|
_, model = model.split("/") |
|
|
|
|
|
@retry( |
|
|
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT), |
|
|
stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), |
|
|
reraise=True, |
|
|
) |
|
|
async def api_call_function(prompt, generation_config): |
|
|
async with cohere.AsyncClient(COHERE_API_KEY) as co: |
|
|
response = await co.generate( |
|
|
model=model, |
|
|
prompt=prompt, |
|
|
temperature=generation_config["temperature"] |
|
|
if generation_config["do_sample"] |
|
|
else 0, |
|
|
p=generation_config["top_p"] |
|
|
if generation_config["do_sample"] |
|
|
else 1, |
|
|
k=generation_config["top_k"] |
|
|
if generation_config["do_sample"] |
|
|
else 0, |
|
|
max_tokens=generation_config["max_new_tokens"], |
|
|
end_sequences=generation_config["stop_sequences"], |
|
|
) |
|
|
|
|
|
output = response.generations[0].text |
|
|
length = None |
|
|
|
|
|
return output, length |
|
|
|
|
|
elif model.startswith("@"): |
|
|
model = model[1:] |
|
|
pipe = pipeline( |
|
|
"text-generation", model=model, trust_remote_code=True, device_map="auto" |
|
|
) |
|
|
|
|
|
async def api_call_function(prompt, generation_config): |
|
|
generation_config, _ = prepare_huggingface_generation_config( |
|
|
generation_config |
|
|
) |
|
|
|
|
|
|
|
|
output = pipe(prompt, return_text=True, **generation_config)[0][ |
|
|
"generated_text" |
|
|
] |
|
|
output = output[len(prompt) :] |
|
|
|
|
|
length = None |
|
|
|
|
|
return output, length |
|
|
|
|
|
else: |
|
|
|
|
|
@retry( |
|
|
wait=wait_random_exponential(min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT), |
|
|
stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), |
|
|
reraise=True, |
|
|
) |
|
|
async def api_call_function(prompt, generation_config): |
|
|
hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model) |
|
|
|
|
|
generation_config, is_chat = prepare_huggingface_generation_config( |
|
|
generation_config |
|
|
) |
|
|
|
|
|
if is_chat: |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
response = await hf_client.chat_completion( |
|
|
messages, stream=False, **generation_config |
|
|
) |
|
|
output = response.choices[0].message.content |
|
|
length = None |
|
|
|
|
|
else: |
|
|
response = await hf_client.text_generation( |
|
|
prompt, stream=False, details=True, **generation_config |
|
|
) |
|
|
|
|
|
length = ( |
|
|
len(response.details.prefill) + len(response.details.tokens) |
|
|
if response.details is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
output = response.generated_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
"stop_sequences" in generation_config |
|
|
and generation_config["stop_sequences"] is not None |
|
|
): |
|
|
for stop_sequence in generation_config["stop_sequences"]: |
|
|
output = output.rsplit(stop_sequence, maxsplit=1)[0] |
|
|
|
|
|
return output, length |
|
|
|
|
|
return api_call_function |
|
|
|
|
|
|
|
|
def strip_newline_space(text): |
|
|
return text.strip("\n").strip() |
|
|
|
|
|
|
|
|
def normalize(text): |
|
|
return strip_newline_space(text).lower().capitalize() |
|
|
|
|
|
|
|
|
def prepare_datasets( |
|
|
dataset_name, |
|
|
take_split="train", |
|
|
train_size=TRAIN_SIZE, |
|
|
test_size=TEST_SIZE, |
|
|
balancing=BALANCING, |
|
|
dataset_split_seed=None, |
|
|
): |
|
|
try: |
|
|
ds = load_dataset(dataset_name, trust_remote_code=True) |
|
|
except FileNotFoundError as e: |
|
|
try: |
|
|
assert "/" in dataset_name |
|
|
dataset_name, subset_name = dataset_name.rsplit("/", 1) |
|
|
ds = load_dataset(dataset_name, subset_name, trust_remote_code=True) |
|
|
except (FileNotFoundError, AssertionError): |
|
|
st.error(f"Dataset `{dataset_name}` not found.") |
|
|
st.stop() |
|
|
|
|
|
label_columns = [ |
|
|
(name, info) |
|
|
for name, info in ds["train"].features.items() |
|
|
if isinstance(info, ClassLabel) |
|
|
] |
|
|
assert len(label_columns) == 1 |
|
|
label_column, label_column_info = label_columns[0] |
|
|
labels = [normalize(label) for label in label_column_info.names] |
|
|
label_dict = dict(enumerate(labels)) |
|
|
|
|
|
if any(len(PROCESSING_TOKENIZER(label)) > 1 for label in labels): |
|
|
st.error( |
|
|
"Labels are not single words. " |
|
|
"Matching labels won't not work as expected." |
|
|
) |
|
|
|
|
|
original_input_columns = [ |
|
|
name |
|
|
for name, info in ds["train"].features.items() |
|
|
if not isinstance(info, ClassLabel) and info.dtype == "string" |
|
|
] |
|
|
|
|
|
input_columns = [] |
|
|
for input_column in original_input_columns: |
|
|
lowered_input_column = input_column.lower() |
|
|
if input_column != lowered_input_column: |
|
|
ds = ds.rename_column(input_column, lowered_input_column) |
|
|
input_columns.append(lowered_input_column) |
|
|
|
|
|
df = ds[take_split].to_pandas() |
|
|
for input_column in input_columns: |
|
|
df[input_column] = df[input_column].apply(strip_newline_space) |
|
|
df[label_column] = df[label_column].replace(label_dict) |
|
|
|
|
|
df = df[[label_column] + input_columns] |
|
|
|
|
|
if train_size is not None and test_size is not None: |
|
|
undersample = RandomUnderSampler( |
|
|
sampling_strategy="not minority", random_state=dataset_split_seed |
|
|
) |
|
|
df, df[label_column] = undersample.fit_resample(df, df[label_column]) |
|
|
sss = StratifiedShuffleSplit( |
|
|
n_splits=1, |
|
|
train_size=train_size, |
|
|
test_size=test_size, |
|
|
random_state=dataset_split_seed, |
|
|
) |
|
|
train_index, test_index = next(iter(sss.split(df, df[label_column]))) |
|
|
|
|
|
train_df = df.iloc[train_index] |
|
|
test_df = df.iloc[test_index] |
|
|
|
|
|
dfs = {"train": train_df, "test": test_df} |
|
|
|
|
|
else: |
|
|
dfs = {take_split: df} |
|
|
|
|
|
return dataset_name, dfs, input_columns, label_column, labels |
|
|
|
|
|
|
|
|
async def complete(api_call_function, prompt, generation_config=None): |
|
|
if generation_config is None: |
|
|
generation_config = {} |
|
|
|
|
|
LOGGER.info(f"API Call\n\n``{prompt}``\n\n{generation_config=}") |
|
|
|
|
|
output, length = await api_call_function(prompt, generation_config) |
|
|
|
|
|
return output, length |
|
|
|
|
|
|
|
|
async def infer(api_call_function, prompt_template, inputs, generation_config=None): |
|
|
prompt = prompt_template.format(**inputs) |
|
|
output, length = await complete(api_call_function, prompt, generation_config) |
|
|
return output, prompt, length |
|
|
|
|
|
|
|
|
async def infer_multi( |
|
|
api_call_function, prompt_template, inputs_df, generation_config=None |
|
|
): |
|
|
results = await asyncio.gather( |
|
|
*( |
|
|
infer( |
|
|
api_call_function, prompt_template, inputs.to_dict(), generation_config |
|
|
) |
|
|
for _, inputs in inputs_df.iterrows() |
|
|
) |
|
|
) |
|
|
|
|
|
return zip(*results) |
|
|
|
|
|
|
|
|
def preprocess_output_line(text): |
|
|
return [ |
|
|
normalize(token_str) |
|
|
for token in PROCESSING_TOKENIZER(text) |
|
|
if (token_str := str(token)) |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def canonize_label(output, annotation_labels, search_row): |
|
|
assert search_row in SEARCH_ROW_DICT.keys() |
|
|
|
|
|
search_row_index = SEARCH_ROW_DICT[search_row] |
|
|
|
|
|
annotation_labels_set = set(annotation_labels) |
|
|
|
|
|
output_lines = strip_newline_space(output).split("\n") |
|
|
output_search_words = preprocess_output_line(output_lines[search_row_index]) |
|
|
|
|
|
label_matches = set(output_search_words) & annotation_labels_set |
|
|
|
|
|
if len(label_matches) == 1: |
|
|
return next(iter(label_matches)) |
|
|
else: |
|
|
return UNKNOWN_LABEL |
|
|
|
|
|
|
|
|
def measure(dataset, outputs, labels, label_column, input_columns, search_row): |
|
|
inferences = [canonize_label(output, labels, search_row) for output in outputs] |
|
|
|
|
|
LOGGER.info(f"{inferences=}") |
|
|
LOGGER.info(f"{labels=}") |
|
|
inference_labels = labels + [UNKNOWN_LABEL] |
|
|
|
|
|
evaluation_df = pd.DataFrame( |
|
|
{ |
|
|
"hit/miss": np.where(dataset[label_column] == inferences, "hit", "miss"), |
|
|
"annotation": dataset[label_column], |
|
|
"inference": inferences, |
|
|
"output": outputs, |
|
|
} |
|
|
| dataset[input_columns].to_dict("list") |
|
|
) |
|
|
|
|
|
unknown_proportion = (evaluation_df["inference"] == UNKNOWN_LABEL).mean() |
|
|
|
|
|
acc = accuracy_score(evaluation_df["annotation"], evaluation_df["inference"]) |
|
|
bacc = balanced_accuracy_score( |
|
|
evaluation_df["annotation"], evaluation_df["inference"] |
|
|
) |
|
|
mcc = matthews_corrcoef(evaluation_df["annotation"], evaluation_df["inference"]) |
|
|
cm = confusion_matrix( |
|
|
evaluation_df["annotation"], evaluation_df["inference"], labels=inference_labels |
|
|
) |
|
|
|
|
|
cm_display = ConfusionMatrixDisplay(cm, display_labels=inference_labels) |
|
|
cm_display.plot() |
|
|
cm_display.ax_.set_xlabel("Inference Labels") |
|
|
cm_display.ax_.set_ylabel("Annotation Labels") |
|
|
cm_display.figure_.autofmt_xdate(rotation=45) |
|
|
|
|
|
metrics = { |
|
|
"unknown_proportion": unknown_proportion, |
|
|
"accuracy": acc, |
|
|
"balanced_accuracy": bacc, |
|
|
"mcc": mcc, |
|
|
"confusion_matrix": cm, |
|
|
"confusion_matrix_display": cm_display.figure_, |
|
|
"hit_miss": evaluation_df, |
|
|
"annotation_labels": labels, |
|
|
"inference_labels": inference_labels, |
|
|
} |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def run_evaluation( |
|
|
api_call_function, |
|
|
prompt_template, |
|
|
dataset, |
|
|
labels, |
|
|
label_column, |
|
|
input_columns, |
|
|
search_row, |
|
|
generation_config=None, |
|
|
): |
|
|
inputs_df = dataset[input_columns] |
|
|
outputs, prompts, lengths = asyncio.run( |
|
|
infer_multi( |
|
|
api_call_function, |
|
|
prompt_template, |
|
|
inputs_df, |
|
|
generation_config, |
|
|
) |
|
|
) |
|
|
|
|
|
metrics = measure(dataset, outputs, labels, label_column, input_columns, search_row) |
|
|
|
|
|
metrics["hit_miss"]["prompt"] = prompts |
|
|
metrics["hit_miss"]["length"] = lengths |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def combine_labels(labels): |
|
|
return "|".join(f"``{label}``" for label in labels) |
|
|
|
|
|
|
|
|
def main(): |
|
|
try: |
|
|
if "dataset_split_seed" not in st.session_state: |
|
|
st.session_state["dataset_split_seed"] = ( |
|
|
int(DATASET_SPLIT_SEED) if DATASET_SPLIT_SEED else None |
|
|
) |
|
|
|
|
|
if "train_size" not in st.session_state: |
|
|
st.session_state["train_size"] = TRAIN_SIZE |
|
|
|
|
|
if "test_size" not in st.session_state: |
|
|
st.session_state["test_size"] = TEST_SIZE |
|
|
|
|
|
if "api_call_function" not in st.session_state: |
|
|
st.session_state["api_call_function"] = build_api_call_function( |
|
|
model=HF_MODEL, |
|
|
) |
|
|
|
|
|
if "train_dataset" not in st.session_state: |
|
|
( |
|
|
st.session_state["dataset_name"], |
|
|
splits_df, |
|
|
st.session_state["input_columns"], |
|
|
st.session_state["label_column"], |
|
|
st.session_state["labels"], |
|
|
) = prepare_datasets( |
|
|
HF_DATASET, |
|
|
train_size=st.session_state.train_size, |
|
|
test_size=st.session_state.test_size, |
|
|
dataset_split_seed=st.session_state.dataset_split_seed, |
|
|
) |
|
|
|
|
|
for split in splits_df: |
|
|
st.session_state[f"{split}_dataset"] = splits_df[split] |
|
|
|
|
|
if "generation_config" not in st.session_state: |
|
|
st.session_state["generation_config"] = GENERATION_CONFIG_DEFAULTS.copy() |
|
|
st.session_state["generation_config"]["is_chat"] = False |
|
|
|
|
|
except Exception as e: |
|
|
st.error(e) |
|
|
|
|
|
st.title(TITLE) |
|
|
|
|
|
with st.sidebar: |
|
|
with st.form("model_form"): |
|
|
model = st.text_input("Model", HF_MODEL).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_config_sliders = { |
|
|
name: st.slider( |
|
|
params["NAME"], |
|
|
params["START"], |
|
|
params["END"], |
|
|
params["DEFAULT"], |
|
|
params["STEP"], |
|
|
) |
|
|
for name, params in GENERATION_CONFIG_PARAMS.items() |
|
|
if "START" in params |
|
|
} |
|
|
|
|
|
do_sample = st.checkbox( |
|
|
GENERATION_CONFIG_PARAMS["do_sample"]["NAME"], |
|
|
value=GENERATION_CONFIG_PARAMS["do_sample"]["DEFAULT"], |
|
|
) |
|
|
|
|
|
stop_sequences = st.text_area( |
|
|
GENERATION_CONFIG_PARAMS["stop_sequences"]["NAME"], |
|
|
value="\n".join(GENERATION_CONFIG_PARAMS["stop_sequences"]["DEFAULT"]), |
|
|
) |
|
|
|
|
|
stop_sequences = [ |
|
|
clean_stop.encode().decode("unicode_escape") |
|
|
for stop in stop_sequences.split("\n") |
|
|
if (clean_stop := stop.strip()) |
|
|
] |
|
|
if not stop_sequences: |
|
|
stop_sequences = None |
|
|
|
|
|
decoding_seed = st.text_input("Decoding Seed").strip() |
|
|
|
|
|
st.divider() |
|
|
|
|
|
dataset = st.text_input("Dataset", HF_DATASET).strip() |
|
|
|
|
|
train_size = st.number_input("Train Size", value=TRAIN_SIZE, min_value=10) |
|
|
test_size = st.number_input("Test Size", value=TEST_SIZE, min_value=10) |
|
|
|
|
|
balancing = st.checkbox("Balancing", BALANCING) |
|
|
|
|
|
dataset_split_seed = st.text_input( |
|
|
"Dataset Split Seed", DATASET_SPLIT_SEED |
|
|
).strip() |
|
|
|
|
|
st.divider() |
|
|
|
|
|
submitted = st.form_submit_button("Set") |
|
|
|
|
|
if submitted: |
|
|
if not dataset: |
|
|
st.error("Dataset must be specified.") |
|
|
st.stop() |
|
|
|
|
|
if not model: |
|
|
st.error("Model must be specified.") |
|
|
st.stop() |
|
|
|
|
|
if not decoding_seed: |
|
|
decoding_seed = None |
|
|
elif seed.isnumeric(): |
|
|
decoding_seed = int(seed) |
|
|
else: |
|
|
st.error("Seed must be numeric or empty.") |
|
|
st.stop() |
|
|
|
|
|
generation_confing_slider_sampling = { |
|
|
name: value |
|
|
for name, value in generation_config_sliders.items() |
|
|
if GENERATION_CONFIG_PARAMS[name]["SAMPLING"] |
|
|
} |
|
|
if ( |
|
|
any( |
|
|
value != GENERATION_CONFIG_DEFAULTS[name] |
|
|
for name, value in generation_confing_slider_sampling.items() |
|
|
) |
|
|
and not do_sample |
|
|
): |
|
|
sampling_slider_default_values_info = " | ".join( |
|
|
f"{name}={GENERATION_CONFIG_DEFAULTS[name]}" |
|
|
for name in generation_confing_slider_sampling |
|
|
) |
|
|
st.error( |
|
|
f"Sampling must be enabled to use non default values for generation parameters: {sampling_slider_default_values_info}" |
|
|
) |
|
|
st.stop() |
|
|
|
|
|
if decoding_seed is not None and not do_sample: |
|
|
st.error( |
|
|
"Sampling must be enabled to use a decoding seed. Otherwise, the seed field should be empty." |
|
|
) |
|
|
st.stop() |
|
|
|
|
|
if not dataset_split_seed: |
|
|
dataset_split_seed = None |
|
|
elif dataset_split_seed.isnumeric(): |
|
|
dataset_split_seed = int(dataset_split_seed) |
|
|
else: |
|
|
st.error("Dataset split seed must be numeric or empty.") |
|
|
st.stop() |
|
|
|
|
|
generation_config = generation_config_sliders | dict( |
|
|
do_sample=do_sample, |
|
|
stop_sequences=stop_sequences, |
|
|
seed=decoding_seed, |
|
|
) |
|
|
|
|
|
st.session_state["dataset_split_seed"] = dataset_split_seed |
|
|
st.session_state["train_size"] = train_size |
|
|
st.session_state["test_size"] = test_size |
|
|
|
|
|
try: |
|
|
st.session_state["api_call_function"] = build_api_call_function( |
|
|
model=model, |
|
|
) |
|
|
except OpenAIAlreadyInitiatedError as e: |
|
|
st.error(e) |
|
|
st.stop() |
|
|
|
|
|
st.session_state["generation_config"] = generation_config |
|
|
|
|
|
( |
|
|
st.session_state["dataset_name"], |
|
|
splits_df, |
|
|
st.session_state["input_columns"], |
|
|
st.session_state["label_column"], |
|
|
st.session_state["labels"], |
|
|
) = prepare_datasets( |
|
|
dataset, |
|
|
train_size=st.session_state.train_size, |
|
|
test_size=st.session_state.test_size, |
|
|
balancing=balancing, |
|
|
dataset_split_seed=st.session_state.dataset_split_seed, |
|
|
) |
|
|
|
|
|
for split in splits_df: |
|
|
st.session_state[f"{split}_dataset"] = splits_df[split] |
|
|
|
|
|
LOGGER.info(f"FORM {dataset=}") |
|
|
LOGGER.info(f"FORM {model=}") |
|
|
LOGGER.info(f"FORM {generation_config=}") |
|
|
|
|
|
with st.expander("Info"): |
|
|
if "dataset_name" in st.session_state: |
|
|
try: |
|
|
data_card = dataset_info(st.session_state.dataset_name).cardData |
|
|
st.caption("Dataset") |
|
|
st.write(data_card) |
|
|
except (HFValidationError, RepositoryNotFoundError): |
|
|
pass |
|
|
|
|
|
if "api_call_function" in st.session_state: |
|
|
try: |
|
|
model_info_response = model_info(model) |
|
|
model_card = model_info_response.cardData |
|
|
st.caption("Model") |
|
|
st.write(model_card) |
|
|
except (HFValidationError, RepositoryNotFoundError): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["Evaluation", "Examples", "Playground"]) |
|
|
|
|
|
with tab1: |
|
|
with st.form("prompt_form"): |
|
|
prompt_template = st.text_area("Prompt Template", height=PROMPT_TEXT_HEIGHT) |
|
|
|
|
|
is_multi_placeholder = len(st.session_state.input_columns) > 1 |
|
|
|
|
|
st.write( |
|
|
f"To determine the inferred label of an input, the model should output one of the following words:" |
|
|
f" {combine_labels(st.session_state.labels)}" |
|
|
) |
|
|
st.write( |
|
|
f"The input placeholder{'s' if is_multi_placeholder else ''} available for the prompt template {'are' if is_multi_placeholder else 'is'}:" |
|
|
f" {combine_labels(f'{{{col}}}' for col in st.session_state.input_columns)}" |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
search_row = st.selectbox( |
|
|
"Search label at which row", list(SEARCH_ROW_DICT) |
|
|
) |
|
|
|
|
|
with col2: |
|
|
submitted = st.form_submit_button("Evaluate") |
|
|
|
|
|
if submitted: |
|
|
if not prompt_template: |
|
|
st.error("Prompt template must be specified.") |
|
|
st.stop() |
|
|
|
|
|
_, formats, *_ = zip(*string.Formatter().parse(prompt_template)) |
|
|
is_valid_prompt_template = set(formats).issubset( |
|
|
{None} | set(st.session_state.input_columns) |
|
|
) |
|
|
|
|
|
if not is_valid_prompt_template: |
|
|
st.error(f"The prompt template contains unrecognized fields.") |
|
|
st.stop() |
|
|
|
|
|
with st.spinner("Executing inference..."): |
|
|
try: |
|
|
evaluation = run_evaluation( |
|
|
st.session_state.api_call_function, |
|
|
prompt_template, |
|
|
st.session_state.test_dataset, |
|
|
st.session_state.labels, |
|
|
st.session_state.label_column, |
|
|
st.session_state.input_columns, |
|
|
search_row, |
|
|
st.session_state.generation_config, |
|
|
) |
|
|
except HfHubHTTPError as e: |
|
|
st.error(e) |
|
|
st.stop() |
|
|
|
|
|
st.markdown("### Metrics") |
|
|
num_metric_cols = 2 if balancing else 4 |
|
|
cols = st.columns(num_metric_cols) |
|
|
with cols[0]: |
|
|
st.metric("Accuracy", f"{100 * evaluation['accuracy']:.0f}%") |
|
|
st.caption("The percentage of correct inferences.") |
|
|
with cols[1]: |
|
|
st.metric( |
|
|
"Unknown", |
|
|
f"{100 * evaluation['unknown_proportion']:.0f}%", |
|
|
) |
|
|
st.caption( |
|
|
"The percentage of inferences" |
|
|
" that could not be determined based on the model output." |
|
|
) |
|
|
if not balancing: |
|
|
with cols[2]: |
|
|
st.metric( |
|
|
"Balanced Accuracy", |
|
|
f"{100 * evaluation['balanced_accuracy']:.0f}%", |
|
|
) |
|
|
with cols[3]: |
|
|
st.metric("MCC", f"{evaluation['mcc']:.2f}") |
|
|
|
|
|
st.markdown("### Detailed Evaluation") |
|
|
|
|
|
st.caption( |
|
|
"This table showcases all examples (input and output pairs) that were leveraged for the evaluation of the prompt template with the model (for instance, accuracy)." |
|
|
" It comprises the input placeholder values, the unmodified model *output*, the deduced *inference*, and the ground-truth *annotation*." |
|
|
) |
|
|
st.caption( |
|
|
"A 'hit' signifies a correct inference (when *inference* coincides with *annotation*), while a 'miss' denotes an incorrect inference." |
|
|
" If the *inference* cannot be determined based on the model output, it is labeled as 'unknown'." |
|
|
) |
|
|
st.caption( |
|
|
"The *prompt* column features the complete prompt that the model was prompted to complete, i.e., your prompt template filled with the input placeholders you have used." |
|
|
) |
|
|
st.caption( |
|
|
"You are not allowed to include these examples in your prompt template." |
|
|
) |
|
|
|
|
|
st.dataframe(evaluation["hit_miss"]) |
|
|
|
|
|
with st.expander("Additional Information", expanded=False): |
|
|
st.markdown("## Confusion Matrix") |
|
|
st.pyplot(evaluation["confusion_matrix_display"]) |
|
|
|
|
|
if evaluation["accuracy"] == 1: |
|
|
st.balloons() |
|
|
|
|
|
with tab2: |
|
|
st.caption( |
|
|
"You can include the following examples in your prompt template for few-shot prompting." |
|
|
) |
|
|
st.dataframe(st.session_state.train_dataset) |
|
|
|
|
|
with tab3: |
|
|
prompt = st.text_area("Prompt", height=PROMPT_TEXT_HEIGHT) |
|
|
|
|
|
submitted = st.button("Complete") |
|
|
|
|
|
if submitted: |
|
|
if not prompt: |
|
|
st.error("Prompt must be specified.") |
|
|
st.stop() |
|
|
|
|
|
with st.spinner("Generating..."): |
|
|
try: |
|
|
generation_config = st.session_state.generation_config.copy() |
|
|
generation_config["is_chat"] = False |
|
|
output, length = asyncio.run( |
|
|
complete( |
|
|
st.session_state.api_call_function, |
|
|
prompt, |
|
|
st.session_state.generation_config, |
|
|
) |
|
|
) |
|
|
except HfHubHTTPError as e: |
|
|
st.error(e) |
|
|
st.stop() |
|
|
st.markdown(escape_markdown(output)) |
|
|
if length is not None: |
|
|
with st.expander("Stats"): |
|
|
st.metric("#Tokens", length) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
main() |
|
|
|